micro_sam.object_classification

  1import os
  2from joblib import load
  3from typing import List, Optional, Sequence, Tuple, Union
  4
  5import numpy as np
  6import pandas as pd
  7
  8from nifty.tools import blocking, takeDict
  9from skimage.measure import regionprops_table
 10from skimage.transform import resize
 11
 12try:
 13    from napari.utils import progress as tqdm
 14except ImportError:
 15    from tqdm import tqdm
 16
 17from .import util
 18
 19
 20def _compute_object_features_impl(embeddings, segmentation, resize_embedding_shape):
 21    # Get the embeddings and put the channel axis last.
 22    embeddings = embeddings.transpose(1, 2, 0)
 23
 24    # Pad the segmentation to be of square shape.
 25    shape = segmentation.shape
 26    if shape[0] == shape[1]:
 27        segmentation_rescaled = segmentation
 28    elif shape[0] > shape[1]:
 29        segmentation_rescaled = np.pad(segmentation, ((0, 0), (0, shape[0] - shape[1])))
 30    elif shape[1] > shape[0]:
 31        segmentation_rescaled = np.pad(segmentation, ((0, shape[1] - shape[0]), (0, 0)))
 32    assert segmentation_rescaled.shape[0] == segmentation_rescaled.shape[1]
 33    shape = segmentation_rescaled.shape
 34
 35    # Resize the segmentation and embeddings to be of the same size.
 36
 37    # We first resize the embedding, to an intermediate shape (passed as parameter).
 38    # The motivation for this is to avoid loosing smaller segmented objects when resizing the segmentation
 39    # to the original embedding shape. On the other hand, we avoid resizing the embeddings to the full segmentation
 40    # shape for efficiency reasons.
 41    resize_shape = tuple(min(rsh, sh) for rsh, sh in zip(resize_embedding_shape, shape)) + (embeddings.shape[-1],)
 42    embeddings = resize(embeddings, resize_shape, preserve_range=True).astype(embeddings.dtype)
 43
 44    segmentation_rescaled = resize(
 45        segmentation_rescaled, embeddings.shape[:2], order=0, anti_aliasing=False, preserve_range=True
 46    ).astype(segmentation.dtype)
 47
 48    # Which features do we use?
 49    all_features = regionprops_table(
 50        segmentation_rescaled, intensity_image=embeddings, properties=("label", "area", "mean_intensity"),
 51    )
 52    seg_ids = all_features["label"]
 53    features = pd.DataFrame(all_features)[
 54        ["area"] + [f"mean_intensity-{i}" for i in range(embeddings.shape[-1])]
 55    ].values
 56
 57    return seg_ids, features
 58
 59
 60def _create_seg_and_embed_generator(segmentation, image_embeddings, is_tiled, is_3d):
 61    assert is_tiled or is_3d
 62
 63    if is_tiled:
 64        tile_embeds = image_embeddings["features"]
 65        tile_shape, halo = tile_embeds.attrs["tile_shape"], tile_embeds.attrs["halo"]
 66        tiling = blocking([0, 0], tile_embeds.attrs["shape"], tile_shape)
 67        length = tiling.numberOfBlocks * segmentation.shape[0] if is_3d else tiling.numberOfBlocks
 68    else:
 69        tiling = None
 70        length = segmentation.shape[0]
 71
 72    if is_3d and is_tiled:  # 3d data with tiling
 73        def generator():
 74            for z in range(segmentation.shape[0]):
 75                seg_z = segmentation[z]
 76                for block_id in range(tiling.numberOfBlocks):
 77                    block = tiling.getBlockWithHalo(block_id, halo)
 78
 79                    # Get the embeddings and segmentation for this block and slice.
 80                    embeds = tile_embeds[str(block_id)][z].squeeze()
 81
 82                    bb = tuple(slice(beg, end) for beg, end in zip(block.outerBlock.begin, block.outerBlock.end))
 83                    seg = seg_z[bb]
 84
 85                    yield seg, embeds
 86
 87    elif is_3d:  # 3d data no tiling
 88        def generator():
 89            for z in range(length):
 90                seg = segmentation[z]
 91                embeds = image_embeddings["features"][z].squeeze()
 92                yield seg, embeds
 93
 94    else:  # 2d data with tiling
 95        def generator():
 96            for block_id in range(length):
 97                block = tiling.getBlockWithHalo(block_id, halo)
 98
 99                # Get the embeddings and segmentation for this block.
100                embeds = tile_embeds[str(block_id)][:].squeeze()
101                bb = tuple(slice(beg, end) for beg, end in zip(block.outerBlock.begin, block.outerBlock.end))
102                seg = segmentation[bb]
103
104                yield seg, embeds
105
106    return generator, length
107
108
109def compute_object_features(
110    image_embeddings: util.ImageEmbeddings,
111    segmentation: np.ndarray,
112    resize_embedding_shape: Tuple[int, int] = (256, 256),
113    verbose: bool = True,
114) -> Tuple[np.ndarray, np.ndarray]:
115    """Compute object features based on SAM embeddings.
116
117    Args:
118        image_embeddings: The precomputed image embeddings.
119        segmentation: The segmentation for which to compute the features.
120        resize_embedding_shape: Shape for intermediate resizing of the embeddings.
121        verbose: Whether to print a progressbar for the computation.
122
123    Returns:
124        The segmentation ids.
125        The object features.
126    """
127    is_tiled = image_embeddings["input_size"] is None
128    is_3d = segmentation.ndim == 3
129
130    # If we have simple embeddings, i.e. 2d without tiling, then we can directly compute the features.
131    if not is_tiled and not is_3d:
132        embeddings = image_embeddings["features"].squeeze()
133        return _compute_object_features_impl(embeddings, segmentation, resize_embedding_shape)
134
135    # Otherwise, we compute the features by iterating over slices and/or tiles,
136    # compute the features for each slice / tile and accumulate them.
137
138    # Fist, we compute the segmentation ids and initialize the required data structures.
139    seg_ids = np.unique(segmentation).tolist()
140    if seg_ids[0] == 0:
141        seg_ids = seg_ids[1:]
142    visited = {seg_id: False for seg_id in seg_ids}
143
144    n_features = 257  # Don't hard-code?
145    features = np.zeros((len(seg_ids), n_features), dtype="float32")
146
147    # Then, we create a generator for iterating over the slices and / or tile.
148    # This generator returns the respective segmentation and embeddings.
149    seg_embed_generator, n_gen = _create_seg_and_embed_generator(
150        segmentation, image_embeddings, is_tiled=is_tiled, is_3d=is_3d
151    )
152
153    for seg, embeds in tqdm(
154        seg_embed_generator(), total=n_gen, disable=not verbose, desc="Compute object features"
155    ):
156        # Compute this seg ids and features.
157        this_seg_ids, this_features = _compute_object_features_impl(embeds, seg, resize_embedding_shape)
158        this_seg_ids = this_seg_ids.tolist()
159
160        # Find which of the seg ids are new (= processed for the first time).
161        # And the seg ids that were already visited.
162        new_idx = np.array([seg_ids.index(seg_id) for seg_id in this_seg_ids if not visited[seg_id]], dtype="int")
163        visited_idx = np.array([seg_ids.index(seg_id) for seg_id in this_seg_ids if visited[seg_id]], dtype="int")
164
165        # Get the corresponding feature indices.
166        this_new_idx = np.array(
167            [this_seg_ids.index(seg_id) for seg_id in this_seg_ids if not visited[seg_id]], dtype="int"
168        )
169        this_visited_idx = np.array(
170            [this_seg_ids.index(seg_id) for seg_id in this_seg_ids if visited[seg_id]], dtype="int"
171        )
172
173        # New features can be written directly.
174        features[new_idx] = this_features[this_new_idx]
175
176        # Features that were already visited can be merged.
177        if len(visited_idx) > 0:
178            # Get ths sizes, which are needed for computing the mean.
179            prev_size = features[visited_idx, 0:1]
180            this_size = this_features[this_visited_idx, 0:1]
181
182            # The sizes themselve are merged by addition.
183            features[visited_idx, 0] += this_features[this_visited_idx, 0]
184
185            # Mean values are merged via weighted sum.
186            features[visited_idx, 1:] = (
187                prev_size * features[visited_idx, 1:] + this_size * this_features[this_visited_idx, 1:]
188            ) / (prev_size + this_size)
189
190        # Set all seg ids from this block to visited.
191        visited.update({seg_id: True for seg_id in this_seg_ids})
192
193    return np.array(seg_ids), features
194
195
196def project_prediction_to_segmentation(
197    segmentation: np.ndarray,
198    object_prediction: np.ndarray,
199    seg_ids: np.ndarray
200) -> np.ndarray:
201    """Project object level prediction to the corresponding segmentation to obtain a pixel level prediction.
202
203    Args:
204        segmentation: The segmentation from which the object prediction is derived.
205        object_prediction: The object prediction.
206        seg_ids: The segmentation ids matching the object prediction.
207
208    Returns:
209        The pixel level object prediction, corresponding to a semantic segmentation.
210    """
211    assert len(object_prediction) == len(seg_ids)
212    prediction = {seg_id: class_pred for seg_id, class_pred in zip(seg_ids, object_prediction)}
213    # Find missing segmentation ids. This will include the background id, but may include other ids of small objects.
214    # Such objects may get removed in the resizing operations.
215    missing_ids = np.setdiff1d(np.unique(segmentation), seg_ids)
216    prediction.update({missing_id: 0 for missing_id in missing_ids})
217    return takeDict(prediction, segmentation)
218
219
220# TODO handle images / segmentations as file paths
221# TODO think about the function signature, specially how exactly we pass model and optional embedding path.
222# TODO halo and tile shape
223# TODO add heuristic for ndim
224def run_prediction_with_object_classifier(
225    images: Sequence[Union[str, os.PathLike, np.ndarray]],
226    segmentations: Sequence[Union[str, os.PathLike, np.ndarray]],
227    predictor,
228    rf_path: Union[str, os.PathLike],
229    image_key: Optional[str] = None,
230    segmentation_key: Optional[str] = None,
231    project_prediction: bool = True,
232    ndim: Optional[int] = None,
233) -> List[np.ndarray]:
234    """Run prediction with a pretrained object classifier on a series of images.
235
236    Args:
237        images: The images, either given as a list of numpy array or filepaths.
238        segmentations: The segmentaitons, either given as a list of numpy array or filepaths.
239        predictor:
240        rf_path:
241        image_key:
242        segmentation_key:
243        project_prediction:
244        ndim:
245
246    Returns:
247        The predictions.
248    """
249    assert len(images) == len(segmentations)
250    rf = load(rf_path)
251    predictions = []
252    for image, segmentation in tqdm(
253        zip(images, segmentations), total=len(images), desc="Run prediction with object classifier"
254    ):
255        embeddings = util.precompute_image_embeddings(predictor, image, verbose=False, ndim=ndim)
256        seg_ids, features = compute_object_features(embeddings, segmentation, verbose=False)
257        prediction = rf.predict(features)
258        if project_prediction:
259            prediction = project_prediction_to_segmentation(segmentation, prediction, seg_ids)
260        predictions.append(prediction)
261    return predictions
def compute_object_features( image_embeddings: Dict[str, Any], segmentation: numpy.ndarray, resize_embedding_shape: Tuple[int, int] = (256, 256), verbose: bool = True) -> Tuple[numpy.ndarray, numpy.ndarray]:
110def compute_object_features(
111    image_embeddings: util.ImageEmbeddings,
112    segmentation: np.ndarray,
113    resize_embedding_shape: Tuple[int, int] = (256, 256),
114    verbose: bool = True,
115) -> Tuple[np.ndarray, np.ndarray]:
116    """Compute object features based on SAM embeddings.
117
118    Args:
119        image_embeddings: The precomputed image embeddings.
120        segmentation: The segmentation for which to compute the features.
121        resize_embedding_shape: Shape for intermediate resizing of the embeddings.
122        verbose: Whether to print a progressbar for the computation.
123
124    Returns:
125        The segmentation ids.
126        The object features.
127    """
128    is_tiled = image_embeddings["input_size"] is None
129    is_3d = segmentation.ndim == 3
130
131    # If we have simple embeddings, i.e. 2d without tiling, then we can directly compute the features.
132    if not is_tiled and not is_3d:
133        embeddings = image_embeddings["features"].squeeze()
134        return _compute_object_features_impl(embeddings, segmentation, resize_embedding_shape)
135
136    # Otherwise, we compute the features by iterating over slices and/or tiles,
137    # compute the features for each slice / tile and accumulate them.
138
139    # Fist, we compute the segmentation ids and initialize the required data structures.
140    seg_ids = np.unique(segmentation).tolist()
141    if seg_ids[0] == 0:
142        seg_ids = seg_ids[1:]
143    visited = {seg_id: False for seg_id in seg_ids}
144
145    n_features = 257  # Don't hard-code?
146    features = np.zeros((len(seg_ids), n_features), dtype="float32")
147
148    # Then, we create a generator for iterating over the slices and / or tile.
149    # This generator returns the respective segmentation and embeddings.
150    seg_embed_generator, n_gen = _create_seg_and_embed_generator(
151        segmentation, image_embeddings, is_tiled=is_tiled, is_3d=is_3d
152    )
153
154    for seg, embeds in tqdm(
155        seg_embed_generator(), total=n_gen, disable=not verbose, desc="Compute object features"
156    ):
157        # Compute this seg ids and features.
158        this_seg_ids, this_features = _compute_object_features_impl(embeds, seg, resize_embedding_shape)
159        this_seg_ids = this_seg_ids.tolist()
160
161        # Find which of the seg ids are new (= processed for the first time).
162        # And the seg ids that were already visited.
163        new_idx = np.array([seg_ids.index(seg_id) for seg_id in this_seg_ids if not visited[seg_id]], dtype="int")
164        visited_idx = np.array([seg_ids.index(seg_id) for seg_id in this_seg_ids if visited[seg_id]], dtype="int")
165
166        # Get the corresponding feature indices.
167        this_new_idx = np.array(
168            [this_seg_ids.index(seg_id) for seg_id in this_seg_ids if not visited[seg_id]], dtype="int"
169        )
170        this_visited_idx = np.array(
171            [this_seg_ids.index(seg_id) for seg_id in this_seg_ids if visited[seg_id]], dtype="int"
172        )
173
174        # New features can be written directly.
175        features[new_idx] = this_features[this_new_idx]
176
177        # Features that were already visited can be merged.
178        if len(visited_idx) > 0:
179            # Get ths sizes, which are needed for computing the mean.
180            prev_size = features[visited_idx, 0:1]
181            this_size = this_features[this_visited_idx, 0:1]
182
183            # The sizes themselve are merged by addition.
184            features[visited_idx, 0] += this_features[this_visited_idx, 0]
185
186            # Mean values are merged via weighted sum.
187            features[visited_idx, 1:] = (
188                prev_size * features[visited_idx, 1:] + this_size * this_features[this_visited_idx, 1:]
189            ) / (prev_size + this_size)
190
191        # Set all seg ids from this block to visited.
192        visited.update({seg_id: True for seg_id in this_seg_ids})
193
194    return np.array(seg_ids), features

Compute object features based on SAM embeddings.

Arguments:
  • image_embeddings: The precomputed image embeddings.
  • segmentation: The segmentation for which to compute the features.
  • resize_embedding_shape: Shape for intermediate resizing of the embeddings.
  • verbose: Whether to print a progressbar for the computation.
Returns:

The segmentation ids. The object features.

def project_prediction_to_segmentation( segmentation: numpy.ndarray, object_prediction: numpy.ndarray, seg_ids: numpy.ndarray) -> numpy.ndarray:
197def project_prediction_to_segmentation(
198    segmentation: np.ndarray,
199    object_prediction: np.ndarray,
200    seg_ids: np.ndarray
201) -> np.ndarray:
202    """Project object level prediction to the corresponding segmentation to obtain a pixel level prediction.
203
204    Args:
205        segmentation: The segmentation from which the object prediction is derived.
206        object_prediction: The object prediction.
207        seg_ids: The segmentation ids matching the object prediction.
208
209    Returns:
210        The pixel level object prediction, corresponding to a semantic segmentation.
211    """
212    assert len(object_prediction) == len(seg_ids)
213    prediction = {seg_id: class_pred for seg_id, class_pred in zip(seg_ids, object_prediction)}
214    # Find missing segmentation ids. This will include the background id, but may include other ids of small objects.
215    # Such objects may get removed in the resizing operations.
216    missing_ids = np.setdiff1d(np.unique(segmentation), seg_ids)
217    prediction.update({missing_id: 0 for missing_id in missing_ids})
218    return takeDict(prediction, segmentation)

Project object level prediction to the corresponding segmentation to obtain a pixel level prediction.

Arguments:
  • segmentation: The segmentation from which the object prediction is derived.
  • object_prediction: The object prediction.
  • seg_ids: The segmentation ids matching the object prediction.
Returns:

The pixel level object prediction, corresponding to a semantic segmentation.

def run_prediction_with_object_classifier( images: Sequence[Union[str, os.PathLike, numpy.ndarray]], segmentations: Sequence[Union[str, os.PathLike, numpy.ndarray]], predictor, rf_path: Union[str, os.PathLike], image_key: Optional[str] = None, segmentation_key: Optional[str] = None, project_prediction: bool = True, ndim: Optional[int] = None) -> List[numpy.ndarray]:
225def run_prediction_with_object_classifier(
226    images: Sequence[Union[str, os.PathLike, np.ndarray]],
227    segmentations: Sequence[Union[str, os.PathLike, np.ndarray]],
228    predictor,
229    rf_path: Union[str, os.PathLike],
230    image_key: Optional[str] = None,
231    segmentation_key: Optional[str] = None,
232    project_prediction: bool = True,
233    ndim: Optional[int] = None,
234) -> List[np.ndarray]:
235    """Run prediction with a pretrained object classifier on a series of images.
236
237    Args:
238        images: The images, either given as a list of numpy array or filepaths.
239        segmentations: The segmentaitons, either given as a list of numpy array or filepaths.
240        predictor:
241        rf_path:
242        image_key:
243        segmentation_key:
244        project_prediction:
245        ndim:
246
247    Returns:
248        The predictions.
249    """
250    assert len(images) == len(segmentations)
251    rf = load(rf_path)
252    predictions = []
253    for image, segmentation in tqdm(
254        zip(images, segmentations), total=len(images), desc="Run prediction with object classifier"
255    ):
256        embeddings = util.precompute_image_embeddings(predictor, image, verbose=False, ndim=ndim)
257        seg_ids, features = compute_object_features(embeddings, segmentation, verbose=False)
258        prediction = rf.predict(features)
259        if project_prediction:
260            prediction = project_prediction_to_segmentation(segmentation, prediction, seg_ids)
261        predictions.append(prediction)
262    return predictions

Run prediction with a pretrained object classifier on a series of images.

Arguments:
  • images: The images, either given as a list of numpy array or filepaths.
  • segmentations: The segmentaitons, either given as a list of numpy array or filepaths.
  • predictor:
  • rf_path:
  • image_key:
  • segmentation_key:
  • project_prediction:
  • ndim:
Returns:

The predictions.