micro_sam.evaluation.multi_dimensional_segmentation

  1import os
  2from tqdm import tqdm
  3from itertools import product
  4from typing import Union, Tuple, Optional, List, Dict, Literal
  5
  6import numpy as np
  7import pandas as pd
  8from math import floor
  9import imageio.v3 as imageio
 10
 11import torch
 12
 13from elf.evaluation import mean_segmentation_accuracy, dice_score
 14
 15from .. import util
 16from ..inference import batched_inference
 17from ..prompt_generators import PointAndBoxPromptGenerator
 18from ..multi_dimensional_segmentation import segment_mask_in_volume
 19from ..evaluation.instance_segmentation import _get_range_of_search_values, evaluate_instance_segmentation_grid_search
 20
 21
 22def default_grid_search_values_multi_dimensional_segmentation(
 23    iou_threshold_values: Optional[List[float]] = None,
 24    projection_method_values: Optional[Union[str, dict]] = None,
 25    box_extension_values: Optional[Union[float, int]] = None
 26) -> Dict[str, List]:
 27    """Default grid-search parameters for multi-dimensional prompt-based instance segmentation.
 28
 29    Args:
 30        iou_threshold_values: The values for `iou_threshold` used in the grid-search.
 31            By default values in the range from 0.5 to 0.9 with a stepsize of 0.1 will be used.
 32        projection_method_values: The values for `projection` method used in the grid-search.
 33            By default the values `mask`, `points`, `box`, `points_and_mask` and `single_point` are used.
 34        box_extension_values: The values for `box_extension` used in the grid-search.
 35            By default values in the range from 0 to 0.25 with a stepsize of 0.025 will be used.
 36
 37    Returns:
 38        The values for grid search.
 39    """
 40    if iou_threshold_values is None:
 41        iou_threshold_values = _get_range_of_search_values([0.5, 0.9], step=0.1)
 42
 43    if projection_method_values is None:
 44        projection_method_values = [
 45            "mask", "points", "box", "points_and_mask", "single_point"
 46        ]
 47
 48    if box_extension_values is None:
 49        box_extension_values = _get_range_of_search_values([0, 0.25], step=0.025)
 50
 51    return {
 52        "iou_threshold": iou_threshold_values,
 53        "projection": projection_method_values,
 54        "box_extension": box_extension_values
 55    }
 56
 57
 58@torch.no_grad()
 59def segment_slices_from_ground_truth(
 60    volume: np.ndarray,
 61    ground_truth: np.ndarray,
 62    model_type: str,
 63    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
 64    embedding_path: Optional[Union[str, os.PathLike]] = None,
 65    save_path: Optional[Union[str, os.PathLike]] = None,
 66    iou_threshold: float = 0.8,
 67    projection: Union[str, dict] = "mask",
 68    box_extension: Union[float, int] = 0.025,
 69    device: Union[str, torch.device] = None,
 70    interactive_seg_mode: str = "box",
 71    verbose: bool = False,
 72    return_segmentation: bool = False,
 73    min_size: int = 0,
 74    evaluation_metric: Literal["sa", "dice"] = "sa",
 75) -> Union[Dict, Tuple[Dict, np.ndarray]]:
 76    """Segment all objects in a volume by prompt-based segmentation in one slice per object.
 77
 78    This function first segments each object in the respective specified slice using interactive
 79    (prompt-based) segmentation functionality. Then it segments the particular object in the
 80    remaining slices in the volume.
 81
 82    Args:
 83        volume: The input volume.
 84        ground_truth: The label volume with instance segmentations.
 85        model_type: Choice of segment anything model.
 86        checkpoint_path: Path to the model checkpoint.
 87        embedding_path: Path to cache the computed embeddings.
 88        save_path: Path to store the segmentations.
 89        iou_threshold: The criterion to decide whether to link the objects in the consecutive slice's segmentation.
 90        projection: The projection (prompting) method to generate prompts for consecutive slices.
 91        box_extension: Extension factor for increasing the box size after projection.
 92        device: The selected device for computation.
 93        interactive_seg_mode: Method for guiding prompt-based instance segmentation.
 94        verbose: Whether to get the trace for projected segmentations.
 95        return_segmentation: Whether to return the segmented volume.
 96        min_size: The minimal size for evaluating an object in the ground-truth.
 97            The size is measured within the central slice.
 98        evaluation_metric: The choice of supported metric to evaluate predictions.
 99
100    Returns:
101        A dictionary of results with all desired metrics.
102        Optional segmentation result (controlled by `return_segmentation` argument).
103    """
104    assert volume.ndim == 3
105
106    predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, device=device)
107
108    # Compute the image embeddings
109    embeddings = util.precompute_image_embeddings(
110        predictor=predictor, input_=volume, save_path=embedding_path, ndim=3, verbose=verbose,
111    )
112
113    # Compute instance ids (without the background)
114    label_ids = np.unique(ground_truth)[1:]
115    assert len(label_ids) > 0, "There are no objects to perform volumetric segmentation."
116
117    # Create an empty volume to store incoming segmentations
118    final_segmentation = np.zeros_like(ground_truth)
119
120    _segmentation_completed = False
121    if save_path is not None and os.path.exists(save_path):
122        _segmentation_completed = True  # We avoid rerunning the segmentation if it is completed.
123
124    skipped_label_ids = []
125    for label_id in tqdm(label_ids, desc="Segmenting per object in the volume", disable=not verbose):
126        # Binary label volume per instance (also referred to as object)
127        this_seg = (ground_truth == label_id).astype("int")
128
129        # Let's search the slices where we have the current object
130        slice_range = np.where(this_seg)[0]
131
132        # Choose the middle slice of the current object for prompt-based segmentation
133        slice_range = (slice_range.min(), slice_range.max())
134        slice_choice = floor(np.mean(slice_range))
135        this_slice_seg = this_seg[slice_choice]
136        if min_size > 0 and this_slice_seg.sum() < min_size:
137            skipped_label_ids.append(label_id)
138            continue
139
140        if _segmentation_completed:
141            continue
142
143        if verbose:
144            print(f"The object with id {label_id} lies in slice range: {slice_range}")
145
146        # Prompts for segmentation for the current slice
147        if interactive_seg_mode == "points":
148            _get_points, _get_box = True, False
149        elif interactive_seg_mode == "box":
150            _get_points, _get_box = False, True
151        else:
152            raise ValueError(
153                f"The provided interactive prompting '{interactive_seg_mode}' for the first slice isn't supported. "
154                "Please choose from 'box' / 'points'."
155            )
156
157        prompt_generator = PointAndBoxPromptGenerator(
158            n_positive_points=1 if _get_points else 0,
159            n_negative_points=1 if _get_points else 0,
160            dilation_strength=10,
161            get_point_prompts=_get_points,
162            get_box_prompts=_get_box
163        )
164        _, box_coords = util.get_centers_and_bounding_boxes(this_slice_seg)
165        point_prompts, point_labels, box_prompts, _ = prompt_generator(
166            segmentation=torch.from_numpy(this_slice_seg)[None, None].to(torch.float32),
167            bbox_coordinates=[box_coords[1]],
168        )
169
170        # Prompt-based segmentation on middle slice of the current object
171        output_slice = batched_inference(
172            predictor=predictor,
173            image=volume[slice_choice],
174            batch_size=1,
175            boxes=box_prompts.numpy() if isinstance(box_prompts, torch.Tensor) else box_prompts,
176            points=point_prompts.numpy() if isinstance(point_prompts, torch.Tensor) else point_prompts,
177            point_labels=point_labels.numpy() if isinstance(point_labels, torch.Tensor) else point_labels,
178            verbose_embeddings=verbose,
179        )
180        output_seg = np.zeros_like(ground_truth)
181        output_seg[slice_choice][output_slice == 1] = 1
182
183        # Segment the object in the entire volume with the specified segmented slice
184        this_seg, _ = segment_mask_in_volume(
185            segmentation=output_seg,
186            predictor=predictor,
187            image_embeddings=embeddings,
188            segmented_slices=np.array(slice_choice),
189            stop_lower=False, stop_upper=False,
190            iou_threshold=iou_threshold,
191            projection=projection,
192            box_extension=box_extension,
193            verbose=verbose,
194        )
195
196        # Store the entire segmented object
197        final_segmentation[this_seg == 1] = label_id
198
199    # Save the volumetric segmentation
200    if save_path is not None:
201        if _segmentation_completed:
202            final_segmentation = imageio.imread(save_path)
203        else:
204            imageio.imwrite(save_path, final_segmentation, compression="zlib")
205
206    # Evaluate the volumetric segmentation
207    if skipped_label_ids:
208        curr_gt = ground_truth.copy()
209        curr_gt[np.isin(curr_gt, skipped_label_ids)] = 0
210    else:
211        curr_gt = ground_truth
212
213    if evaluation_metric == "sa":
214        msa, sa = mean_segmentation_accuracy(
215            segmentation=final_segmentation, groundtruth=curr_gt, return_accuracies=True
216        )
217        results = {"mSA": msa, "SA50": sa[0], "SA75": sa[5]}
218
219    elif evaluation_metric == "dice":
220        # Calculate overall dice score (by binarizing all labels).
221        dice = dice_score(segmentation=final_segmentation, groundtruth=curr_gt)
222        results = {"Dice": dice}
223
224    elif evaluation_metric == "dice_per_class":
225        # Calculate dice per class.
226        dice = [
227            dice_score(segmentation=(final_segmentation == i), groundtruth=(curr_gt == i))
228            for i in np.unique(curr_gt)[1:]
229        ]
230        dice = np.mean(dice)
231        results = {"Dice": dice}
232
233    else:
234        raise ValueError(
235            f"'{evaluation_metric}' is not a supported evaluation metrics. "
236            "Please choose 'sa' / 'dice' / 'dice_per_class'."
237        )
238
239    if return_segmentation:
240        return results, final_segmentation
241    else:
242        return results
243
244
245def _get_best_parameters_from_grid_search_combinations(
246    result_dir, best_params_path, grid_search_values, evaluation_metric,
247):
248    if os.path.exists(best_params_path):
249        print("The best parameters are already saved at:", best_params_path)
250        return
251
252    criterion = "mSA" if evaluation_metric == "sa" else "Dice"
253    best_kwargs, best_metric = evaluate_instance_segmentation_grid_search(
254        result_dir=result_dir, grid_search_parameters=list(grid_search_values.keys()), criterion=criterion,
255    )
256
257    # let's save the best parameters
258    best_kwargs[criterion] = best_metric
259    best_param_df = pd.DataFrame.from_dict([best_kwargs])
260    best_param_df.to_csv(best_params_path)
261
262    best_param_str = ", ".join(f"{k} = {v}" for k, v in best_kwargs.items())
263    print("Best grid-search result:", best_metric, "with parmeters:\n", best_param_str)
264
265
266def run_multi_dimensional_segmentation_grid_search(
267    volume: np.ndarray,
268    ground_truth: np.ndarray,
269    model_type: str,
270    checkpoint_path: Union[str, os.PathLike],
271    embedding_path: Optional[Union[str, os.PathLike]],
272    result_dir: Union[str, os.PathLike],
273    interactive_seg_mode: str = "box",
274    verbose: bool = False,
275    grid_search_values: Optional[Dict[str, List]] = None,
276    min_size: int = 0,
277    evaluation_metric: Literal["sa", "dice"] = "sa",
278) -> str:
279    """Run grid search for prompt-based multi-dimensional instance segmentation.
280
281    The parameters and their respective value ranges for the grid search are specified via the
282    `grid_search_values` argument. For example, to run a grid search over the parameters `iou_threshold`,
283    `projection` and `box_extension`, you can pass the following:
284    ```python
285    grid_search_values = {
286        "iou_threshold": [0.5, 0.6, 0.7, 0.8, 0.9],
287        "projection": ["mask", "box", "points"],
288        "box_extension": [0, 0.1, 0.2, 0.3, 0.4, 0,5],
289    }
290    ```
291    All combinations of the parameters will be checked.
292    If passed None, the function `default_grid_search_values_multi_dimensional_segmentation` is used
293    to get the default grid search parameters for the instance segmentation method.
294
295    Args:
296        volume: The input volume.
297        ground_truth: The label volume with instance segmentations.
298        model_type: Choice of segment anything model.
299        checkpoint_path: Path to the model checkpoint.
300        embedding_path: Path to cache the computed embeddings.
301        result_dir: Path to save the grid search results.
302        interactive_seg_mode: Method for guiding prompt-based instance segmentation.
303        verbose: Whether to get the trace for projected segmentations.
304        grid_search_values: The grid search values for parameters of the `segment_slices_from_ground_truth` function.
305        min_size: The minimal size for evaluating an object in the ground-truth.
306            The size is measured within the central slice.
307        evaluation_metric: The choice of metric for evaluating predictions.
308
309    Returns:
310        Filepath where the best parameters are saved.
311    """
312    if grid_search_values is None:
313        grid_search_values = default_grid_search_values_multi_dimensional_segmentation()
314
315    assert len(grid_search_values.keys()) == 3, "There must be three grid-search parameters. See above for details."
316
317    os.makedirs(result_dir, exist_ok=True)
318    result_path = os.path.join(result_dir, "all_grid_search_results.csv")
319    best_params_path = os.path.join(result_dir, "grid_search_params_multi_dimensional_segmentation.csv")
320    if os.path.exists(result_path):
321        _get_best_parameters_from_grid_search_combinations(
322            result_dir, best_params_path, grid_search_values, evaluation_metric
323        )
324        return best_params_path
325
326    # Compute all combinations of grid search values.
327    gs_combinations = product(*grid_search_values.values())
328
329    # Map each combination back to a valid kwarg input.
330    gs_combinations = [
331        {k: v for k, v in zip(grid_search_values.keys(), vals)} for vals in gs_combinations
332    ]
333
334    net_list = []
335    for gs_kwargs in tqdm(gs_combinations, desc="Run grid-search for multi-dimensional segmentation"):
336        results = segment_slices_from_ground_truth(
337            volume=volume,
338            ground_truth=ground_truth,
339            model_type=model_type,
340            checkpoint_path=checkpoint_path,
341            embedding_path=embedding_path,
342            interactive_seg_mode=interactive_seg_mode,
343            verbose=verbose,
344            return_segmentation=False,
345            min_size=min_size,
346            evaluation_metric=evaluation_metric,
347            **gs_kwargs
348        )
349
350        result_dict = {**results, **gs_kwargs}
351        tmp_df = pd.DataFrame([result_dict])
352        net_list.append(tmp_df)
353
354    res_df = pd.concat(net_list, ignore_index=True)
355    res_df.to_csv(result_path)
356
357    _get_best_parameters_from_grid_search_combinations(
358        result_dir, best_params_path, grid_search_values, evaluation_metric
359    )
360    print("The best grid-search parameters have been computed and stored at:", best_params_path)
361    return best_params_path
def default_grid_search_values_multi_dimensional_segmentation( iou_threshold_values: Optional[List[float]] = None, projection_method_values: Union[str, dict, NoneType] = None, box_extension_values: Union[float, int, NoneType] = None) -> Dict[str, List]:
23def default_grid_search_values_multi_dimensional_segmentation(
24    iou_threshold_values: Optional[List[float]] = None,
25    projection_method_values: Optional[Union[str, dict]] = None,
26    box_extension_values: Optional[Union[float, int]] = None
27) -> Dict[str, List]:
28    """Default grid-search parameters for multi-dimensional prompt-based instance segmentation.
29
30    Args:
31        iou_threshold_values: The values for `iou_threshold` used in the grid-search.
32            By default values in the range from 0.5 to 0.9 with a stepsize of 0.1 will be used.
33        projection_method_values: The values for `projection` method used in the grid-search.
34            By default the values `mask`, `points`, `box`, `points_and_mask` and `single_point` are used.
35        box_extension_values: The values for `box_extension` used in the grid-search.
36            By default values in the range from 0 to 0.25 with a stepsize of 0.025 will be used.
37
38    Returns:
39        The values for grid search.
40    """
41    if iou_threshold_values is None:
42        iou_threshold_values = _get_range_of_search_values([0.5, 0.9], step=0.1)
43
44    if projection_method_values is None:
45        projection_method_values = [
46            "mask", "points", "box", "points_and_mask", "single_point"
47        ]
48
49    if box_extension_values is None:
50        box_extension_values = _get_range_of_search_values([0, 0.25], step=0.025)
51
52    return {
53        "iou_threshold": iou_threshold_values,
54        "projection": projection_method_values,
55        "box_extension": box_extension_values
56    }

Default grid-search parameters for multi-dimensional prompt-based instance segmentation.

Arguments:
  • iou_threshold_values: The values for iou_threshold used in the grid-search. By default values in the range from 0.5 to 0.9 with a stepsize of 0.1 will be used.
  • projection_method_values: The values for projection method used in the grid-search. By default the values mask, points, box, points_and_mask and single_point are used.
  • box_extension_values: The values for box_extension used in the grid-search. By default values in the range from 0 to 0.25 with a stepsize of 0.025 will be used.
Returns:

The values for grid search.

@torch.no_grad()
def segment_slices_from_ground_truth( volume: numpy.ndarray, ground_truth: numpy.ndarray, model_type: str, checkpoint_path: Union[os.PathLike, str, NoneType] = None, embedding_path: Union[os.PathLike, str, NoneType] = None, save_path: Union[os.PathLike, str, NoneType] = None, iou_threshold: float = 0.8, projection: Union[str, dict] = 'mask', box_extension: Union[float, int] = 0.025, device: Union[str, torch.device] = None, interactive_seg_mode: str = 'box', verbose: bool = False, return_segmentation: bool = False, min_size: int = 0, evaluation_metric: Literal['sa', 'dice'] = 'sa') -> Union[Dict, Tuple[Dict, numpy.ndarray]]:
 59@torch.no_grad()
 60def segment_slices_from_ground_truth(
 61    volume: np.ndarray,
 62    ground_truth: np.ndarray,
 63    model_type: str,
 64    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
 65    embedding_path: Optional[Union[str, os.PathLike]] = None,
 66    save_path: Optional[Union[str, os.PathLike]] = None,
 67    iou_threshold: float = 0.8,
 68    projection: Union[str, dict] = "mask",
 69    box_extension: Union[float, int] = 0.025,
 70    device: Union[str, torch.device] = None,
 71    interactive_seg_mode: str = "box",
 72    verbose: bool = False,
 73    return_segmentation: bool = False,
 74    min_size: int = 0,
 75    evaluation_metric: Literal["sa", "dice"] = "sa",
 76) -> Union[Dict, Tuple[Dict, np.ndarray]]:
 77    """Segment all objects in a volume by prompt-based segmentation in one slice per object.
 78
 79    This function first segments each object in the respective specified slice using interactive
 80    (prompt-based) segmentation functionality. Then it segments the particular object in the
 81    remaining slices in the volume.
 82
 83    Args:
 84        volume: The input volume.
 85        ground_truth: The label volume with instance segmentations.
 86        model_type: Choice of segment anything model.
 87        checkpoint_path: Path to the model checkpoint.
 88        embedding_path: Path to cache the computed embeddings.
 89        save_path: Path to store the segmentations.
 90        iou_threshold: The criterion to decide whether to link the objects in the consecutive slice's segmentation.
 91        projection: The projection (prompting) method to generate prompts for consecutive slices.
 92        box_extension: Extension factor for increasing the box size after projection.
 93        device: The selected device for computation.
 94        interactive_seg_mode: Method for guiding prompt-based instance segmentation.
 95        verbose: Whether to get the trace for projected segmentations.
 96        return_segmentation: Whether to return the segmented volume.
 97        min_size: The minimal size for evaluating an object in the ground-truth.
 98            The size is measured within the central slice.
 99        evaluation_metric: The choice of supported metric to evaluate predictions.
100
101    Returns:
102        A dictionary of results with all desired metrics.
103        Optional segmentation result (controlled by `return_segmentation` argument).
104    """
105    assert volume.ndim == 3
106
107    predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, device=device)
108
109    # Compute the image embeddings
110    embeddings = util.precompute_image_embeddings(
111        predictor=predictor, input_=volume, save_path=embedding_path, ndim=3, verbose=verbose,
112    )
113
114    # Compute instance ids (without the background)
115    label_ids = np.unique(ground_truth)[1:]
116    assert len(label_ids) > 0, "There are no objects to perform volumetric segmentation."
117
118    # Create an empty volume to store incoming segmentations
119    final_segmentation = np.zeros_like(ground_truth)
120
121    _segmentation_completed = False
122    if save_path is not None and os.path.exists(save_path):
123        _segmentation_completed = True  # We avoid rerunning the segmentation if it is completed.
124
125    skipped_label_ids = []
126    for label_id in tqdm(label_ids, desc="Segmenting per object in the volume", disable=not verbose):
127        # Binary label volume per instance (also referred to as object)
128        this_seg = (ground_truth == label_id).astype("int")
129
130        # Let's search the slices where we have the current object
131        slice_range = np.where(this_seg)[0]
132
133        # Choose the middle slice of the current object for prompt-based segmentation
134        slice_range = (slice_range.min(), slice_range.max())
135        slice_choice = floor(np.mean(slice_range))
136        this_slice_seg = this_seg[slice_choice]
137        if min_size > 0 and this_slice_seg.sum() < min_size:
138            skipped_label_ids.append(label_id)
139            continue
140
141        if _segmentation_completed:
142            continue
143
144        if verbose:
145            print(f"The object with id {label_id} lies in slice range: {slice_range}")
146
147        # Prompts for segmentation for the current slice
148        if interactive_seg_mode == "points":
149            _get_points, _get_box = True, False
150        elif interactive_seg_mode == "box":
151            _get_points, _get_box = False, True
152        else:
153            raise ValueError(
154                f"The provided interactive prompting '{interactive_seg_mode}' for the first slice isn't supported. "
155                "Please choose from 'box' / 'points'."
156            )
157
158        prompt_generator = PointAndBoxPromptGenerator(
159            n_positive_points=1 if _get_points else 0,
160            n_negative_points=1 if _get_points else 0,
161            dilation_strength=10,
162            get_point_prompts=_get_points,
163            get_box_prompts=_get_box
164        )
165        _, box_coords = util.get_centers_and_bounding_boxes(this_slice_seg)
166        point_prompts, point_labels, box_prompts, _ = prompt_generator(
167            segmentation=torch.from_numpy(this_slice_seg)[None, None].to(torch.float32),
168            bbox_coordinates=[box_coords[1]],
169        )
170
171        # Prompt-based segmentation on middle slice of the current object
172        output_slice = batched_inference(
173            predictor=predictor,
174            image=volume[slice_choice],
175            batch_size=1,
176            boxes=box_prompts.numpy() if isinstance(box_prompts, torch.Tensor) else box_prompts,
177            points=point_prompts.numpy() if isinstance(point_prompts, torch.Tensor) else point_prompts,
178            point_labels=point_labels.numpy() if isinstance(point_labels, torch.Tensor) else point_labels,
179            verbose_embeddings=verbose,
180        )
181        output_seg = np.zeros_like(ground_truth)
182        output_seg[slice_choice][output_slice == 1] = 1
183
184        # Segment the object in the entire volume with the specified segmented slice
185        this_seg, _ = segment_mask_in_volume(
186            segmentation=output_seg,
187            predictor=predictor,
188            image_embeddings=embeddings,
189            segmented_slices=np.array(slice_choice),
190            stop_lower=False, stop_upper=False,
191            iou_threshold=iou_threshold,
192            projection=projection,
193            box_extension=box_extension,
194            verbose=verbose,
195        )
196
197        # Store the entire segmented object
198        final_segmentation[this_seg == 1] = label_id
199
200    # Save the volumetric segmentation
201    if save_path is not None:
202        if _segmentation_completed:
203            final_segmentation = imageio.imread(save_path)
204        else:
205            imageio.imwrite(save_path, final_segmentation, compression="zlib")
206
207    # Evaluate the volumetric segmentation
208    if skipped_label_ids:
209        curr_gt = ground_truth.copy()
210        curr_gt[np.isin(curr_gt, skipped_label_ids)] = 0
211    else:
212        curr_gt = ground_truth
213
214    if evaluation_metric == "sa":
215        msa, sa = mean_segmentation_accuracy(
216            segmentation=final_segmentation, groundtruth=curr_gt, return_accuracies=True
217        )
218        results = {"mSA": msa, "SA50": sa[0], "SA75": sa[5]}
219
220    elif evaluation_metric == "dice":
221        # Calculate overall dice score (by binarizing all labels).
222        dice = dice_score(segmentation=final_segmentation, groundtruth=curr_gt)
223        results = {"Dice": dice}
224
225    elif evaluation_metric == "dice_per_class":
226        # Calculate dice per class.
227        dice = [
228            dice_score(segmentation=(final_segmentation == i), groundtruth=(curr_gt == i))
229            for i in np.unique(curr_gt)[1:]
230        ]
231        dice = np.mean(dice)
232        results = {"Dice": dice}
233
234    else:
235        raise ValueError(
236            f"'{evaluation_metric}' is not a supported evaluation metrics. "
237            "Please choose 'sa' / 'dice' / 'dice_per_class'."
238        )
239
240    if return_segmentation:
241        return results, final_segmentation
242    else:
243        return results

Segment all objects in a volume by prompt-based segmentation in one slice per object.

This function first segments each object in the respective specified slice using interactive (prompt-based) segmentation functionality. Then it segments the particular object in the remaining slices in the volume.

Arguments:
  • volume: The input volume.
  • ground_truth: The label volume with instance segmentations.
  • model_type: Choice of segment anything model.
  • checkpoint_path: Path to the model checkpoint.
  • embedding_path: Path to cache the computed embeddings.
  • save_path: Path to store the segmentations.
  • iou_threshold: The criterion to decide whether to link the objects in the consecutive slice's segmentation.
  • projection: The projection (prompting) method to generate prompts for consecutive slices.
  • box_extension: Extension factor for increasing the box size after projection.
  • device: The selected device for computation.
  • interactive_seg_mode: Method for guiding prompt-based instance segmentation.
  • verbose: Whether to get the trace for projected segmentations.
  • return_segmentation: Whether to return the segmented volume.
  • min_size: The minimal size for evaluating an object in the ground-truth. The size is measured within the central slice.
  • evaluation_metric: The choice of supported metric to evaluate predictions.
Returns:

A dictionary of results with all desired metrics. Optional segmentation result (controlled by return_segmentation argument).