micro_sam.evaluation.multi_dimensional_segmentation
1import os 2from tqdm import tqdm 3from itertools import product 4from typing import Union, Tuple, Optional, List, Dict, Literal 5 6import numpy as np 7import pandas as pd 8from math import floor 9import imageio.v3 as imageio 10 11import torch 12 13from elf.evaluation import mean_segmentation_accuracy, dice_score 14 15from .. import util 16from ..inference import batched_inference 17from ..prompt_generators import PointAndBoxPromptGenerator 18from ..multi_dimensional_segmentation import segment_mask_in_volume 19from ..evaluation.instance_segmentation import _get_range_of_search_values, evaluate_instance_segmentation_grid_search 20 21 22def default_grid_search_values_multi_dimensional_segmentation( 23 iou_threshold_values: Optional[List[float]] = None, 24 projection_method_values: Optional[Union[str, dict]] = None, 25 box_extension_values: Optional[Union[float, int]] = None 26) -> Dict[str, List]: 27 """Default grid-search parameters for multi-dimensional prompt-based instance segmentation. 28 29 Args: 30 iou_threshold_values: The values for `iou_threshold` used in the grid-search. 31 By default values in the range from 0.5 to 0.9 with a stepsize of 0.1 will be used. 32 projection_method_values: The values for `projection` method used in the grid-search. 33 By default the values `mask`, `points`, `box`, `points_and_mask` and `single_point` are used. 34 box_extension_values: The values for `box_extension` used in the grid-search. 35 By default values in the range from 0 to 0.25 with a stepsize of 0.025 will be used. 36 37 Returns: 38 The values for grid search. 39 """ 40 if iou_threshold_values is None: 41 iou_threshold_values = _get_range_of_search_values([0.5, 0.9], step=0.1) 42 43 if projection_method_values is None: 44 projection_method_values = [ 45 "mask", "points", "box", "points_and_mask", "single_point" 46 ] 47 48 if box_extension_values is None: 49 box_extension_values = _get_range_of_search_values([0, 0.25], step=0.025) 50 51 return { 52 "iou_threshold": iou_threshold_values, 53 "projection": projection_method_values, 54 "box_extension": box_extension_values 55 } 56 57 58@torch.no_grad() 59def segment_slices_from_ground_truth( 60 volume: np.ndarray, 61 ground_truth: np.ndarray, 62 model_type: str, 63 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 64 embedding_path: Optional[Union[str, os.PathLike]] = None, 65 save_path: Optional[Union[str, os.PathLike]] = None, 66 iou_threshold: float = 0.8, 67 projection: Union[str, dict] = "mask", 68 box_extension: Union[float, int] = 0.025, 69 device: Union[str, torch.device] = None, 70 interactive_seg_mode: str = "box", 71 verbose: bool = False, 72 return_segmentation: bool = False, 73 min_size: int = 0, 74 evaluation_metric: Literal["sa", "dice"] = "sa", 75) -> Union[Dict, Tuple[Dict, np.ndarray]]: 76 """Segment all objects in a volume by prompt-based segmentation in one slice per object. 77 78 This function first segments each object in the respective specified slice using interactive 79 (prompt-based) segmentation functionality. Then it segments the particular object in the 80 remaining slices in the volume. 81 82 Args: 83 volume: The input volume. 84 ground_truth: The label volume with instance segmentations. 85 model_type: Choice of segment anything model. 86 checkpoint_path: Path to the model checkpoint. 87 embedding_path: Path to cache the computed embeddings. 88 save_path: Path to store the segmentations. 89 iou_threshold: The criterion to decide whether to link the objects in the consecutive slice's segmentation. 90 projection: The projection (prompting) method to generate prompts for consecutive slices. 91 box_extension: Extension factor for increasing the box size after projection. 92 device: The selected device for computation. 93 interactive_seg_mode: Method for guiding prompt-based instance segmentation. 94 verbose: Whether to get the trace for projected segmentations. 95 return_segmentation: Whether to return the segmented volume. 96 min_size: The minimal size for evaluating an object in the ground-truth. 97 The size is measured within the central slice. 98 evaluation_metric: The choice of supported metric to evaluate predictions. 99 100 Returns: 101 A dictionary of results with all desired metrics. 102 Optional segmentation result (controlled by `return_segmentation` argument). 103 """ 104 assert volume.ndim == 3 105 106 predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, device=device) 107 108 # Compute the image embeddings 109 embeddings = util.precompute_image_embeddings( 110 predictor=predictor, input_=volume, save_path=embedding_path, ndim=3, verbose=verbose, 111 ) 112 113 # Compute instance ids (without the background) 114 label_ids = np.unique(ground_truth)[1:] 115 assert len(label_ids) > 0, "There are no objects to perform volumetric segmentation." 116 117 # Create an empty volume to store incoming segmentations 118 final_segmentation = np.zeros_like(ground_truth) 119 120 _segmentation_completed = False 121 if save_path is not None and os.path.exists(save_path): 122 _segmentation_completed = True # We avoid rerunning the segmentation if it is completed. 123 124 skipped_label_ids = [] 125 for label_id in tqdm(label_ids, desc="Segmenting per object in the volume", disable=not verbose): 126 # Binary label volume per instance (also referred to as object) 127 this_seg = (ground_truth == label_id).astype("int") 128 129 # Let's search the slices where we have the current object 130 slice_range = np.where(this_seg)[0] 131 132 # Choose the middle slice of the current object for prompt-based segmentation 133 slice_range = (slice_range.min(), slice_range.max()) 134 slice_choice = floor(np.mean(slice_range)) 135 this_slice_seg = this_seg[slice_choice] 136 if min_size > 0 and this_slice_seg.sum() < min_size: 137 skipped_label_ids.append(label_id) 138 continue 139 140 if _segmentation_completed: 141 continue 142 143 if verbose: 144 print(f"The object with id {label_id} lies in slice range: {slice_range}") 145 146 # Prompts for segmentation for the current slice 147 if interactive_seg_mode == "points": 148 _get_points, _get_box = True, False 149 elif interactive_seg_mode == "box": 150 _get_points, _get_box = False, True 151 else: 152 raise ValueError( 153 f"The provided interactive prompting '{interactive_seg_mode}' for the first slice isn't supported. " 154 "Please choose from 'box' / 'points'." 155 ) 156 157 prompt_generator = PointAndBoxPromptGenerator( 158 n_positive_points=1 if _get_points else 0, 159 n_negative_points=1 if _get_points else 0, 160 dilation_strength=10, 161 get_point_prompts=_get_points, 162 get_box_prompts=_get_box 163 ) 164 _, box_coords = util.get_centers_and_bounding_boxes(this_slice_seg) 165 point_prompts, point_labels, box_prompts, _ = prompt_generator( 166 segmentation=torch.from_numpy(this_slice_seg)[None, None].to(torch.float32), 167 bbox_coordinates=[box_coords[1]], 168 ) 169 170 # Prompt-based segmentation on middle slice of the current object 171 output_slice = batched_inference( 172 predictor=predictor, 173 image=volume[slice_choice], 174 batch_size=1, 175 boxes=box_prompts.numpy() if isinstance(box_prompts, torch.Tensor) else box_prompts, 176 points=point_prompts.numpy() if isinstance(point_prompts, torch.Tensor) else point_prompts, 177 point_labels=point_labels.numpy() if isinstance(point_labels, torch.Tensor) else point_labels, 178 verbose_embeddings=verbose, 179 ) 180 output_seg = np.zeros_like(ground_truth) 181 output_seg[slice_choice][output_slice == 1] = 1 182 183 # Segment the object in the entire volume with the specified segmented slice 184 this_seg, _ = segment_mask_in_volume( 185 segmentation=output_seg, 186 predictor=predictor, 187 image_embeddings=embeddings, 188 segmented_slices=np.array(slice_choice), 189 stop_lower=False, stop_upper=False, 190 iou_threshold=iou_threshold, 191 projection=projection, 192 box_extension=box_extension, 193 verbose=verbose, 194 ) 195 196 # Store the entire segmented object 197 final_segmentation[this_seg == 1] = label_id 198 199 # Save the volumetric segmentation 200 if save_path is not None: 201 if _segmentation_completed: 202 final_segmentation = imageio.imread(save_path) 203 else: 204 imageio.imwrite(save_path, final_segmentation, compression="zlib") 205 206 # Evaluate the volumetric segmentation 207 if skipped_label_ids: 208 curr_gt = ground_truth.copy() 209 curr_gt[np.isin(curr_gt, skipped_label_ids)] = 0 210 else: 211 curr_gt = ground_truth 212 213 if evaluation_metric == "sa": 214 msa, sa = mean_segmentation_accuracy( 215 segmentation=final_segmentation, groundtruth=curr_gt, return_accuracies=True 216 ) 217 results = {"mSA": msa, "SA50": sa[0], "SA75": sa[5]} 218 219 elif evaluation_metric == "dice": 220 # Calculate overall dice score (by binarizing all labels). 221 dice = dice_score(segmentation=final_segmentation, groundtruth=curr_gt) 222 results = {"Dice": dice} 223 224 elif evaluation_metric == "dice_per_class": 225 # Calculate dice per class. 226 dice = [ 227 dice_score(segmentation=(final_segmentation == i), groundtruth=(curr_gt == i)) 228 for i in np.unique(curr_gt)[1:] 229 ] 230 dice = np.mean(dice) 231 results = {"Dice": dice} 232 233 else: 234 raise ValueError( 235 f"'{evaluation_metric}' is not a supported evaluation metrics. " 236 "Please choose 'sa' / 'dice' / 'dice_per_class'." 237 ) 238 239 if return_segmentation: 240 return results, final_segmentation 241 else: 242 return results 243 244 245def _get_best_parameters_from_grid_search_combinations( 246 result_dir, best_params_path, grid_search_values, evaluation_metric, 247): 248 if os.path.exists(best_params_path): 249 print("The best parameters are already saved at:", best_params_path) 250 return 251 252 criterion = "mSA" if evaluation_metric == "sa" else "Dice" 253 best_kwargs, best_metric = evaluate_instance_segmentation_grid_search( 254 result_dir=result_dir, grid_search_parameters=list(grid_search_values.keys()), criterion=criterion, 255 ) 256 257 # let's save the best parameters 258 best_kwargs[criterion] = best_metric 259 best_param_df = pd.DataFrame.from_dict([best_kwargs]) 260 best_param_df.to_csv(best_params_path) 261 262 best_param_str = ", ".join(f"{k} = {v}" for k, v in best_kwargs.items()) 263 print("Best grid-search result:", best_metric, "with parmeters:\n", best_param_str) 264 265 266def run_multi_dimensional_segmentation_grid_search( 267 volume: np.ndarray, 268 ground_truth: np.ndarray, 269 model_type: str, 270 checkpoint_path: Union[str, os.PathLike], 271 embedding_path: Optional[Union[str, os.PathLike]], 272 result_dir: Union[str, os.PathLike], 273 interactive_seg_mode: str = "box", 274 verbose: bool = False, 275 grid_search_values: Optional[Dict[str, List]] = None, 276 min_size: int = 0, 277 evaluation_metric: Literal["sa", "dice"] = "sa", 278) -> str: 279 """Run grid search for prompt-based multi-dimensional instance segmentation. 280 281 The parameters and their respective value ranges for the grid search are specified via the 282 `grid_search_values` argument. For example, to run a grid search over the parameters `iou_threshold`, 283 `projection` and `box_extension`, you can pass the following: 284 ```python 285 grid_search_values = { 286 "iou_threshold": [0.5, 0.6, 0.7, 0.8, 0.9], 287 "projection": ["mask", "box", "points"], 288 "box_extension": [0, 0.1, 0.2, 0.3, 0.4, 0,5], 289 } 290 ``` 291 All combinations of the parameters will be checked. 292 If passed None, the function `default_grid_search_values_multi_dimensional_segmentation` is used 293 to get the default grid search parameters for the instance segmentation method. 294 295 Args: 296 volume: The input volume. 297 ground_truth: The label volume with instance segmentations. 298 model_type: Choice of segment anything model. 299 checkpoint_path: Path to the model checkpoint. 300 embedding_path: Path to cache the computed embeddings. 301 result_dir: Path to save the grid search results. 302 interactive_seg_mode: Method for guiding prompt-based instance segmentation. 303 verbose: Whether to get the trace for projected segmentations. 304 grid_search_values: The grid search values for parameters of the `segment_slices_from_ground_truth` function. 305 min_size: The minimal size for evaluating an object in the ground-truth. 306 The size is measured within the central slice. 307 evaluation_metric: The choice of metric for evaluating predictions. 308 309 Returns: 310 Filepath where the best parameters are saved. 311 """ 312 if grid_search_values is None: 313 grid_search_values = default_grid_search_values_multi_dimensional_segmentation() 314 315 assert len(grid_search_values.keys()) == 3, "There must be three grid-search parameters. See above for details." 316 317 os.makedirs(result_dir, exist_ok=True) 318 result_path = os.path.join(result_dir, "all_grid_search_results.csv") 319 best_params_path = os.path.join(result_dir, "grid_search_params_multi_dimensional_segmentation.csv") 320 if os.path.exists(result_path): 321 _get_best_parameters_from_grid_search_combinations( 322 result_dir, best_params_path, grid_search_values, evaluation_metric 323 ) 324 return best_params_path 325 326 # Compute all combinations of grid search values. 327 gs_combinations = product(*grid_search_values.values()) 328 329 # Map each combination back to a valid kwarg input. 330 gs_combinations = [ 331 {k: v for k, v in zip(grid_search_values.keys(), vals)} for vals in gs_combinations 332 ] 333 334 net_list = [] 335 for gs_kwargs in tqdm(gs_combinations, desc="Run grid-search for multi-dimensional segmentation"): 336 results = segment_slices_from_ground_truth( 337 volume=volume, 338 ground_truth=ground_truth, 339 model_type=model_type, 340 checkpoint_path=checkpoint_path, 341 embedding_path=embedding_path, 342 interactive_seg_mode=interactive_seg_mode, 343 verbose=verbose, 344 return_segmentation=False, 345 min_size=min_size, 346 evaluation_metric=evaluation_metric, 347 **gs_kwargs 348 ) 349 350 result_dict = {**results, **gs_kwargs} 351 tmp_df = pd.DataFrame([result_dict]) 352 net_list.append(tmp_df) 353 354 res_df = pd.concat(net_list, ignore_index=True) 355 res_df.to_csv(result_path) 356 357 _get_best_parameters_from_grid_search_combinations( 358 result_dir, best_params_path, grid_search_values, evaluation_metric 359 ) 360 print("The best grid-search parameters have been computed and stored at:", best_params_path) 361 return best_params_path
23def default_grid_search_values_multi_dimensional_segmentation( 24 iou_threshold_values: Optional[List[float]] = None, 25 projection_method_values: Optional[Union[str, dict]] = None, 26 box_extension_values: Optional[Union[float, int]] = None 27) -> Dict[str, List]: 28 """Default grid-search parameters for multi-dimensional prompt-based instance segmentation. 29 30 Args: 31 iou_threshold_values: The values for `iou_threshold` used in the grid-search. 32 By default values in the range from 0.5 to 0.9 with a stepsize of 0.1 will be used. 33 projection_method_values: The values for `projection` method used in the grid-search. 34 By default the values `mask`, `points`, `box`, `points_and_mask` and `single_point` are used. 35 box_extension_values: The values for `box_extension` used in the grid-search. 36 By default values in the range from 0 to 0.25 with a stepsize of 0.025 will be used. 37 38 Returns: 39 The values for grid search. 40 """ 41 if iou_threshold_values is None: 42 iou_threshold_values = _get_range_of_search_values([0.5, 0.9], step=0.1) 43 44 if projection_method_values is None: 45 projection_method_values = [ 46 "mask", "points", "box", "points_and_mask", "single_point" 47 ] 48 49 if box_extension_values is None: 50 box_extension_values = _get_range_of_search_values([0, 0.25], step=0.025) 51 52 return { 53 "iou_threshold": iou_threshold_values, 54 "projection": projection_method_values, 55 "box_extension": box_extension_values 56 }
Default grid-search parameters for multi-dimensional prompt-based instance segmentation.
Arguments:
- iou_threshold_values: The values for
iou_threshold
used in the grid-search. By default values in the range from 0.5 to 0.9 with a stepsize of 0.1 will be used. - projection_method_values: The values for
projection
method used in the grid-search. By default the valuesmask
,points
,box
,points_and_mask
andsingle_point
are used. - box_extension_values: The values for
box_extension
used in the grid-search. By default values in the range from 0 to 0.25 with a stepsize of 0.025 will be used.
Returns:
The values for grid search.
59@torch.no_grad() 60def segment_slices_from_ground_truth( 61 volume: np.ndarray, 62 ground_truth: np.ndarray, 63 model_type: str, 64 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 65 embedding_path: Optional[Union[str, os.PathLike]] = None, 66 save_path: Optional[Union[str, os.PathLike]] = None, 67 iou_threshold: float = 0.8, 68 projection: Union[str, dict] = "mask", 69 box_extension: Union[float, int] = 0.025, 70 device: Union[str, torch.device] = None, 71 interactive_seg_mode: str = "box", 72 verbose: bool = False, 73 return_segmentation: bool = False, 74 min_size: int = 0, 75 evaluation_metric: Literal["sa", "dice"] = "sa", 76) -> Union[Dict, Tuple[Dict, np.ndarray]]: 77 """Segment all objects in a volume by prompt-based segmentation in one slice per object. 78 79 This function first segments each object in the respective specified slice using interactive 80 (prompt-based) segmentation functionality. Then it segments the particular object in the 81 remaining slices in the volume. 82 83 Args: 84 volume: The input volume. 85 ground_truth: The label volume with instance segmentations. 86 model_type: Choice of segment anything model. 87 checkpoint_path: Path to the model checkpoint. 88 embedding_path: Path to cache the computed embeddings. 89 save_path: Path to store the segmentations. 90 iou_threshold: The criterion to decide whether to link the objects in the consecutive slice's segmentation. 91 projection: The projection (prompting) method to generate prompts for consecutive slices. 92 box_extension: Extension factor for increasing the box size after projection. 93 device: The selected device for computation. 94 interactive_seg_mode: Method for guiding prompt-based instance segmentation. 95 verbose: Whether to get the trace for projected segmentations. 96 return_segmentation: Whether to return the segmented volume. 97 min_size: The minimal size for evaluating an object in the ground-truth. 98 The size is measured within the central slice. 99 evaluation_metric: The choice of supported metric to evaluate predictions. 100 101 Returns: 102 A dictionary of results with all desired metrics. 103 Optional segmentation result (controlled by `return_segmentation` argument). 104 """ 105 assert volume.ndim == 3 106 107 predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, device=device) 108 109 # Compute the image embeddings 110 embeddings = util.precompute_image_embeddings( 111 predictor=predictor, input_=volume, save_path=embedding_path, ndim=3, verbose=verbose, 112 ) 113 114 # Compute instance ids (without the background) 115 label_ids = np.unique(ground_truth)[1:] 116 assert len(label_ids) > 0, "There are no objects to perform volumetric segmentation." 117 118 # Create an empty volume to store incoming segmentations 119 final_segmentation = np.zeros_like(ground_truth) 120 121 _segmentation_completed = False 122 if save_path is not None and os.path.exists(save_path): 123 _segmentation_completed = True # We avoid rerunning the segmentation if it is completed. 124 125 skipped_label_ids = [] 126 for label_id in tqdm(label_ids, desc="Segmenting per object in the volume", disable=not verbose): 127 # Binary label volume per instance (also referred to as object) 128 this_seg = (ground_truth == label_id).astype("int") 129 130 # Let's search the slices where we have the current object 131 slice_range = np.where(this_seg)[0] 132 133 # Choose the middle slice of the current object for prompt-based segmentation 134 slice_range = (slice_range.min(), slice_range.max()) 135 slice_choice = floor(np.mean(slice_range)) 136 this_slice_seg = this_seg[slice_choice] 137 if min_size > 0 and this_slice_seg.sum() < min_size: 138 skipped_label_ids.append(label_id) 139 continue 140 141 if _segmentation_completed: 142 continue 143 144 if verbose: 145 print(f"The object with id {label_id} lies in slice range: {slice_range}") 146 147 # Prompts for segmentation for the current slice 148 if interactive_seg_mode == "points": 149 _get_points, _get_box = True, False 150 elif interactive_seg_mode == "box": 151 _get_points, _get_box = False, True 152 else: 153 raise ValueError( 154 f"The provided interactive prompting '{interactive_seg_mode}' for the first slice isn't supported. " 155 "Please choose from 'box' / 'points'." 156 ) 157 158 prompt_generator = PointAndBoxPromptGenerator( 159 n_positive_points=1 if _get_points else 0, 160 n_negative_points=1 if _get_points else 0, 161 dilation_strength=10, 162 get_point_prompts=_get_points, 163 get_box_prompts=_get_box 164 ) 165 _, box_coords = util.get_centers_and_bounding_boxes(this_slice_seg) 166 point_prompts, point_labels, box_prompts, _ = prompt_generator( 167 segmentation=torch.from_numpy(this_slice_seg)[None, None].to(torch.float32), 168 bbox_coordinates=[box_coords[1]], 169 ) 170 171 # Prompt-based segmentation on middle slice of the current object 172 output_slice = batched_inference( 173 predictor=predictor, 174 image=volume[slice_choice], 175 batch_size=1, 176 boxes=box_prompts.numpy() if isinstance(box_prompts, torch.Tensor) else box_prompts, 177 points=point_prompts.numpy() if isinstance(point_prompts, torch.Tensor) else point_prompts, 178 point_labels=point_labels.numpy() if isinstance(point_labels, torch.Tensor) else point_labels, 179 verbose_embeddings=verbose, 180 ) 181 output_seg = np.zeros_like(ground_truth) 182 output_seg[slice_choice][output_slice == 1] = 1 183 184 # Segment the object in the entire volume with the specified segmented slice 185 this_seg, _ = segment_mask_in_volume( 186 segmentation=output_seg, 187 predictor=predictor, 188 image_embeddings=embeddings, 189 segmented_slices=np.array(slice_choice), 190 stop_lower=False, stop_upper=False, 191 iou_threshold=iou_threshold, 192 projection=projection, 193 box_extension=box_extension, 194 verbose=verbose, 195 ) 196 197 # Store the entire segmented object 198 final_segmentation[this_seg == 1] = label_id 199 200 # Save the volumetric segmentation 201 if save_path is not None: 202 if _segmentation_completed: 203 final_segmentation = imageio.imread(save_path) 204 else: 205 imageio.imwrite(save_path, final_segmentation, compression="zlib") 206 207 # Evaluate the volumetric segmentation 208 if skipped_label_ids: 209 curr_gt = ground_truth.copy() 210 curr_gt[np.isin(curr_gt, skipped_label_ids)] = 0 211 else: 212 curr_gt = ground_truth 213 214 if evaluation_metric == "sa": 215 msa, sa = mean_segmentation_accuracy( 216 segmentation=final_segmentation, groundtruth=curr_gt, return_accuracies=True 217 ) 218 results = {"mSA": msa, "SA50": sa[0], "SA75": sa[5]} 219 220 elif evaluation_metric == "dice": 221 # Calculate overall dice score (by binarizing all labels). 222 dice = dice_score(segmentation=final_segmentation, groundtruth=curr_gt) 223 results = {"Dice": dice} 224 225 elif evaluation_metric == "dice_per_class": 226 # Calculate dice per class. 227 dice = [ 228 dice_score(segmentation=(final_segmentation == i), groundtruth=(curr_gt == i)) 229 for i in np.unique(curr_gt)[1:] 230 ] 231 dice = np.mean(dice) 232 results = {"Dice": dice} 233 234 else: 235 raise ValueError( 236 f"'{evaluation_metric}' is not a supported evaluation metrics. " 237 "Please choose 'sa' / 'dice' / 'dice_per_class'." 238 ) 239 240 if return_segmentation: 241 return results, final_segmentation 242 else: 243 return results
Segment all objects in a volume by prompt-based segmentation in one slice per object.
This function first segments each object in the respective specified slice using interactive (prompt-based) segmentation functionality. Then it segments the particular object in the remaining slices in the volume.
Arguments:
- volume: The input volume.
- ground_truth: The label volume with instance segmentations.
- model_type: Choice of segment anything model.
- checkpoint_path: Path to the model checkpoint.
- embedding_path: Path to cache the computed embeddings.
- save_path: Path to store the segmentations.
- iou_threshold: The criterion to decide whether to link the objects in the consecutive slice's segmentation.
- projection: The projection (prompting) method to generate prompts for consecutive slices.
- box_extension: Extension factor for increasing the box size after projection.
- device: The selected device for computation.
- interactive_seg_mode: Method for guiding prompt-based instance segmentation.
- verbose: Whether to get the trace for projected segmentations.
- return_segmentation: Whether to return the segmented volume.
- min_size: The minimal size for evaluating an object in the ground-truth. The size is measured within the central slice.
- evaluation_metric: The choice of supported metric to evaluate predictions.
Returns:
A dictionary of results with all desired metrics. Optional segmentation result (controlled by
return_segmentation
argument).
267def run_multi_dimensional_segmentation_grid_search( 268 volume: np.ndarray, 269 ground_truth: np.ndarray, 270 model_type: str, 271 checkpoint_path: Union[str, os.PathLike], 272 embedding_path: Optional[Union[str, os.PathLike]], 273 result_dir: Union[str, os.PathLike], 274 interactive_seg_mode: str = "box", 275 verbose: bool = False, 276 grid_search_values: Optional[Dict[str, List]] = None, 277 min_size: int = 0, 278 evaluation_metric: Literal["sa", "dice"] = "sa", 279) -> str: 280 """Run grid search for prompt-based multi-dimensional instance segmentation. 281 282 The parameters and their respective value ranges for the grid search are specified via the 283 `grid_search_values` argument. For example, to run a grid search over the parameters `iou_threshold`, 284 `projection` and `box_extension`, you can pass the following: 285 ```python 286 grid_search_values = { 287 "iou_threshold": [0.5, 0.6, 0.7, 0.8, 0.9], 288 "projection": ["mask", "box", "points"], 289 "box_extension": [0, 0.1, 0.2, 0.3, 0.4, 0,5], 290 } 291 ``` 292 All combinations of the parameters will be checked. 293 If passed None, the function `default_grid_search_values_multi_dimensional_segmentation` is used 294 to get the default grid search parameters for the instance segmentation method. 295 296 Args: 297 volume: The input volume. 298 ground_truth: The label volume with instance segmentations. 299 model_type: Choice of segment anything model. 300 checkpoint_path: Path to the model checkpoint. 301 embedding_path: Path to cache the computed embeddings. 302 result_dir: Path to save the grid search results. 303 interactive_seg_mode: Method for guiding prompt-based instance segmentation. 304 verbose: Whether to get the trace for projected segmentations. 305 grid_search_values: The grid search values for parameters of the `segment_slices_from_ground_truth` function. 306 min_size: The minimal size for evaluating an object in the ground-truth. 307 The size is measured within the central slice. 308 evaluation_metric: The choice of metric for evaluating predictions. 309 310 Returns: 311 Filepath where the best parameters are saved. 312 """ 313 if grid_search_values is None: 314 grid_search_values = default_grid_search_values_multi_dimensional_segmentation() 315 316 assert len(grid_search_values.keys()) == 3, "There must be three grid-search parameters. See above for details." 317 318 os.makedirs(result_dir, exist_ok=True) 319 result_path = os.path.join(result_dir, "all_grid_search_results.csv") 320 best_params_path = os.path.join(result_dir, "grid_search_params_multi_dimensional_segmentation.csv") 321 if os.path.exists(result_path): 322 _get_best_parameters_from_grid_search_combinations( 323 result_dir, best_params_path, grid_search_values, evaluation_metric 324 ) 325 return best_params_path 326 327 # Compute all combinations of grid search values. 328 gs_combinations = product(*grid_search_values.values()) 329 330 # Map each combination back to a valid kwarg input. 331 gs_combinations = [ 332 {k: v for k, v in zip(grid_search_values.keys(), vals)} for vals in gs_combinations 333 ] 334 335 net_list = [] 336 for gs_kwargs in tqdm(gs_combinations, desc="Run grid-search for multi-dimensional segmentation"): 337 results = segment_slices_from_ground_truth( 338 volume=volume, 339 ground_truth=ground_truth, 340 model_type=model_type, 341 checkpoint_path=checkpoint_path, 342 embedding_path=embedding_path, 343 interactive_seg_mode=interactive_seg_mode, 344 verbose=verbose, 345 return_segmentation=False, 346 min_size=min_size, 347 evaluation_metric=evaluation_metric, 348 **gs_kwargs 349 ) 350 351 result_dict = {**results, **gs_kwargs} 352 tmp_df = pd.DataFrame([result_dict]) 353 net_list.append(tmp_df) 354 355 res_df = pd.concat(net_list, ignore_index=True) 356 res_df.to_csv(result_path) 357 358 _get_best_parameters_from_grid_search_combinations( 359 result_dir, best_params_path, grid_search_values, evaluation_metric 360 ) 361 print("The best grid-search parameters have been computed and stored at:", best_params_path) 362 return best_params_path
Run grid search for prompt-based multi-dimensional instance segmentation.
The parameters and their respective value ranges for the grid search are specified via the
grid_search_values
argument. For example, to run a grid search over the parameters iou_threshold
,
projection
and box_extension
, you can pass the following:
grid_search_values = {
"iou_threshold": [0.5, 0.6, 0.7, 0.8, 0.9],
"projection": ["mask", "box", "points"],
"box_extension": [0, 0.1, 0.2, 0.3, 0.4, 0,5],
}
All combinations of the parameters will be checked.
If passed None, the function default_grid_search_values_multi_dimensional_segmentation
is used
to get the default grid search parameters for the instance segmentation method.
Arguments:
- volume: The input volume.
- ground_truth: The label volume with instance segmentations.
- model_type: Choice of segment anything model.
- checkpoint_path: Path to the model checkpoint.
- embedding_path: Path to cache the computed embeddings.
- result_dir: Path to save the grid search results.
- interactive_seg_mode: Method for guiding prompt-based instance segmentation.
- verbose: Whether to get the trace for projected segmentations.
- grid_search_values: The grid search values for parameters of the
segment_slices_from_ground_truth
function. - min_size: The minimal size for evaluating an object in the ground-truth. The size is measured within the central slice.
- evaluation_metric: The choice of metric for evaluating predictions.
Returns:
Filepath where the best parameters are saved.