micro_sam.evaluation.multi_dimensional_segmentation
1import os 2import numpy as np 3import pandas as pd 4from tqdm import tqdm 5from math import floor 6from itertools import product 7from typing import Union, Tuple, Optional, List, Dict, Literal 8 9import imageio.v3 as imageio 10 11import torch 12 13from elf.evaluation import mean_segmentation_accuracy, dice_score 14 15from .. import util 16from ..inference import batched_inference 17from ..prompt_generators import PointAndBoxPromptGenerator 18from ..multi_dimensional_segmentation import segment_mask_in_volume 19from ..evaluation.instance_segmentation import _get_range_of_search_values, evaluate_instance_segmentation_grid_search 20 21 22def default_grid_search_values_multi_dimensional_segmentation( 23 iou_threshold_values: Optional[List[float]] = None, 24 projection_method_values: Optional[Union[str, dict]] = None, 25 box_extension_values: Optional[Union[float, int]] = None 26) -> Dict[str, List]: 27 """Default grid-search parameters for multi-dimensional prompt-based instance segmentation. 28 29 Args: 30 iou_threshold_values: The values for `iou_threshold` used in the grid-search. 31 By default values in the range from 0.5 to 0.9 with a stepsize of 0.1 will be used. 32 projection_method_values: The values for `projection` method used in the grid-search. 33 By default the values `mask`, `points`, `box`, `points_and_mask` and `single_point` are used. 34 box_extension_values: The values for `box_extension` used in the grid-search. 35 By default values in the range from 0 to 0.25 with a stepsize of 0.025 will be used. 36 37 Returns: 38 The values for grid search. 39 """ 40 if iou_threshold_values is None: 41 iou_threshold_values = _get_range_of_search_values([0.5, 0.9], step=0.1) 42 43 if projection_method_values is None: 44 projection_method_values = [ 45 "mask", "points", "box", "points_and_mask", "single_point" 46 ] 47 48 if box_extension_values is None: 49 box_extension_values = _get_range_of_search_values([0, 0.25], step=0.025) 50 51 return { 52 "iou_threshold": iou_threshold_values, 53 "projection": projection_method_values, 54 "box_extension": box_extension_values 55 } 56 57 58@torch.no_grad() 59def segment_slices_from_ground_truth( 60 volume: np.ndarray, 61 ground_truth: np.ndarray, 62 model_type: str, 63 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 64 embedding_path: Optional[Union[str, os.PathLike]] = None, 65 save_path: Optional[Union[str, os.PathLike]] = None, 66 iou_threshold: float = 0.8, 67 projection: Union[str, dict] = "mask", 68 box_extension: Union[float, int] = 0.025, 69 device: Union[str, torch.device] = None, 70 interactive_seg_mode: str = "box", 71 verbose: bool = False, 72 return_segmentation: bool = False, 73 min_size: int = 0, 74 evaluation_metric: Literal["sa", "dice"] = "sa", 75) -> Union[float, Tuple[np.ndarray, float]]: 76 """Segment all objects in a volume by prompt-based segmentation in one slice per object. 77 78 This function first segments each object in the respective specified slice using interactive 79 (prompt-based) segmentation functionality. Then it segments the particular object in the 80 remaining slices in the volume. 81 82 Args: 83 volume: The input volume. 84 ground_truth: The label volume with instance segmentations. 85 model_type: Choice of segment anything model. 86 checkpoint_path: Path to the model checkpoint. 87 embedding_path: Path to cache the computed embeddings. 88 save_path: Path to store the segmentations. 89 iou_threshold: The criterion to decide whether to link the objects in the consecutive slice's segmentation. 90 projection: The projection (prompting) method to generate prompts for consecutive slices. 91 box_extension: Extension factor for increasing the box size after projection. 92 device: The selected device for computation. 93 interactive_seg_mode: Method for guiding prompt-based instance segmentation. 94 verbose: Whether to get the trace for projected segmentations. 95 return_segmentation: Whether to return the segmented volume. 96 min_size: The minimal size for evaluating an object in the ground-truth. 97 The size is measured within the central slice. 98 evaluation_metric: The choice of supported metric to evaluate predictions. 99 """ 100 assert volume.ndim == 3 101 102 predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, device=device) 103 104 # Compute the image embeddings 105 embeddings = util.precompute_image_embeddings( 106 predictor=predictor, input_=volume, save_path=embedding_path, ndim=3, verbose=verbose, 107 ) 108 109 # Compute instance ids (without the background) 110 label_ids = np.unique(ground_truth)[1:] 111 assert len(label_ids) > 0, "There are no objects to perform volumetric segmentation." 112 113 # Create an empty volume to store incoming segmentations 114 final_segmentation = np.zeros_like(ground_truth) 115 116 _segmentation_completed = False 117 if save_path is not None and os.path.exists(save_path): 118 _segmentation_completed = True # We avoid rerunning the segmentation if it is completed. 119 120 skipped_label_ids = [] 121 for label_id in tqdm(label_ids, desc="Segmenting per object in the volume", disable=not verbose): 122 # Binary label volume per instance (also referred to as object) 123 this_seg = (ground_truth == label_id).astype("int") 124 125 # Let's search the slices where we have the current object 126 slice_range = np.where(this_seg)[0] 127 128 # Choose the middle slice of the current object for prompt-based segmentation 129 slice_range = (slice_range.min(), slice_range.max()) 130 slice_choice = floor(np.mean(slice_range)) 131 this_slice_seg = this_seg[slice_choice] 132 if min_size > 0 and this_slice_seg.sum() < min_size: 133 skipped_label_ids.append(label_id) 134 continue 135 136 if _segmentation_completed: 137 continue 138 139 if verbose: 140 print(f"The object with id {label_id} lies in slice range: {slice_range}") 141 142 # Prompts for segmentation for the current slice 143 if interactive_seg_mode == "points": 144 _get_points, _get_box = True, False 145 elif interactive_seg_mode == "box": 146 _get_points, _get_box = False, True 147 else: 148 raise ValueError( 149 f"The provided interactive prompting '{interactive_seg_mode}' for the first slice isn't supported." 150 "Please choose from 'box' / 'points'." 151 ) 152 153 prompt_generator = PointAndBoxPromptGenerator( 154 n_positive_points=1 if _get_points else 0, 155 n_negative_points=1 if _get_points else 0, 156 dilation_strength=10, 157 get_point_prompts=_get_points, 158 get_box_prompts=_get_box 159 ) 160 _, box_coords = util.get_centers_and_bounding_boxes(this_slice_seg) 161 point_prompts, point_labels, box_prompts, _ = prompt_generator( 162 segmentation=torch.from_numpy(this_slice_seg)[None, None].to(torch.float32), 163 bbox_coordinates=[box_coords[1]], 164 ) 165 166 # Prompt-based segmentation on middle slice of the current object 167 output_slice = batched_inference( 168 predictor=predictor, 169 image=volume[slice_choice], 170 batch_size=1, 171 boxes=box_prompts.numpy() if isinstance(box_prompts, torch.Tensor) else box_prompts, 172 points=point_prompts.numpy() if isinstance(point_prompts, torch.Tensor) else point_prompts, 173 point_labels=point_labels.numpy() if isinstance(point_labels, torch.Tensor) else point_labels, 174 verbose_embeddings=verbose, 175 ) 176 output_seg = np.zeros_like(ground_truth) 177 output_seg[slice_choice][output_slice == 1] = 1 178 179 # Segment the object in the entire volume with the specified segmented slice 180 this_seg, _ = segment_mask_in_volume( 181 segmentation=output_seg, 182 predictor=predictor, 183 image_embeddings=embeddings, 184 segmented_slices=np.array(slice_choice), 185 stop_lower=False, stop_upper=False, 186 iou_threshold=iou_threshold, 187 projection=projection, 188 box_extension=box_extension, 189 verbose=verbose, 190 ) 191 192 # Store the entire segmented object 193 final_segmentation[this_seg == 1] = label_id 194 195 # Save the volumetric segmentation 196 if save_path is not None: 197 if _segmentation_completed: 198 final_segmentation = imageio.imread(save_path) 199 else: 200 imageio.imwrite(save_path, final_segmentation, compression="zlib") 201 202 # Evaluate the volumetric segmentation 203 if skipped_label_ids: 204 curr_gt = ground_truth.copy() 205 curr_gt[np.isin(curr_gt, skipped_label_ids)] = 0 206 else: 207 curr_gt = ground_truth 208 209 if evaluation_metric == "sa": 210 msa, sa = mean_segmentation_accuracy( 211 segmentation=final_segmentation, groundtruth=curr_gt, return_accuracies=True 212 ) 213 results = {"mSA": msa, "SA50": sa[0], "SA75": sa[5]} 214 elif evaluation_metric == "dice": 215 dice = dice_score(segmentation=final_segmentation, groundtruth=curr_gt) 216 results = {"Dice": dice} 217 else: 218 raise ValueError(f"'{evaluation_metric}' is not a supported evaluation metrics. Please choose 'sa' / 'dice'.") 219 220 if return_segmentation: 221 return results, final_segmentation 222 else: 223 return results 224 225 226def _get_best_parameters_from_grid_search_combinations( 227 result_dir, best_params_path, grid_search_values, evaluation_metric, 228): 229 if os.path.exists(best_params_path): 230 print("The best parameters are already saved at:", best_params_path) 231 return 232 233 criterion = "mSA" if evaluation_metric == "sa" else "Dice" 234 best_kwargs, best_metric = evaluate_instance_segmentation_grid_search( 235 result_dir=result_dir, grid_search_parameters=list(grid_search_values.keys()), criterion=criterion, 236 ) 237 238 # let's save the best parameters 239 best_kwargs[criterion] = best_metric 240 best_param_df = pd.DataFrame.from_dict([best_kwargs]) 241 best_param_df.to_csv(best_params_path) 242 243 best_param_str = ", ".join(f"{k} = {v}" for k, v in best_kwargs.items()) 244 print("Best grid-search result:", best_metric, "with parmeters:\n", best_param_str) 245 246 247def run_multi_dimensional_segmentation_grid_search( 248 volume: np.ndarray, 249 ground_truth: np.ndarray, 250 model_type: str, 251 checkpoint_path: Union[str, os.PathLike], 252 embedding_path: Optional[Union[str, os.PathLike]], 253 result_dir: Union[str, os.PathLike], 254 interactive_seg_mode: str = "box", 255 verbose: bool = False, 256 grid_search_values: Optional[Dict[str, List]] = None, 257 min_size: int = 0, 258 evaluation_metric: Literal["sa", "dice"] = "sa", 259): 260 """Run grid search for prompt-based multi-dimensional instance segmentation. 261 262 The parameters and their respective value ranges for the grid search are specified via the 263 `grid_search_values` argument. For example, to run a grid search over the parameters `iou_threshold`, 264 `projection` and `box_extension`, you can pass the following: 265 ``` 266 grid_search_values = { 267 "iou_threshold": [0.5, 0.6, 0.7, 0.8, 0.9], 268 "projection": ["mask", "box", "points"], 269 "box_extension": [0, 0.1, 0.2, 0.3, 0.4, 0,5], 270 } 271 ``` 272 All combinations of the parameters will be checked. 273 If passed None, the function `default_grid_search_values_multi_dimensional_segmentation` is used 274 to get the default grid search parameters for the instance segmentation method. 275 276 Args: 277 volume: The input volume. 278 ground_truth: The label volume with instance segmentations. 279 model_type: Choice of segment anything model. 280 checkpoint_path: Path to the model checkpoint. 281 embedding_path: Path to cache the computed embeddings. 282 result_dir: Path to save the grid search results. 283 interactive_seg_mode: Method for guiding prompt-based instance segmentation. 284 verbose: Whether to get the trace for projected segmentations. 285 grid_search_values: The grid search values for parameters of the `segment_slices_from_ground_truth` function. 286 min_size: The minimal size for evaluating an object in the ground-truth. 287 The size is measured within the central slice. 288 evaluation_metric: The choice of metric for evaluating predictions. 289 """ 290 if grid_search_values is None: 291 grid_search_values = default_grid_search_values_multi_dimensional_segmentation() 292 293 assert len(grid_search_values.keys()) == 3, "There must be three grid-search parameters. See above for details." 294 295 os.makedirs(result_dir, exist_ok=True) 296 result_path = os.path.join(result_dir, "all_grid_search_results.csv") 297 best_params_path = os.path.join(result_dir, "grid_search_params_multi_dimensional_segmentation.csv") 298 if os.path.exists(result_path): 299 _get_best_parameters_from_grid_search_combinations( 300 result_dir, best_params_path, grid_search_values, evaluation_metric 301 ) 302 return best_params_path 303 304 # Compute all combinations of grid search values. 305 gs_combinations = product(*grid_search_values.values()) 306 307 # Map each combination back to a valid kwarg input. 308 gs_combinations = [ 309 {k: v for k, v in zip(grid_search_values.keys(), vals)} for vals in gs_combinations 310 ] 311 312 net_list = [] 313 for gs_kwargs in tqdm(gs_combinations, desc="Run grid-search for multi-dimensional segmentation"): 314 results = segment_slices_from_ground_truth( 315 volume=volume, 316 ground_truth=ground_truth, 317 model_type=model_type, 318 checkpoint_path=checkpoint_path, 319 embedding_path=embedding_path, 320 interactive_seg_mode=interactive_seg_mode, 321 verbose=verbose, 322 return_segmentation=False, 323 min_size=min_size, 324 evaluation_metric=evaluation_metric, 325 **gs_kwargs 326 ) 327 328 result_dict = {**results, **gs_kwargs} 329 tmp_df = pd.DataFrame([result_dict]) 330 net_list.append(tmp_df) 331 332 res_df = pd.concat(net_list, ignore_index=True) 333 res_df.to_csv(result_path) 334 335 _get_best_parameters_from_grid_search_combinations( 336 result_dir, best_params_path, grid_search_values, evaluation_metric 337 ) 338 print("The best grid-search parameters have been computed and stored at:", best_params_path) 339 return best_params_path
23def default_grid_search_values_multi_dimensional_segmentation( 24 iou_threshold_values: Optional[List[float]] = None, 25 projection_method_values: Optional[Union[str, dict]] = None, 26 box_extension_values: Optional[Union[float, int]] = None 27) -> Dict[str, List]: 28 """Default grid-search parameters for multi-dimensional prompt-based instance segmentation. 29 30 Args: 31 iou_threshold_values: The values for `iou_threshold` used in the grid-search. 32 By default values in the range from 0.5 to 0.9 with a stepsize of 0.1 will be used. 33 projection_method_values: The values for `projection` method used in the grid-search. 34 By default the values `mask`, `points`, `box`, `points_and_mask` and `single_point` are used. 35 box_extension_values: The values for `box_extension` used in the grid-search. 36 By default values in the range from 0 to 0.25 with a stepsize of 0.025 will be used. 37 38 Returns: 39 The values for grid search. 40 """ 41 if iou_threshold_values is None: 42 iou_threshold_values = _get_range_of_search_values([0.5, 0.9], step=0.1) 43 44 if projection_method_values is None: 45 projection_method_values = [ 46 "mask", "points", "box", "points_and_mask", "single_point" 47 ] 48 49 if box_extension_values is None: 50 box_extension_values = _get_range_of_search_values([0, 0.25], step=0.025) 51 52 return { 53 "iou_threshold": iou_threshold_values, 54 "projection": projection_method_values, 55 "box_extension": box_extension_values 56 }
Default grid-search parameters for multi-dimensional prompt-based instance segmentation.
Arguments:
- iou_threshold_values: The values for
iou_threshold
used in the grid-search. By default values in the range from 0.5 to 0.9 with a stepsize of 0.1 will be used. - projection_method_values: The values for
projection
method used in the grid-search. By default the valuesmask
,points
,box
,points_and_mask
andsingle_point
are used. - box_extension_values: The values for
box_extension
used in the grid-search. By default values in the range from 0 to 0.25 with a stepsize of 0.025 will be used.
Returns:
The values for grid search.
59@torch.no_grad() 60def segment_slices_from_ground_truth( 61 volume: np.ndarray, 62 ground_truth: np.ndarray, 63 model_type: str, 64 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 65 embedding_path: Optional[Union[str, os.PathLike]] = None, 66 save_path: Optional[Union[str, os.PathLike]] = None, 67 iou_threshold: float = 0.8, 68 projection: Union[str, dict] = "mask", 69 box_extension: Union[float, int] = 0.025, 70 device: Union[str, torch.device] = None, 71 interactive_seg_mode: str = "box", 72 verbose: bool = False, 73 return_segmentation: bool = False, 74 min_size: int = 0, 75 evaluation_metric: Literal["sa", "dice"] = "sa", 76) -> Union[float, Tuple[np.ndarray, float]]: 77 """Segment all objects in a volume by prompt-based segmentation in one slice per object. 78 79 This function first segments each object in the respective specified slice using interactive 80 (prompt-based) segmentation functionality. Then it segments the particular object in the 81 remaining slices in the volume. 82 83 Args: 84 volume: The input volume. 85 ground_truth: The label volume with instance segmentations. 86 model_type: Choice of segment anything model. 87 checkpoint_path: Path to the model checkpoint. 88 embedding_path: Path to cache the computed embeddings. 89 save_path: Path to store the segmentations. 90 iou_threshold: The criterion to decide whether to link the objects in the consecutive slice's segmentation. 91 projection: The projection (prompting) method to generate prompts for consecutive slices. 92 box_extension: Extension factor for increasing the box size after projection. 93 device: The selected device for computation. 94 interactive_seg_mode: Method for guiding prompt-based instance segmentation. 95 verbose: Whether to get the trace for projected segmentations. 96 return_segmentation: Whether to return the segmented volume. 97 min_size: The minimal size for evaluating an object in the ground-truth. 98 The size is measured within the central slice. 99 evaluation_metric: The choice of supported metric to evaluate predictions. 100 """ 101 assert volume.ndim == 3 102 103 predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, device=device) 104 105 # Compute the image embeddings 106 embeddings = util.precompute_image_embeddings( 107 predictor=predictor, input_=volume, save_path=embedding_path, ndim=3, verbose=verbose, 108 ) 109 110 # Compute instance ids (without the background) 111 label_ids = np.unique(ground_truth)[1:] 112 assert len(label_ids) > 0, "There are no objects to perform volumetric segmentation." 113 114 # Create an empty volume to store incoming segmentations 115 final_segmentation = np.zeros_like(ground_truth) 116 117 _segmentation_completed = False 118 if save_path is not None and os.path.exists(save_path): 119 _segmentation_completed = True # We avoid rerunning the segmentation if it is completed. 120 121 skipped_label_ids = [] 122 for label_id in tqdm(label_ids, desc="Segmenting per object in the volume", disable=not verbose): 123 # Binary label volume per instance (also referred to as object) 124 this_seg = (ground_truth == label_id).astype("int") 125 126 # Let's search the slices where we have the current object 127 slice_range = np.where(this_seg)[0] 128 129 # Choose the middle slice of the current object for prompt-based segmentation 130 slice_range = (slice_range.min(), slice_range.max()) 131 slice_choice = floor(np.mean(slice_range)) 132 this_slice_seg = this_seg[slice_choice] 133 if min_size > 0 and this_slice_seg.sum() < min_size: 134 skipped_label_ids.append(label_id) 135 continue 136 137 if _segmentation_completed: 138 continue 139 140 if verbose: 141 print(f"The object with id {label_id} lies in slice range: {slice_range}") 142 143 # Prompts for segmentation for the current slice 144 if interactive_seg_mode == "points": 145 _get_points, _get_box = True, False 146 elif interactive_seg_mode == "box": 147 _get_points, _get_box = False, True 148 else: 149 raise ValueError( 150 f"The provided interactive prompting '{interactive_seg_mode}' for the first slice isn't supported." 151 "Please choose from 'box' / 'points'." 152 ) 153 154 prompt_generator = PointAndBoxPromptGenerator( 155 n_positive_points=1 if _get_points else 0, 156 n_negative_points=1 if _get_points else 0, 157 dilation_strength=10, 158 get_point_prompts=_get_points, 159 get_box_prompts=_get_box 160 ) 161 _, box_coords = util.get_centers_and_bounding_boxes(this_slice_seg) 162 point_prompts, point_labels, box_prompts, _ = prompt_generator( 163 segmentation=torch.from_numpy(this_slice_seg)[None, None].to(torch.float32), 164 bbox_coordinates=[box_coords[1]], 165 ) 166 167 # Prompt-based segmentation on middle slice of the current object 168 output_slice = batched_inference( 169 predictor=predictor, 170 image=volume[slice_choice], 171 batch_size=1, 172 boxes=box_prompts.numpy() if isinstance(box_prompts, torch.Tensor) else box_prompts, 173 points=point_prompts.numpy() if isinstance(point_prompts, torch.Tensor) else point_prompts, 174 point_labels=point_labels.numpy() if isinstance(point_labels, torch.Tensor) else point_labels, 175 verbose_embeddings=verbose, 176 ) 177 output_seg = np.zeros_like(ground_truth) 178 output_seg[slice_choice][output_slice == 1] = 1 179 180 # Segment the object in the entire volume with the specified segmented slice 181 this_seg, _ = segment_mask_in_volume( 182 segmentation=output_seg, 183 predictor=predictor, 184 image_embeddings=embeddings, 185 segmented_slices=np.array(slice_choice), 186 stop_lower=False, stop_upper=False, 187 iou_threshold=iou_threshold, 188 projection=projection, 189 box_extension=box_extension, 190 verbose=verbose, 191 ) 192 193 # Store the entire segmented object 194 final_segmentation[this_seg == 1] = label_id 195 196 # Save the volumetric segmentation 197 if save_path is not None: 198 if _segmentation_completed: 199 final_segmentation = imageio.imread(save_path) 200 else: 201 imageio.imwrite(save_path, final_segmentation, compression="zlib") 202 203 # Evaluate the volumetric segmentation 204 if skipped_label_ids: 205 curr_gt = ground_truth.copy() 206 curr_gt[np.isin(curr_gt, skipped_label_ids)] = 0 207 else: 208 curr_gt = ground_truth 209 210 if evaluation_metric == "sa": 211 msa, sa = mean_segmentation_accuracy( 212 segmentation=final_segmentation, groundtruth=curr_gt, return_accuracies=True 213 ) 214 results = {"mSA": msa, "SA50": sa[0], "SA75": sa[5]} 215 elif evaluation_metric == "dice": 216 dice = dice_score(segmentation=final_segmentation, groundtruth=curr_gt) 217 results = {"Dice": dice} 218 else: 219 raise ValueError(f"'{evaluation_metric}' is not a supported evaluation metrics. Please choose 'sa' / 'dice'.") 220 221 if return_segmentation: 222 return results, final_segmentation 223 else: 224 return results
Segment all objects in a volume by prompt-based segmentation in one slice per object.
This function first segments each object in the respective specified slice using interactive (prompt-based) segmentation functionality. Then it segments the particular object in the remaining slices in the volume.
Arguments:
- volume: The input volume.
- ground_truth: The label volume with instance segmentations.
- model_type: Choice of segment anything model.
- checkpoint_path: Path to the model checkpoint.
- embedding_path: Path to cache the computed embeddings.
- save_path: Path to store the segmentations.
- iou_threshold: The criterion to decide whether to link the objects in the consecutive slice's segmentation.
- projection: The projection (prompting) method to generate prompts for consecutive slices.
- box_extension: Extension factor for increasing the box size after projection.
- device: The selected device for computation.
- interactive_seg_mode: Method for guiding prompt-based instance segmentation.
- verbose: Whether to get the trace for projected segmentations.
- return_segmentation: Whether to return the segmented volume.
- min_size: The minimal size for evaluating an object in the ground-truth. The size is measured within the central slice.
- evaluation_metric: The choice of supported metric to evaluate predictions.
248def run_multi_dimensional_segmentation_grid_search( 249 volume: np.ndarray, 250 ground_truth: np.ndarray, 251 model_type: str, 252 checkpoint_path: Union[str, os.PathLike], 253 embedding_path: Optional[Union[str, os.PathLike]], 254 result_dir: Union[str, os.PathLike], 255 interactive_seg_mode: str = "box", 256 verbose: bool = False, 257 grid_search_values: Optional[Dict[str, List]] = None, 258 min_size: int = 0, 259 evaluation_metric: Literal["sa", "dice"] = "sa", 260): 261 """Run grid search for prompt-based multi-dimensional instance segmentation. 262 263 The parameters and their respective value ranges for the grid search are specified via the 264 `grid_search_values` argument. For example, to run a grid search over the parameters `iou_threshold`, 265 `projection` and `box_extension`, you can pass the following: 266 ``` 267 grid_search_values = { 268 "iou_threshold": [0.5, 0.6, 0.7, 0.8, 0.9], 269 "projection": ["mask", "box", "points"], 270 "box_extension": [0, 0.1, 0.2, 0.3, 0.4, 0,5], 271 } 272 ``` 273 All combinations of the parameters will be checked. 274 If passed None, the function `default_grid_search_values_multi_dimensional_segmentation` is used 275 to get the default grid search parameters for the instance segmentation method. 276 277 Args: 278 volume: The input volume. 279 ground_truth: The label volume with instance segmentations. 280 model_type: Choice of segment anything model. 281 checkpoint_path: Path to the model checkpoint. 282 embedding_path: Path to cache the computed embeddings. 283 result_dir: Path to save the grid search results. 284 interactive_seg_mode: Method for guiding prompt-based instance segmentation. 285 verbose: Whether to get the trace for projected segmentations. 286 grid_search_values: The grid search values for parameters of the `segment_slices_from_ground_truth` function. 287 min_size: The minimal size for evaluating an object in the ground-truth. 288 The size is measured within the central slice. 289 evaluation_metric: The choice of metric for evaluating predictions. 290 """ 291 if grid_search_values is None: 292 grid_search_values = default_grid_search_values_multi_dimensional_segmentation() 293 294 assert len(grid_search_values.keys()) == 3, "There must be three grid-search parameters. See above for details." 295 296 os.makedirs(result_dir, exist_ok=True) 297 result_path = os.path.join(result_dir, "all_grid_search_results.csv") 298 best_params_path = os.path.join(result_dir, "grid_search_params_multi_dimensional_segmentation.csv") 299 if os.path.exists(result_path): 300 _get_best_parameters_from_grid_search_combinations( 301 result_dir, best_params_path, grid_search_values, evaluation_metric 302 ) 303 return best_params_path 304 305 # Compute all combinations of grid search values. 306 gs_combinations = product(*grid_search_values.values()) 307 308 # Map each combination back to a valid kwarg input. 309 gs_combinations = [ 310 {k: v for k, v in zip(grid_search_values.keys(), vals)} for vals in gs_combinations 311 ] 312 313 net_list = [] 314 for gs_kwargs in tqdm(gs_combinations, desc="Run grid-search for multi-dimensional segmentation"): 315 results = segment_slices_from_ground_truth( 316 volume=volume, 317 ground_truth=ground_truth, 318 model_type=model_type, 319 checkpoint_path=checkpoint_path, 320 embedding_path=embedding_path, 321 interactive_seg_mode=interactive_seg_mode, 322 verbose=verbose, 323 return_segmentation=False, 324 min_size=min_size, 325 evaluation_metric=evaluation_metric, 326 **gs_kwargs 327 ) 328 329 result_dict = {**results, **gs_kwargs} 330 tmp_df = pd.DataFrame([result_dict]) 331 net_list.append(tmp_df) 332 333 res_df = pd.concat(net_list, ignore_index=True) 334 res_df.to_csv(result_path) 335 336 _get_best_parameters_from_grid_search_combinations( 337 result_dir, best_params_path, grid_search_values, evaluation_metric 338 ) 339 print("The best grid-search parameters have been computed and stored at:", best_params_path) 340 return best_params_path
Run grid search for prompt-based multi-dimensional instance segmentation.
The parameters and their respective value ranges for the grid search are specified via the
grid_search_values
argument. For example, to run a grid search over the parameters iou_threshold
,
projection
and box_extension
, you can pass the following:
grid_search_values = {
"iou_threshold": [0.5, 0.6, 0.7, 0.8, 0.9],
"projection": ["mask", "box", "points"],
"box_extension": [0, 0.1, 0.2, 0.3, 0.4, 0,5],
}
All combinations of the parameters will be checked.
If passed None, the function default_grid_search_values_multi_dimensional_segmentation
is used
to get the default grid search parameters for the instance segmentation method.
Arguments:
- volume: The input volume.
- ground_truth: The label volume with instance segmentations.
- model_type: Choice of segment anything model.
- checkpoint_path: Path to the model checkpoint.
- embedding_path: Path to cache the computed embeddings.
- result_dir: Path to save the grid search results.
- interactive_seg_mode: Method for guiding prompt-based instance segmentation.
- verbose: Whether to get the trace for projected segmentations.
- grid_search_values: The grid search values for parameters of the
segment_slices_from_ground_truth
function. - min_size: The minimal size for evaluating an object in the ground-truth. The size is measured within the central slice.
- evaluation_metric: The choice of metric for evaluating predictions.