micro_sam.evaluation.multi_dimensional_segmentation

  1import os
  2import numpy as np
  3import pandas as pd
  4from tqdm import tqdm
  5from math import floor
  6from itertools import product
  7from typing import Union, Tuple, Optional, List, Dict
  8
  9import imageio.v3 as imageio
 10
 11import torch
 12
 13from elf.evaluation import mean_segmentation_accuracy
 14
 15from .. import util
 16from ..inference import batched_inference
 17from ..prompt_generators import PointAndBoxPromptGenerator
 18from ..multi_dimensional_segmentation import segment_mask_in_volume
 19from ..evaluation.instance_segmentation import _get_range_of_search_values, evaluate_instance_segmentation_grid_search
 20
 21
 22def default_grid_search_values_multi_dimensional_segmentation(
 23    iou_threshold_values: Optional[List[float]] = None,
 24    projection_method_values: Optional[Union[str, dict]] = None,
 25    box_extension_values: Optional[Union[float, int]] = None
 26) -> Dict[str, List]:
 27    """Default grid-search parameters for multi-dimensional prompt-based instance segmentation.
 28
 29    Args:
 30        iou_threshold_values: The values for `iou_threshold` used in the grid-search.
 31            By default values in the range from 0.5 to 0.9 with a stepsize of 0.1 will be used.
 32        projection_method_values: The values for `projection` method used in the grid-search.
 33            By default the values `mask`, `bounding_box` and `points` are used.
 34        box_extension_values: The values for `box_extension` used in the grid-search.
 35            By default values in the range from 0 to 0.25 with a stepsize of 0.025 will be used.
 36
 37    Returns:
 38        The values for grid search.
 39    """
 40    if iou_threshold_values is None:
 41        iou_threshold_values = _get_range_of_search_values([0.5, 0.9], step=0.1)
 42
 43    if projection_method_values is None:
 44        projection_method_values = [
 45            "mask", "points", "box", "points_and_mask", "single_point"
 46        ]
 47
 48    if box_extension_values is None:
 49        box_extension_values = _get_range_of_search_values([0, 0.25], step=0.025)
 50
 51    return {
 52        "iou_threshold": iou_threshold_values,
 53        "projection": projection_method_values,
 54        "box_extension": box_extension_values
 55    }
 56
 57
 58@torch.no_grad()
 59def segment_slices_from_ground_truth(
 60    volume: np.ndarray,
 61    ground_truth: np.ndarray,
 62    model_type: str,
 63    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
 64    embedding_path: Optional[Union[str, os.PathLike]] = None,
 65    save_path: Optional[Union[str, os.PathLike]] = None,
 66    iou_threshold: float = 0.8,
 67    projection: Union[str, dict] = "mask",
 68    box_extension: Union[float, int] = 0.025,
 69    device: Union[str, torch.device] = None,
 70    interactive_seg_mode: str = "box",
 71    verbose: bool = False,
 72    return_segmentation: bool = False,
 73    min_size: int = 0,
 74) -> Union[float, Tuple[np.ndarray, float]]:
 75    """Segment all objects in a volume by prompt-based segmentation in one slice per object.
 76
 77    This function first segments each object in the respective specified slice using interactive
 78    (prompt-based) segmentation functionality. Then it segments the particular object in the
 79    remaining slices in the volume.
 80
 81    Args:
 82        volume: The input volume.
 83        ground_truth: The label volume with instance segmentations.
 84        model_type: Choice of segment anything model.
 85        checkpoint_path: Path to the model checkpoint.
 86        embedding_path: Path to cache the computed embeddings.
 87        save_path: Path to store the segmentations.
 88        iou_threshold: The criterion to decide whether to link the objects in the consecutive slice's segmentation.
 89        projection: The projection (prompting) method to generate prompts for consecutive slices.
 90        box_extension: Extension factor for increasing the box size after projection.
 91        device: The selected device for computation.
 92        interactive_seg_mode: Method for guiding prompt-based instance segmentation.
 93        verbose: Whether to get the trace for projected segmentations.
 94        return_segmentation: Whether to return the segmented volume.
 95        min_size: The minimal size for evaluating an object in the ground-truth.
 96            The size is measured within the central slice.
 97    """
 98    assert volume.ndim == 3
 99
100    predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, device=device)
101
102    # Compute the image embeddings
103    embeddings = util.precompute_image_embeddings(
104        predictor=predictor, input_=volume, save_path=embedding_path, ndim=3, verbose=verbose,
105    )
106
107    # Compute instance ids (without the background)
108    label_ids = np.unique(ground_truth)[1:]
109    assert len(label_ids) > 0, "There are no objects to perform volumetric segmentation."
110
111    # Create an empty volume to store incoming segmentations
112    final_segmentation = np.zeros_like(ground_truth)
113
114    skipped_label_ids = []
115    for label_id in label_ids:
116        # Binary label volume per instance (also referred to as object)
117        this_seg = (ground_truth == label_id).astype("int")
118
119        # Let's search the slices where we have the current object
120        slice_range = np.where(this_seg)[0]
121
122        # Choose the middle slice of the current object for prompt-based segmentation
123        slice_range = (slice_range.min(), slice_range.max())
124        slice_choice = floor(np.mean(slice_range))
125        this_slice_seg = this_seg[slice_choice]
126        if min_size > 0 and this_slice_seg.sum() < min_size:
127            skipped_label_ids.append(label_id)
128            continue
129
130        if verbose:
131            print(f"The object with id {label_id} lies in slice range: {slice_range}")
132
133        # Prompts for segmentation for the current slice
134        if interactive_seg_mode == "points":
135            _get_points, _get_box = True, False
136        elif interactive_seg_mode == "box":
137            _get_points, _get_box = False, True
138        else:
139            raise ValueError(
140                f"The provided interactive prompting '{interactive_seg_mode}' for the first slice isn't supported."
141                "Please choose from 'box' / 'points'."
142            )
143
144        prompt_generator = PointAndBoxPromptGenerator(
145            n_positive_points=1 if _get_points else 0,
146            n_negative_points=1 if _get_points else 0,
147            dilation_strength=10,
148            get_point_prompts=_get_points,
149            get_box_prompts=_get_box
150        )
151        _, box_coords = util.get_centers_and_bounding_boxes(this_slice_seg)
152        point_prompts, point_labels, box_prompts, _ = prompt_generator(
153            segmentation=torch.from_numpy(this_slice_seg)[None, None].to(torch.float32),
154            bbox_coordinates=[box_coords[1]],
155        )
156
157        # Prompt-based segmentation on middle slice of the current object
158        output_slice = batched_inference(
159            predictor=predictor,
160            image=volume[slice_choice],
161            batch_size=1,
162            boxes=box_prompts.numpy() if isinstance(box_prompts, torch.Tensor) else box_prompts,
163            points=point_prompts.numpy() if isinstance(point_prompts, torch.Tensor) else point_prompts,
164            point_labels=point_labels.numpy() if isinstance(point_labels, torch.Tensor) else point_labels,
165            verbose_embeddings=verbose,
166        )
167        output_seg = np.zeros_like(ground_truth)
168        output_seg[slice_choice][output_slice == 1] = 1
169
170        # Segment the object in the entire volume with the specified segmented slice
171        this_seg, _ = segment_mask_in_volume(
172            segmentation=output_seg,
173            predictor=predictor,
174            image_embeddings=embeddings,
175            segmented_slices=np.array(slice_choice),
176            stop_lower=False, stop_upper=False,
177            iou_threshold=iou_threshold,
178            projection=projection,
179            box_extension=box_extension,
180            verbose=verbose,
181        )
182
183        # Store the entire segmented object
184        final_segmentation[this_seg == 1] = label_id
185
186    # Save the volumetric segmentation
187    if save_path is not None:
188        imageio.imwrite(save_path, final_segmentation, compression="zlib")
189
190    # Evaluate the volumetric segmentation
191    if skipped_label_ids:
192        curr_gt = ground_truth.copy()
193        curr_gt[np.isin(curr_gt, skipped_label_ids)] = 0
194    else:
195        curr_gt = ground_truth
196
197    msa, sa = mean_segmentation_accuracy(final_segmentation, curr_gt, return_accuracies=True)
198    results = {"mSA": msa, "SA50": sa[0], "SA75": sa[5]}
199    results = pd.DataFrame.from_dict([results])
200
201    if return_segmentation:
202        return results, final_segmentation
203    else:
204        return results
205
206
207def _get_best_parameters_from_grid_search_combinations(result_dir, best_params_path, grid_search_values):
208    if os.path.exists(best_params_path):
209        print("The best parameters are already saved at:", best_params_path)
210        return
211
212    best_kwargs, best_msa = evaluate_instance_segmentation_grid_search(result_dir, list(grid_search_values.keys()))
213
214    # let's save the best parameters
215    best_kwargs["mSA"] = best_msa
216    best_param_df = pd.DataFrame.from_dict([best_kwargs])
217    best_param_df.to_csv(best_params_path)
218
219    best_param_str = ", ".join(f"{k} = {v}" for k, v in best_kwargs.items())
220    print("Best grid-search result:", best_msa, "with parmeters:\n", best_param_str)
221
222
223def run_multi_dimensional_segmentation_grid_search(
224    volume: np.ndarray,
225    ground_truth: np.ndarray,
226    model_type: str,
227    checkpoint_path: Union[str, os.PathLike],
228    embedding_path: Union[str, os.PathLike],
229    result_dir: Union[str, os.PathLike],
230    interactive_seg_mode: str = "box",
231    verbose: bool = False,
232    grid_search_values: Optional[Dict[str, List]] = None,
233    min_size: int = 0
234):
235    """Run grid search for prompt-based multi-dimensional instance segmentation.
236
237    The parameters and their respective value ranges for the grid search are specified via the
238    `grid_search_values` argument. For example, to run a grid search over the parameters `iou_threshold`,
239    `projection` and `box_extension`, you can pass the following:
240    ```
241    grid_search_values = {
242        "iou_threshold": [0.5, 0.6, 0.7, 0.8, 0.9],
243        "projection": ["mask", "bounding_box", "points"],
244        "box_extension": [0, 0.1, 0.2, 0.3, 0.4, 0,5],
245    }
246    ```
247    All combinations of the parameters will be checked.
248    If passed None, the function `default_grid_search_values_multi_dimensional_segmentation` is used
249    to get the default grid search parameters for the instance segmentation method.
250
251    Args:
252        volume: The input volume.
253        ground_truth: The label volume with instance segmentations.
254        model_type: Choice of segment anything model.
255        checkpoint_path: Path to the model checkpoint.
256        embedding_path: Path to cache the computed embeddings.
257        result_path: Path to save the grid search results.
258        interactive_seg_mode: Method for guiding prompt-based instance segmentation.
259        verbose: Whether to get the trace for projected segmentations.
260        grid_search_values: The grid search values for parameters of the `segment_slices_from_ground_truth` function.
261        min_size: The minimal size for evaluating an object in the ground-truth.
262            The size is measured within the central slice.
263    """
264    if grid_search_values is None:
265        grid_search_values = default_grid_search_values_multi_dimensional_segmentation()
266
267    assert len(grid_search_values.keys()) == 3, "There must be three grid-search parameters. See above for details."
268
269    os.makedirs(result_dir, exist_ok=True)
270    result_path = os.path.join(result_dir, "all_grid_search_results.csv")
271    best_params_path = os.path.join(result_dir, "grid_search_params_multi_dimensional_segmentation.csv")
272    if os.path.exists(result_path):
273        _get_best_parameters_from_grid_search_combinations(result_dir, best_params_path, grid_search_values)
274        return best_params_path
275
276    # Compute all combinations of grid search values.
277    gs_combinations = product(*grid_search_values.values())
278
279    # Map each combination back to a valid kwarg input.
280    gs_combinations = [
281        {k: v for k, v in zip(grid_search_values.keys(), vals)} for vals in gs_combinations
282    ]
283
284    net_list = []
285    for gs_kwargs in tqdm(gs_combinations):
286        results = segment_slices_from_ground_truth(
287            volume=volume,
288            ground_truth=ground_truth,
289            model_type=model_type,
290            checkpoint_path=checkpoint_path,
291            embedding_path=embedding_path,
292            interactive_seg_mode=interactive_seg_mode,
293            verbose=verbose,
294            return_segmentation=False,
295            min_size=min_size,
296            **gs_kwargs
297        )
298
299        result_dict = {**results, **gs_kwargs}
300        tmp_df = pd.DataFrame([result_dict])
301        net_list.append(tmp_df)
302
303    res_df = pd.concat(net_list, ignore_index=True)
304    res_df.to_csv(result_path)
305
306    _get_best_parameters_from_grid_search_combinations(result_dir, best_params_path, grid_search_values)
307    print("The best grid-search parameters have been computed and stored at:", best_params_path)
308    return best_params_path
def default_grid_search_values_multi_dimensional_segmentation( iou_threshold_values: Optional[List[float]] = None, projection_method_values: Union[str, dict, NoneType] = None, box_extension_values: Union[int, float, NoneType] = None) -> Dict[str, List]:
23def default_grid_search_values_multi_dimensional_segmentation(
24    iou_threshold_values: Optional[List[float]] = None,
25    projection_method_values: Optional[Union[str, dict]] = None,
26    box_extension_values: Optional[Union[float, int]] = None
27) -> Dict[str, List]:
28    """Default grid-search parameters for multi-dimensional prompt-based instance segmentation.
29
30    Args:
31        iou_threshold_values: The values for `iou_threshold` used in the grid-search.
32            By default values in the range from 0.5 to 0.9 with a stepsize of 0.1 will be used.
33        projection_method_values: The values for `projection` method used in the grid-search.
34            By default the values `mask`, `bounding_box` and `points` are used.
35        box_extension_values: The values for `box_extension` used in the grid-search.
36            By default values in the range from 0 to 0.25 with a stepsize of 0.025 will be used.
37
38    Returns:
39        The values for grid search.
40    """
41    if iou_threshold_values is None:
42        iou_threshold_values = _get_range_of_search_values([0.5, 0.9], step=0.1)
43
44    if projection_method_values is None:
45        projection_method_values = [
46            "mask", "points", "box", "points_and_mask", "single_point"
47        ]
48
49    if box_extension_values is None:
50        box_extension_values = _get_range_of_search_values([0, 0.25], step=0.025)
51
52    return {
53        "iou_threshold": iou_threshold_values,
54        "projection": projection_method_values,
55        "box_extension": box_extension_values
56    }

Default grid-search parameters for multi-dimensional prompt-based instance segmentation.

Arguments:
  • iou_threshold_values: The values for iou_threshold used in the grid-search. By default values in the range from 0.5 to 0.9 with a stepsize of 0.1 will be used.
  • projection_method_values: The values for projection method used in the grid-search. By default the values mask, bounding_box and points are used.
  • box_extension_values: The values for box_extension used in the grid-search. By default values in the range from 0 to 0.25 with a stepsize of 0.025 will be used.
Returns:

The values for grid search.

@torch.no_grad()
def segment_slices_from_ground_truth( volume: numpy.ndarray, ground_truth: numpy.ndarray, model_type: str, checkpoint_path: Union[os.PathLike, str, NoneType] = None, embedding_path: Union[os.PathLike, str, NoneType] = None, save_path: Union[os.PathLike, str, NoneType] = None, iou_threshold: float = 0.8, projection: Union[str, dict] = 'mask', box_extension: Union[float, int] = 0.025, device: Union[str, torch.device] = None, interactive_seg_mode: str = 'box', verbose: bool = False, return_segmentation: bool = False, min_size: int = 0) -> Union[float, Tuple[numpy.ndarray, float]]:
 59@torch.no_grad()
 60def segment_slices_from_ground_truth(
 61    volume: np.ndarray,
 62    ground_truth: np.ndarray,
 63    model_type: str,
 64    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
 65    embedding_path: Optional[Union[str, os.PathLike]] = None,
 66    save_path: Optional[Union[str, os.PathLike]] = None,
 67    iou_threshold: float = 0.8,
 68    projection: Union[str, dict] = "mask",
 69    box_extension: Union[float, int] = 0.025,
 70    device: Union[str, torch.device] = None,
 71    interactive_seg_mode: str = "box",
 72    verbose: bool = False,
 73    return_segmentation: bool = False,
 74    min_size: int = 0,
 75) -> Union[float, Tuple[np.ndarray, float]]:
 76    """Segment all objects in a volume by prompt-based segmentation in one slice per object.
 77
 78    This function first segments each object in the respective specified slice using interactive
 79    (prompt-based) segmentation functionality. Then it segments the particular object in the
 80    remaining slices in the volume.
 81
 82    Args:
 83        volume: The input volume.
 84        ground_truth: The label volume with instance segmentations.
 85        model_type: Choice of segment anything model.
 86        checkpoint_path: Path to the model checkpoint.
 87        embedding_path: Path to cache the computed embeddings.
 88        save_path: Path to store the segmentations.
 89        iou_threshold: The criterion to decide whether to link the objects in the consecutive slice's segmentation.
 90        projection: The projection (prompting) method to generate prompts for consecutive slices.
 91        box_extension: Extension factor for increasing the box size after projection.
 92        device: The selected device for computation.
 93        interactive_seg_mode: Method for guiding prompt-based instance segmentation.
 94        verbose: Whether to get the trace for projected segmentations.
 95        return_segmentation: Whether to return the segmented volume.
 96        min_size: The minimal size for evaluating an object in the ground-truth.
 97            The size is measured within the central slice.
 98    """
 99    assert volume.ndim == 3
100
101    predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, device=device)
102
103    # Compute the image embeddings
104    embeddings = util.precompute_image_embeddings(
105        predictor=predictor, input_=volume, save_path=embedding_path, ndim=3, verbose=verbose,
106    )
107
108    # Compute instance ids (without the background)
109    label_ids = np.unique(ground_truth)[1:]
110    assert len(label_ids) > 0, "There are no objects to perform volumetric segmentation."
111
112    # Create an empty volume to store incoming segmentations
113    final_segmentation = np.zeros_like(ground_truth)
114
115    skipped_label_ids = []
116    for label_id in label_ids:
117        # Binary label volume per instance (also referred to as object)
118        this_seg = (ground_truth == label_id).astype("int")
119
120        # Let's search the slices where we have the current object
121        slice_range = np.where(this_seg)[0]
122
123        # Choose the middle slice of the current object for prompt-based segmentation
124        slice_range = (slice_range.min(), slice_range.max())
125        slice_choice = floor(np.mean(slice_range))
126        this_slice_seg = this_seg[slice_choice]
127        if min_size > 0 and this_slice_seg.sum() < min_size:
128            skipped_label_ids.append(label_id)
129            continue
130
131        if verbose:
132            print(f"The object with id {label_id} lies in slice range: {slice_range}")
133
134        # Prompts for segmentation for the current slice
135        if interactive_seg_mode == "points":
136            _get_points, _get_box = True, False
137        elif interactive_seg_mode == "box":
138            _get_points, _get_box = False, True
139        else:
140            raise ValueError(
141                f"The provided interactive prompting '{interactive_seg_mode}' for the first slice isn't supported."
142                "Please choose from 'box' / 'points'."
143            )
144
145        prompt_generator = PointAndBoxPromptGenerator(
146            n_positive_points=1 if _get_points else 0,
147            n_negative_points=1 if _get_points else 0,
148            dilation_strength=10,
149            get_point_prompts=_get_points,
150            get_box_prompts=_get_box
151        )
152        _, box_coords = util.get_centers_and_bounding_boxes(this_slice_seg)
153        point_prompts, point_labels, box_prompts, _ = prompt_generator(
154            segmentation=torch.from_numpy(this_slice_seg)[None, None].to(torch.float32),
155            bbox_coordinates=[box_coords[1]],
156        )
157
158        # Prompt-based segmentation on middle slice of the current object
159        output_slice = batched_inference(
160            predictor=predictor,
161            image=volume[slice_choice],
162            batch_size=1,
163            boxes=box_prompts.numpy() if isinstance(box_prompts, torch.Tensor) else box_prompts,
164            points=point_prompts.numpy() if isinstance(point_prompts, torch.Tensor) else point_prompts,
165            point_labels=point_labels.numpy() if isinstance(point_labels, torch.Tensor) else point_labels,
166            verbose_embeddings=verbose,
167        )
168        output_seg = np.zeros_like(ground_truth)
169        output_seg[slice_choice][output_slice == 1] = 1
170
171        # Segment the object in the entire volume with the specified segmented slice
172        this_seg, _ = segment_mask_in_volume(
173            segmentation=output_seg,
174            predictor=predictor,
175            image_embeddings=embeddings,
176            segmented_slices=np.array(slice_choice),
177            stop_lower=False, stop_upper=False,
178            iou_threshold=iou_threshold,
179            projection=projection,
180            box_extension=box_extension,
181            verbose=verbose,
182        )
183
184        # Store the entire segmented object
185        final_segmentation[this_seg == 1] = label_id
186
187    # Save the volumetric segmentation
188    if save_path is not None:
189        imageio.imwrite(save_path, final_segmentation, compression="zlib")
190
191    # Evaluate the volumetric segmentation
192    if skipped_label_ids:
193        curr_gt = ground_truth.copy()
194        curr_gt[np.isin(curr_gt, skipped_label_ids)] = 0
195    else:
196        curr_gt = ground_truth
197
198    msa, sa = mean_segmentation_accuracy(final_segmentation, curr_gt, return_accuracies=True)
199    results = {"mSA": msa, "SA50": sa[0], "SA75": sa[5]}
200    results = pd.DataFrame.from_dict([results])
201
202    if return_segmentation:
203        return results, final_segmentation
204    else:
205        return results

Segment all objects in a volume by prompt-based segmentation in one slice per object.

This function first segments each object in the respective specified slice using interactive (prompt-based) segmentation functionality. Then it segments the particular object in the remaining slices in the volume.

Arguments:
  • volume: The input volume.
  • ground_truth: The label volume with instance segmentations.
  • model_type: Choice of segment anything model.
  • checkpoint_path: Path to the model checkpoint.
  • embedding_path: Path to cache the computed embeddings.
  • save_path: Path to store the segmentations.
  • iou_threshold: The criterion to decide whether to link the objects in the consecutive slice's segmentation.
  • projection: The projection (prompting) method to generate prompts for consecutive slices.
  • box_extension: Extension factor for increasing the box size after projection.
  • device: The selected device for computation.
  • interactive_seg_mode: Method for guiding prompt-based instance segmentation.
  • verbose: Whether to get the trace for projected segmentations.
  • return_segmentation: Whether to return the segmented volume.
  • min_size: The minimal size for evaluating an object in the ground-truth. The size is measured within the central slice.