micro_sam.object_classification
1import os 2from joblib import load 3from typing import List, Optional, Sequence, Tuple, Union 4 5import numpy as np 6import pandas as pd 7 8from nifty.tools import blocking, takeDict 9from skimage.measure import regionprops_table 10from skimage.transform import resize 11 12try: 13 from napari.utils import progress as tqdm 14except ImportError: 15 from tqdm import tqdm 16 17from .import util 18 19 20def _compute_object_features_impl(embeddings, segmentation, resize_embedding_shape): 21 # Get the embeddings and put the channel axis last. 22 embeddings = embeddings.transpose(1, 2, 0) 23 24 # Pad the segmentation to be of square shape. 25 shape = segmentation.shape 26 if shape[0] == shape[1]: 27 segmentation_rescaled = segmentation 28 elif shape[0] > shape[1]: 29 segmentation_rescaled = np.pad(segmentation, ((0, 0), (0, shape[0] - shape[1]))) 30 elif shape[1] > shape[0]: 31 segmentation_rescaled = np.pad(segmentation, ((0, shape[1] - shape[0]), (0, 0))) 32 assert segmentation_rescaled.shape[0] == segmentation_rescaled.shape[1] 33 shape = segmentation_rescaled.shape 34 35 # Resize the segmentation and embeddings to be of the same size. 36 37 # We first resize the embedding, to an intermediate shape (passed as parameter). 38 # The motivation for this is to avoid loosing smaller segmented objects when resizing the segmentation 39 # to the original embedding shape. On the other hand, we avoid resizing the embeddings to the full segmentation 40 # shape for efficiency reasons. 41 resize_shape = tuple(min(rsh, sh) for rsh, sh in zip(resize_embedding_shape, shape)) + (embeddings.shape[-1],) 42 embeddings = resize(embeddings, resize_shape, preserve_range=True).astype(embeddings.dtype) 43 44 segmentation_rescaled = resize( 45 segmentation_rescaled, embeddings.shape[:2], order=0, anti_aliasing=False, preserve_range=True 46 ).astype(segmentation.dtype) 47 48 # Which features do we use? 49 all_features = regionprops_table( 50 segmentation_rescaled, intensity_image=embeddings, properties=("label", "area", "mean_intensity"), 51 ) 52 seg_ids = all_features["label"] 53 features = pd.DataFrame(all_features)[ 54 ["area"] + [f"mean_intensity-{i}" for i in range(embeddings.shape[-1])] 55 ].values 56 57 return seg_ids, features 58 59 60def _create_seg_and_embed_generator(segmentation, image_embeddings, is_tiled, is_3d): 61 assert is_tiled or is_3d 62 63 if is_tiled: 64 tile_embeds = image_embeddings["features"] 65 tile_shape, halo = tile_embeds.attrs["tile_shape"], tile_embeds.attrs["halo"] 66 tiling = blocking([0, 0], tile_embeds.attrs["shape"], tile_shape) 67 length = tiling.numberOfBlocks * segmentation.shape[0] if is_3d else tiling.numberOfBlocks 68 else: 69 tiling = None 70 length = segmentation.shape[0] 71 72 if is_3d and is_tiled: # 3d data with tiling 73 def generator(): 74 for z in range(segmentation.shape[0]): 75 seg_z = segmentation[z] 76 for block_id in range(tiling.numberOfBlocks): 77 block = tiling.getBlockWithHalo(block_id, halo) 78 79 # Get the embeddings and segmentation for this block and slice. 80 embeds = tile_embeds[str(block_id)][z].squeeze() 81 82 bb = tuple(slice(beg, end) for beg, end in zip(block.outerBlock.begin, block.outerBlock.end)) 83 seg = seg_z[bb] 84 85 yield seg, embeds 86 87 elif is_3d: # 3d data no tiling 88 def generator(): 89 for z in range(length): 90 seg = segmentation[z] 91 embeds = image_embeddings["features"][z].squeeze() 92 yield seg, embeds 93 94 else: # 2d data with tiling 95 def generator(): 96 for block_id in range(length): 97 block = tiling.getBlockWithHalo(block_id, halo) 98 99 # Get the embeddings and segmentation for this block. 100 embeds = tile_embeds[str(block_id)][:].squeeze() 101 bb = tuple(slice(beg, end) for beg, end in zip(block.outerBlock.begin, block.outerBlock.end)) 102 seg = segmentation[bb] 103 104 yield seg, embeds 105 106 return generator, length 107 108 109def compute_object_features( 110 image_embeddings: util.ImageEmbeddings, 111 segmentation: np.ndarray, 112 resize_embedding_shape: Tuple[int, int] = (256, 256), 113 verbose: bool = True, 114) -> Tuple[np.ndarray, np.ndarray]: 115 """Compute object features based on SAM embeddings. 116 117 Args: 118 image_embeddings: The precomputed image embeddings. 119 segmentation: The segmentation for which to compute the features. 120 resize_embedding_shape: Shape for intermediate resizing of the embeddings. 121 verbose: Whether to print a progressbar for the computation. 122 123 Returns: 124 The segmentation ids. 125 The object features. 126 """ 127 is_tiled = image_embeddings["input_size"] is None 128 is_3d = segmentation.ndim == 3 129 130 # If we have simple embeddings, i.e. 2d without tiling, then we can directly compute the features. 131 if not is_tiled and not is_3d: 132 embeddings = image_embeddings["features"].squeeze() 133 return _compute_object_features_impl(embeddings, segmentation, resize_embedding_shape) 134 135 # Otherwise, we compute the features by iterating over slices and/or tiles, 136 # compute the features for each slice / tile and accumulate them. 137 138 # Fist, we compute the segmentation ids and initialize the required data structures. 139 seg_ids = np.unique(segmentation).tolist() 140 if seg_ids[0] == 0: 141 seg_ids = seg_ids[1:] 142 visited = {seg_id: False for seg_id in seg_ids} 143 144 n_features = 257 # Don't hard-code? 145 features = np.zeros((len(seg_ids), n_features), dtype="float32") 146 147 # Then, we create a generator for iterating over the slices and / or tile. 148 # This generator returns the respective segmentation and embeddings. 149 seg_embed_generator, n_gen = _create_seg_and_embed_generator( 150 segmentation, image_embeddings, is_tiled=is_tiled, is_3d=is_3d 151 ) 152 153 for seg, embeds in tqdm( 154 seg_embed_generator(), total=n_gen, disable=not verbose, desc="Compute object features" 155 ): 156 # Compute this seg ids and features. 157 this_seg_ids, this_features = _compute_object_features_impl(embeds, seg, resize_embedding_shape) 158 this_seg_ids = this_seg_ids.tolist() 159 160 # Find which of the seg ids are new (= processed for the first time). 161 # And the seg ids that were already visited. 162 new_idx = np.array([seg_ids.index(seg_id) for seg_id in this_seg_ids if not visited[seg_id]], dtype="int") 163 visited_idx = np.array([seg_ids.index(seg_id) for seg_id in this_seg_ids if visited[seg_id]], dtype="int") 164 165 # Get the corresponding feature indices. 166 this_new_idx = np.array( 167 [this_seg_ids.index(seg_id) for seg_id in this_seg_ids if not visited[seg_id]], dtype="int" 168 ) 169 this_visited_idx = np.array( 170 [this_seg_ids.index(seg_id) for seg_id in this_seg_ids if visited[seg_id]], dtype="int" 171 ) 172 173 # New features can be written directly. 174 features[new_idx] = this_features[this_new_idx] 175 176 # Features that were already visited can be merged. 177 if len(visited_idx) > 0: 178 # Get ths sizes, which are needed for computing the mean. 179 prev_size = features[visited_idx, 0:1] 180 this_size = this_features[this_visited_idx, 0:1] 181 182 # The sizes themselve are merged by addition. 183 features[visited_idx, 0] += this_features[this_visited_idx, 0] 184 185 # Mean values are merged via weighted sum. 186 features[visited_idx, 1:] = ( 187 prev_size * features[visited_idx, 1:] + this_size * this_features[this_visited_idx, 1:] 188 ) / (prev_size + this_size) 189 190 # Set all seg ids from this block to visited. 191 visited.update({seg_id: True for seg_id in this_seg_ids}) 192 193 return np.array(seg_ids), features 194 195 196def project_prediction_to_segmentation( 197 segmentation: np.ndarray, 198 object_prediction: np.ndarray, 199 seg_ids: np.ndarray 200) -> np.ndarray: 201 """Project object level prediction to the corresponding segmentation to obtain a pixel level prediction. 202 203 Args: 204 segmentation: The segmentation from which the object prediction is derived. 205 object_prediction: The object prediction. 206 seg_ids: The segmentation ids matching the object prediction. 207 208 Returns: 209 The pixel level object prediction, corresponding to a semantic segmentation. 210 """ 211 assert len(object_prediction) == len(seg_ids) 212 prediction = {seg_id: class_pred for seg_id, class_pred in zip(seg_ids, object_prediction)} 213 # Find missing segmentation ids. This will include the background id, but may include other ids of small objects. 214 # Such objects may get removed in the resizing operations. 215 missing_ids = np.setdiff1d(np.unique(segmentation), seg_ids) 216 prediction.update({missing_id: 0 for missing_id in missing_ids}) 217 return takeDict(prediction, segmentation) 218 219 220# TODO handle images / segmentations as file paths 221# TODO think about the function signature, specially how exactly we pass model and optional embedding path. 222# TODO halo and tile shape 223# TODO add heuristic for ndim 224def run_prediction_with_object_classifier( 225 images: Sequence[Union[str, os.PathLike, np.ndarray]], 226 segmentations: Sequence[Union[str, os.PathLike, np.ndarray]], 227 predictor, 228 rf_path: Union[str, os.PathLike], 229 image_key: Optional[str] = None, 230 segmentation_key: Optional[str] = None, 231 project_prediction: bool = True, 232 ndim: Optional[int] = None, 233) -> List[np.ndarray]: 234 """Run prediction with a pretrained object classifier on a series of images. 235 236 Args: 237 images: The images, either given as a list of numpy array or filepaths. 238 segmentations: The segmentaitons, either given as a list of numpy array or filepaths. 239 predictor: 240 rf_path: 241 image_key: 242 segmentation_key: 243 project_prediction: 244 ndim: 245 246 Returns: 247 The predictions. 248 """ 249 assert len(images) == len(segmentations) 250 rf = load(rf_path) 251 predictions = [] 252 for image, segmentation in tqdm( 253 zip(images, segmentations), total=len(images), desc="Run prediction with object classifier" 254 ): 255 embeddings = util.precompute_image_embeddings(predictor, image, verbose=False, ndim=ndim) 256 seg_ids, features = compute_object_features(embeddings, segmentation, verbose=False) 257 prediction = rf.predict(features) 258 if project_prediction: 259 prediction = project_prediction_to_segmentation(segmentation, prediction, seg_ids) 260 predictions.append(prediction) 261 return predictions
def
compute_object_features( image_embeddings: Dict[str, Any], segmentation: numpy.ndarray, resize_embedding_shape: Tuple[int, int] = (256, 256), verbose: bool = True) -> Tuple[numpy.ndarray, numpy.ndarray]:
110def compute_object_features( 111 image_embeddings: util.ImageEmbeddings, 112 segmentation: np.ndarray, 113 resize_embedding_shape: Tuple[int, int] = (256, 256), 114 verbose: bool = True, 115) -> Tuple[np.ndarray, np.ndarray]: 116 """Compute object features based on SAM embeddings. 117 118 Args: 119 image_embeddings: The precomputed image embeddings. 120 segmentation: The segmentation for which to compute the features. 121 resize_embedding_shape: Shape for intermediate resizing of the embeddings. 122 verbose: Whether to print a progressbar for the computation. 123 124 Returns: 125 The segmentation ids. 126 The object features. 127 """ 128 is_tiled = image_embeddings["input_size"] is None 129 is_3d = segmentation.ndim == 3 130 131 # If we have simple embeddings, i.e. 2d without tiling, then we can directly compute the features. 132 if not is_tiled and not is_3d: 133 embeddings = image_embeddings["features"].squeeze() 134 return _compute_object_features_impl(embeddings, segmentation, resize_embedding_shape) 135 136 # Otherwise, we compute the features by iterating over slices and/or tiles, 137 # compute the features for each slice / tile and accumulate them. 138 139 # Fist, we compute the segmentation ids and initialize the required data structures. 140 seg_ids = np.unique(segmentation).tolist() 141 if seg_ids[0] == 0: 142 seg_ids = seg_ids[1:] 143 visited = {seg_id: False for seg_id in seg_ids} 144 145 n_features = 257 # Don't hard-code? 146 features = np.zeros((len(seg_ids), n_features), dtype="float32") 147 148 # Then, we create a generator for iterating over the slices and / or tile. 149 # This generator returns the respective segmentation and embeddings. 150 seg_embed_generator, n_gen = _create_seg_and_embed_generator( 151 segmentation, image_embeddings, is_tiled=is_tiled, is_3d=is_3d 152 ) 153 154 for seg, embeds in tqdm( 155 seg_embed_generator(), total=n_gen, disable=not verbose, desc="Compute object features" 156 ): 157 # Compute this seg ids and features. 158 this_seg_ids, this_features = _compute_object_features_impl(embeds, seg, resize_embedding_shape) 159 this_seg_ids = this_seg_ids.tolist() 160 161 # Find which of the seg ids are new (= processed for the first time). 162 # And the seg ids that were already visited. 163 new_idx = np.array([seg_ids.index(seg_id) for seg_id in this_seg_ids if not visited[seg_id]], dtype="int") 164 visited_idx = np.array([seg_ids.index(seg_id) for seg_id in this_seg_ids if visited[seg_id]], dtype="int") 165 166 # Get the corresponding feature indices. 167 this_new_idx = np.array( 168 [this_seg_ids.index(seg_id) for seg_id in this_seg_ids if not visited[seg_id]], dtype="int" 169 ) 170 this_visited_idx = np.array( 171 [this_seg_ids.index(seg_id) for seg_id in this_seg_ids if visited[seg_id]], dtype="int" 172 ) 173 174 # New features can be written directly. 175 features[new_idx] = this_features[this_new_idx] 176 177 # Features that were already visited can be merged. 178 if len(visited_idx) > 0: 179 # Get ths sizes, which are needed for computing the mean. 180 prev_size = features[visited_idx, 0:1] 181 this_size = this_features[this_visited_idx, 0:1] 182 183 # The sizes themselve are merged by addition. 184 features[visited_idx, 0] += this_features[this_visited_idx, 0] 185 186 # Mean values are merged via weighted sum. 187 features[visited_idx, 1:] = ( 188 prev_size * features[visited_idx, 1:] + this_size * this_features[this_visited_idx, 1:] 189 ) / (prev_size + this_size) 190 191 # Set all seg ids from this block to visited. 192 visited.update({seg_id: True for seg_id in this_seg_ids}) 193 194 return np.array(seg_ids), features
Compute object features based on SAM embeddings.
Arguments:
- image_embeddings: The precomputed image embeddings.
- segmentation: The segmentation for which to compute the features.
- resize_embedding_shape: Shape for intermediate resizing of the embeddings.
- verbose: Whether to print a progressbar for the computation.
Returns:
The segmentation ids. The object features.
def
project_prediction_to_segmentation( segmentation: numpy.ndarray, object_prediction: numpy.ndarray, seg_ids: numpy.ndarray) -> numpy.ndarray:
197def project_prediction_to_segmentation( 198 segmentation: np.ndarray, 199 object_prediction: np.ndarray, 200 seg_ids: np.ndarray 201) -> np.ndarray: 202 """Project object level prediction to the corresponding segmentation to obtain a pixel level prediction. 203 204 Args: 205 segmentation: The segmentation from which the object prediction is derived. 206 object_prediction: The object prediction. 207 seg_ids: The segmentation ids matching the object prediction. 208 209 Returns: 210 The pixel level object prediction, corresponding to a semantic segmentation. 211 """ 212 assert len(object_prediction) == len(seg_ids) 213 prediction = {seg_id: class_pred for seg_id, class_pred in zip(seg_ids, object_prediction)} 214 # Find missing segmentation ids. This will include the background id, but may include other ids of small objects. 215 # Such objects may get removed in the resizing operations. 216 missing_ids = np.setdiff1d(np.unique(segmentation), seg_ids) 217 prediction.update({missing_id: 0 for missing_id in missing_ids}) 218 return takeDict(prediction, segmentation)
Project object level prediction to the corresponding segmentation to obtain a pixel level prediction.
Arguments:
- segmentation: The segmentation from which the object prediction is derived.
- object_prediction: The object prediction.
- seg_ids: The segmentation ids matching the object prediction.
Returns:
The pixel level object prediction, corresponding to a semantic segmentation.
def
run_prediction_with_object_classifier( images: Sequence[Union[str, os.PathLike, numpy.ndarray]], segmentations: Sequence[Union[str, os.PathLike, numpy.ndarray]], predictor, rf_path: Union[str, os.PathLike], image_key: Optional[str] = None, segmentation_key: Optional[str] = None, project_prediction: bool = True, ndim: Optional[int] = None) -> List[numpy.ndarray]:
225def run_prediction_with_object_classifier( 226 images: Sequence[Union[str, os.PathLike, np.ndarray]], 227 segmentations: Sequence[Union[str, os.PathLike, np.ndarray]], 228 predictor, 229 rf_path: Union[str, os.PathLike], 230 image_key: Optional[str] = None, 231 segmentation_key: Optional[str] = None, 232 project_prediction: bool = True, 233 ndim: Optional[int] = None, 234) -> List[np.ndarray]: 235 """Run prediction with a pretrained object classifier on a series of images. 236 237 Args: 238 images: The images, either given as a list of numpy array or filepaths. 239 segmentations: The segmentaitons, either given as a list of numpy array or filepaths. 240 predictor: 241 rf_path: 242 image_key: 243 segmentation_key: 244 project_prediction: 245 ndim: 246 247 Returns: 248 The predictions. 249 """ 250 assert len(images) == len(segmentations) 251 rf = load(rf_path) 252 predictions = [] 253 for image, segmentation in tqdm( 254 zip(images, segmentations), total=len(images), desc="Run prediction with object classifier" 255 ): 256 embeddings = util.precompute_image_embeddings(predictor, image, verbose=False, ndim=ndim) 257 seg_ids, features = compute_object_features(embeddings, segmentation, verbose=False) 258 prediction = rf.predict(features) 259 if project_prediction: 260 prediction = project_prediction_to_segmentation(segmentation, prediction, seg_ids) 261 predictions.append(prediction) 262 return predictions
Run prediction with a pretrained object classifier on a series of images.
Arguments:
- images: The images, either given as a list of numpy array or filepaths.
- segmentations: The segmentaitons, either given as a list of numpy array or filepaths.
- predictor:
- rf_path:
- image_key:
- segmentation_key:
- project_prediction:
- ndim:
Returns:
The predictions.