micro_sam.evaluation.multi_dimensional_segmentation
1import os 2import numpy as np 3import pandas as pd 4from tqdm import tqdm 5from math import floor 6from itertools import product 7from typing import Union, Tuple, Optional, List, Dict 8 9import imageio.v3 as imageio 10 11import torch 12 13from elf.evaluation import mean_segmentation_accuracy 14 15from .. import util 16from ..inference import batched_inference 17from ..prompt_generators import PointAndBoxPromptGenerator 18from ..multi_dimensional_segmentation import segment_mask_in_volume 19from ..evaluation.instance_segmentation import _get_range_of_search_values, evaluate_instance_segmentation_grid_search 20 21 22def default_grid_search_values_multi_dimensional_segmentation( 23 iou_threshold_values: Optional[List[float]] = None, 24 projection_method_values: Optional[Union[str, dict]] = None, 25 box_extension_values: Optional[Union[float, int]] = None 26) -> Dict[str, List]: 27 """Default grid-search parameters for multi-dimensional prompt-based instance segmentation. 28 29 Args: 30 iou_threshold_values: The values for `iou_threshold` used in the grid-search. 31 By default values in the range from 0.5 to 0.9 with a stepsize of 0.1 will be used. 32 projection_method_values: The values for `projection` method used in the grid-search. 33 By default the values `mask`, `bounding_box` and `points` are used. 34 box_extension_values: The values for `box_extension` used in the grid-search. 35 By default values in the range from 0 to 0.25 with a stepsize of 0.025 will be used. 36 37 Returns: 38 The values for grid search. 39 """ 40 if iou_threshold_values is None: 41 iou_threshold_values = _get_range_of_search_values([0.5, 0.9], step=0.1) 42 43 if projection_method_values is None: 44 projection_method_values = [ 45 "mask", "points", "box", "points_and_mask", "single_point" 46 ] 47 48 if box_extension_values is None: 49 box_extension_values = _get_range_of_search_values([0, 0.25], step=0.025) 50 51 return { 52 "iou_threshold": iou_threshold_values, 53 "projection": projection_method_values, 54 "box_extension": box_extension_values 55 } 56 57 58@torch.no_grad() 59def segment_slices_from_ground_truth( 60 volume: np.ndarray, 61 ground_truth: np.ndarray, 62 model_type: str, 63 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 64 embedding_path: Optional[Union[str, os.PathLike]] = None, 65 save_path: Optional[Union[str, os.PathLike]] = None, 66 iou_threshold: float = 0.8, 67 projection: Union[str, dict] = "mask", 68 box_extension: Union[float, int] = 0.025, 69 device: Union[str, torch.device] = None, 70 interactive_seg_mode: str = "box", 71 verbose: bool = False, 72 return_segmentation: bool = False, 73 min_size: int = 0, 74) -> Union[float, Tuple[np.ndarray, float]]: 75 """Segment all objects in a volume by prompt-based segmentation in one slice per object. 76 77 This function first segments each object in the respective specified slice using interactive 78 (prompt-based) segmentation functionality. Then it segments the particular object in the 79 remaining slices in the volume. 80 81 Args: 82 volume: The input volume. 83 ground_truth: The label volume with instance segmentations. 84 model_type: Choice of segment anything model. 85 checkpoint_path: Path to the model checkpoint. 86 embedding_path: Path to cache the computed embeddings. 87 save_path: Path to store the segmentations. 88 iou_threshold: The criterion to decide whether to link the objects in the consecutive slice's segmentation. 89 projection: The projection (prompting) method to generate prompts for consecutive slices. 90 box_extension: Extension factor for increasing the box size after projection. 91 device: The selected device for computation. 92 interactive_seg_mode: Method for guiding prompt-based instance segmentation. 93 verbose: Whether to get the trace for projected segmentations. 94 return_segmentation: Whether to return the segmented volume. 95 min_size: The minimal size for evaluating an object in the ground-truth. 96 The size is measured within the central slice. 97 """ 98 assert volume.ndim == 3 99 100 predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, device=device) 101 102 # Compute the image embeddings 103 embeddings = util.precompute_image_embeddings( 104 predictor=predictor, input_=volume, save_path=embedding_path, ndim=3, verbose=verbose, 105 ) 106 107 # Compute instance ids (without the background) 108 label_ids = np.unique(ground_truth)[1:] 109 assert len(label_ids) > 0, "There are no objects to perform volumetric segmentation." 110 111 # Create an empty volume to store incoming segmentations 112 final_segmentation = np.zeros_like(ground_truth) 113 114 skipped_label_ids = [] 115 for label_id in label_ids: 116 # Binary label volume per instance (also referred to as object) 117 this_seg = (ground_truth == label_id).astype("int") 118 119 # Let's search the slices where we have the current object 120 slice_range = np.where(this_seg)[0] 121 122 # Choose the middle slice of the current object for prompt-based segmentation 123 slice_range = (slice_range.min(), slice_range.max()) 124 slice_choice = floor(np.mean(slice_range)) 125 this_slice_seg = this_seg[slice_choice] 126 if min_size > 0 and this_slice_seg.sum() < min_size: 127 skipped_label_ids.append(label_id) 128 continue 129 130 if verbose: 131 print(f"The object with id {label_id} lies in slice range: {slice_range}") 132 133 # Prompts for segmentation for the current slice 134 if interactive_seg_mode == "points": 135 _get_points, _get_box = True, False 136 elif interactive_seg_mode == "box": 137 _get_points, _get_box = False, True 138 else: 139 raise ValueError( 140 f"The provided interactive prompting '{interactive_seg_mode}' for the first slice isn't supported." 141 "Please choose from 'box' / 'points'." 142 ) 143 144 prompt_generator = PointAndBoxPromptGenerator( 145 n_positive_points=1 if _get_points else 0, 146 n_negative_points=1 if _get_points else 0, 147 dilation_strength=10, 148 get_point_prompts=_get_points, 149 get_box_prompts=_get_box 150 ) 151 _, box_coords = util.get_centers_and_bounding_boxes(this_slice_seg) 152 point_prompts, point_labels, box_prompts, _ = prompt_generator( 153 segmentation=torch.from_numpy(this_slice_seg)[None, None].to(torch.float32), 154 bbox_coordinates=[box_coords[1]], 155 ) 156 157 # Prompt-based segmentation on middle slice of the current object 158 output_slice = batched_inference( 159 predictor=predictor, 160 image=volume[slice_choice], 161 batch_size=1, 162 boxes=box_prompts.numpy() if isinstance(box_prompts, torch.Tensor) else box_prompts, 163 points=point_prompts.numpy() if isinstance(point_prompts, torch.Tensor) else point_prompts, 164 point_labels=point_labels.numpy() if isinstance(point_labels, torch.Tensor) else point_labels, 165 verbose_embeddings=verbose, 166 ) 167 output_seg = np.zeros_like(ground_truth) 168 output_seg[slice_choice][output_slice == 1] = 1 169 170 # Segment the object in the entire volume with the specified segmented slice 171 this_seg, _ = segment_mask_in_volume( 172 segmentation=output_seg, 173 predictor=predictor, 174 image_embeddings=embeddings, 175 segmented_slices=np.array(slice_choice), 176 stop_lower=False, stop_upper=False, 177 iou_threshold=iou_threshold, 178 projection=projection, 179 box_extension=box_extension, 180 verbose=verbose, 181 ) 182 183 # Store the entire segmented object 184 final_segmentation[this_seg == 1] = label_id 185 186 # Save the volumetric segmentation 187 if save_path is not None: 188 imageio.imwrite(save_path, final_segmentation, compression="zlib") 189 190 # Evaluate the volumetric segmentation 191 if skipped_label_ids: 192 curr_gt = ground_truth.copy() 193 curr_gt[np.isin(curr_gt, skipped_label_ids)] = 0 194 else: 195 curr_gt = ground_truth 196 197 msa, sa = mean_segmentation_accuracy(final_segmentation, curr_gt, return_accuracies=True) 198 results = {"mSA": msa, "SA50": sa[0], "SA75": sa[5]} 199 results = pd.DataFrame.from_dict([results]) 200 201 if return_segmentation: 202 return results, final_segmentation 203 else: 204 return results 205 206 207def _get_best_parameters_from_grid_search_combinations(result_dir, best_params_path, grid_search_values): 208 if os.path.exists(best_params_path): 209 print("The best parameters are already saved at:", best_params_path) 210 return 211 212 best_kwargs, best_msa = evaluate_instance_segmentation_grid_search(result_dir, list(grid_search_values.keys())) 213 214 # let's save the best parameters 215 best_kwargs["mSA"] = best_msa 216 best_param_df = pd.DataFrame.from_dict([best_kwargs]) 217 best_param_df.to_csv(best_params_path) 218 219 best_param_str = ", ".join(f"{k} = {v}" for k, v in best_kwargs.items()) 220 print("Best grid-search result:", best_msa, "with parmeters:\n", best_param_str) 221 222 223def run_multi_dimensional_segmentation_grid_search( 224 volume: np.ndarray, 225 ground_truth: np.ndarray, 226 model_type: str, 227 checkpoint_path: Union[str, os.PathLike], 228 embedding_path: Union[str, os.PathLike], 229 result_dir: Union[str, os.PathLike], 230 interactive_seg_mode: str = "box", 231 verbose: bool = False, 232 grid_search_values: Optional[Dict[str, List]] = None, 233 min_size: int = 0 234): 235 """Run grid search for prompt-based multi-dimensional instance segmentation. 236 237 The parameters and their respective value ranges for the grid search are specified via the 238 `grid_search_values` argument. For example, to run a grid search over the parameters `iou_threshold`, 239 `projection` and `box_extension`, you can pass the following: 240 ``` 241 grid_search_values = { 242 "iou_threshold": [0.5, 0.6, 0.7, 0.8, 0.9], 243 "projection": ["mask", "bounding_box", "points"], 244 "box_extension": [0, 0.1, 0.2, 0.3, 0.4, 0,5], 245 } 246 ``` 247 All combinations of the parameters will be checked. 248 If passed None, the function `default_grid_search_values_multi_dimensional_segmentation` is used 249 to get the default grid search parameters for the instance segmentation method. 250 251 Args: 252 volume: The input volume. 253 ground_truth: The label volume with instance segmentations. 254 model_type: Choice of segment anything model. 255 checkpoint_path: Path to the model checkpoint. 256 embedding_path: Path to cache the computed embeddings. 257 result_path: Path to save the grid search results. 258 interactive_seg_mode: Method for guiding prompt-based instance segmentation. 259 verbose: Whether to get the trace for projected segmentations. 260 grid_search_values: The grid search values for parameters of the `segment_slices_from_ground_truth` function. 261 min_size: The minimal size for evaluating an object in the ground-truth. 262 The size is measured within the central slice. 263 """ 264 if grid_search_values is None: 265 grid_search_values = default_grid_search_values_multi_dimensional_segmentation() 266 267 assert len(grid_search_values.keys()) == 3, "There must be three grid-search parameters. See above for details." 268 269 os.makedirs(result_dir, exist_ok=True) 270 result_path = os.path.join(result_dir, "all_grid_search_results.csv") 271 best_params_path = os.path.join(result_dir, "grid_search_params_multi_dimensional_segmentation.csv") 272 if os.path.exists(result_path): 273 _get_best_parameters_from_grid_search_combinations(result_dir, best_params_path, grid_search_values) 274 return best_params_path 275 276 # Compute all combinations of grid search values. 277 gs_combinations = product(*grid_search_values.values()) 278 279 # Map each combination back to a valid kwarg input. 280 gs_combinations = [ 281 {k: v for k, v in zip(grid_search_values.keys(), vals)} for vals in gs_combinations 282 ] 283 284 net_list = [] 285 for gs_kwargs in tqdm(gs_combinations): 286 results = segment_slices_from_ground_truth( 287 volume=volume, 288 ground_truth=ground_truth, 289 model_type=model_type, 290 checkpoint_path=checkpoint_path, 291 embedding_path=embedding_path, 292 interactive_seg_mode=interactive_seg_mode, 293 verbose=verbose, 294 return_segmentation=False, 295 min_size=min_size, 296 **gs_kwargs 297 ) 298 299 result_dict = {**results, **gs_kwargs} 300 tmp_df = pd.DataFrame([result_dict]) 301 net_list.append(tmp_df) 302 303 res_df = pd.concat(net_list, ignore_index=True) 304 res_df.to_csv(result_path) 305 306 _get_best_parameters_from_grid_search_combinations(result_dir, best_params_path, grid_search_values) 307 print("The best grid-search parameters have been computed and stored at:", best_params_path) 308 return best_params_path
23def default_grid_search_values_multi_dimensional_segmentation( 24 iou_threshold_values: Optional[List[float]] = None, 25 projection_method_values: Optional[Union[str, dict]] = None, 26 box_extension_values: Optional[Union[float, int]] = None 27) -> Dict[str, List]: 28 """Default grid-search parameters for multi-dimensional prompt-based instance segmentation. 29 30 Args: 31 iou_threshold_values: The values for `iou_threshold` used in the grid-search. 32 By default values in the range from 0.5 to 0.9 with a stepsize of 0.1 will be used. 33 projection_method_values: The values for `projection` method used in the grid-search. 34 By default the values `mask`, `bounding_box` and `points` are used. 35 box_extension_values: The values for `box_extension` used in the grid-search. 36 By default values in the range from 0 to 0.25 with a stepsize of 0.025 will be used. 37 38 Returns: 39 The values for grid search. 40 """ 41 if iou_threshold_values is None: 42 iou_threshold_values = _get_range_of_search_values([0.5, 0.9], step=0.1) 43 44 if projection_method_values is None: 45 projection_method_values = [ 46 "mask", "points", "box", "points_and_mask", "single_point" 47 ] 48 49 if box_extension_values is None: 50 box_extension_values = _get_range_of_search_values([0, 0.25], step=0.025) 51 52 return { 53 "iou_threshold": iou_threshold_values, 54 "projection": projection_method_values, 55 "box_extension": box_extension_values 56 }
Default grid-search parameters for multi-dimensional prompt-based instance segmentation.
Arguments:
- iou_threshold_values: The values for
iou_threshold
used in the grid-search. By default values in the range from 0.5 to 0.9 with a stepsize of 0.1 will be used. - projection_method_values: The values for
projection
method used in the grid-search. By default the valuesmask
,bounding_box
andpoints
are used. - box_extension_values: The values for
box_extension
used in the grid-search. By default values in the range from 0 to 0.25 with a stepsize of 0.025 will be used.
Returns:
The values for grid search.
59@torch.no_grad() 60def segment_slices_from_ground_truth( 61 volume: np.ndarray, 62 ground_truth: np.ndarray, 63 model_type: str, 64 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 65 embedding_path: Optional[Union[str, os.PathLike]] = None, 66 save_path: Optional[Union[str, os.PathLike]] = None, 67 iou_threshold: float = 0.8, 68 projection: Union[str, dict] = "mask", 69 box_extension: Union[float, int] = 0.025, 70 device: Union[str, torch.device] = None, 71 interactive_seg_mode: str = "box", 72 verbose: bool = False, 73 return_segmentation: bool = False, 74 min_size: int = 0, 75) -> Union[float, Tuple[np.ndarray, float]]: 76 """Segment all objects in a volume by prompt-based segmentation in one slice per object. 77 78 This function first segments each object in the respective specified slice using interactive 79 (prompt-based) segmentation functionality. Then it segments the particular object in the 80 remaining slices in the volume. 81 82 Args: 83 volume: The input volume. 84 ground_truth: The label volume with instance segmentations. 85 model_type: Choice of segment anything model. 86 checkpoint_path: Path to the model checkpoint. 87 embedding_path: Path to cache the computed embeddings. 88 save_path: Path to store the segmentations. 89 iou_threshold: The criterion to decide whether to link the objects in the consecutive slice's segmentation. 90 projection: The projection (prompting) method to generate prompts for consecutive slices. 91 box_extension: Extension factor for increasing the box size after projection. 92 device: The selected device for computation. 93 interactive_seg_mode: Method for guiding prompt-based instance segmentation. 94 verbose: Whether to get the trace for projected segmentations. 95 return_segmentation: Whether to return the segmented volume. 96 min_size: The minimal size for evaluating an object in the ground-truth. 97 The size is measured within the central slice. 98 """ 99 assert volume.ndim == 3 100 101 predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, device=device) 102 103 # Compute the image embeddings 104 embeddings = util.precompute_image_embeddings( 105 predictor=predictor, input_=volume, save_path=embedding_path, ndim=3, verbose=verbose, 106 ) 107 108 # Compute instance ids (without the background) 109 label_ids = np.unique(ground_truth)[1:] 110 assert len(label_ids) > 0, "There are no objects to perform volumetric segmentation." 111 112 # Create an empty volume to store incoming segmentations 113 final_segmentation = np.zeros_like(ground_truth) 114 115 skipped_label_ids = [] 116 for label_id in label_ids: 117 # Binary label volume per instance (also referred to as object) 118 this_seg = (ground_truth == label_id).astype("int") 119 120 # Let's search the slices where we have the current object 121 slice_range = np.where(this_seg)[0] 122 123 # Choose the middle slice of the current object for prompt-based segmentation 124 slice_range = (slice_range.min(), slice_range.max()) 125 slice_choice = floor(np.mean(slice_range)) 126 this_slice_seg = this_seg[slice_choice] 127 if min_size > 0 and this_slice_seg.sum() < min_size: 128 skipped_label_ids.append(label_id) 129 continue 130 131 if verbose: 132 print(f"The object with id {label_id} lies in slice range: {slice_range}") 133 134 # Prompts for segmentation for the current slice 135 if interactive_seg_mode == "points": 136 _get_points, _get_box = True, False 137 elif interactive_seg_mode == "box": 138 _get_points, _get_box = False, True 139 else: 140 raise ValueError( 141 f"The provided interactive prompting '{interactive_seg_mode}' for the first slice isn't supported." 142 "Please choose from 'box' / 'points'." 143 ) 144 145 prompt_generator = PointAndBoxPromptGenerator( 146 n_positive_points=1 if _get_points else 0, 147 n_negative_points=1 if _get_points else 0, 148 dilation_strength=10, 149 get_point_prompts=_get_points, 150 get_box_prompts=_get_box 151 ) 152 _, box_coords = util.get_centers_and_bounding_boxes(this_slice_seg) 153 point_prompts, point_labels, box_prompts, _ = prompt_generator( 154 segmentation=torch.from_numpy(this_slice_seg)[None, None].to(torch.float32), 155 bbox_coordinates=[box_coords[1]], 156 ) 157 158 # Prompt-based segmentation on middle slice of the current object 159 output_slice = batched_inference( 160 predictor=predictor, 161 image=volume[slice_choice], 162 batch_size=1, 163 boxes=box_prompts.numpy() if isinstance(box_prompts, torch.Tensor) else box_prompts, 164 points=point_prompts.numpy() if isinstance(point_prompts, torch.Tensor) else point_prompts, 165 point_labels=point_labels.numpy() if isinstance(point_labels, torch.Tensor) else point_labels, 166 verbose_embeddings=verbose, 167 ) 168 output_seg = np.zeros_like(ground_truth) 169 output_seg[slice_choice][output_slice == 1] = 1 170 171 # Segment the object in the entire volume with the specified segmented slice 172 this_seg, _ = segment_mask_in_volume( 173 segmentation=output_seg, 174 predictor=predictor, 175 image_embeddings=embeddings, 176 segmented_slices=np.array(slice_choice), 177 stop_lower=False, stop_upper=False, 178 iou_threshold=iou_threshold, 179 projection=projection, 180 box_extension=box_extension, 181 verbose=verbose, 182 ) 183 184 # Store the entire segmented object 185 final_segmentation[this_seg == 1] = label_id 186 187 # Save the volumetric segmentation 188 if save_path is not None: 189 imageio.imwrite(save_path, final_segmentation, compression="zlib") 190 191 # Evaluate the volumetric segmentation 192 if skipped_label_ids: 193 curr_gt = ground_truth.copy() 194 curr_gt[np.isin(curr_gt, skipped_label_ids)] = 0 195 else: 196 curr_gt = ground_truth 197 198 msa, sa = mean_segmentation_accuracy(final_segmentation, curr_gt, return_accuracies=True) 199 results = {"mSA": msa, "SA50": sa[0], "SA75": sa[5]} 200 results = pd.DataFrame.from_dict([results]) 201 202 if return_segmentation: 203 return results, final_segmentation 204 else: 205 return results
Segment all objects in a volume by prompt-based segmentation in one slice per object.
This function first segments each object in the respective specified slice using interactive (prompt-based) segmentation functionality. Then it segments the particular object in the remaining slices in the volume.
Arguments:
- volume: The input volume.
- ground_truth: The label volume with instance segmentations.
- model_type: Choice of segment anything model.
- checkpoint_path: Path to the model checkpoint.
- embedding_path: Path to cache the computed embeddings.
- save_path: Path to store the segmentations.
- iou_threshold: The criterion to decide whether to link the objects in the consecutive slice's segmentation.
- projection: The projection (prompting) method to generate prompts for consecutive slices.
- box_extension: Extension factor for increasing the box size after projection.
- device: The selected device for computation.
- interactive_seg_mode: Method for guiding prompt-based instance segmentation.
- verbose: Whether to get the trace for projected segmentations.
- return_segmentation: Whether to return the segmented volume.
- min_size: The minimal size for evaluating an object in the ground-truth. The size is measured within the central slice.
224def run_multi_dimensional_segmentation_grid_search( 225 volume: np.ndarray, 226 ground_truth: np.ndarray, 227 model_type: str, 228 checkpoint_path: Union[str, os.PathLike], 229 embedding_path: Union[str, os.PathLike], 230 result_dir: Union[str, os.PathLike], 231 interactive_seg_mode: str = "box", 232 verbose: bool = False, 233 grid_search_values: Optional[Dict[str, List]] = None, 234 min_size: int = 0 235): 236 """Run grid search for prompt-based multi-dimensional instance segmentation. 237 238 The parameters and their respective value ranges for the grid search are specified via the 239 `grid_search_values` argument. For example, to run a grid search over the parameters `iou_threshold`, 240 `projection` and `box_extension`, you can pass the following: 241 ``` 242 grid_search_values = { 243 "iou_threshold": [0.5, 0.6, 0.7, 0.8, 0.9], 244 "projection": ["mask", "bounding_box", "points"], 245 "box_extension": [0, 0.1, 0.2, 0.3, 0.4, 0,5], 246 } 247 ``` 248 All combinations of the parameters will be checked. 249 If passed None, the function `default_grid_search_values_multi_dimensional_segmentation` is used 250 to get the default grid search parameters for the instance segmentation method. 251 252 Args: 253 volume: The input volume. 254 ground_truth: The label volume with instance segmentations. 255 model_type: Choice of segment anything model. 256 checkpoint_path: Path to the model checkpoint. 257 embedding_path: Path to cache the computed embeddings. 258 result_path: Path to save the grid search results. 259 interactive_seg_mode: Method for guiding prompt-based instance segmentation. 260 verbose: Whether to get the trace for projected segmentations. 261 grid_search_values: The grid search values for parameters of the `segment_slices_from_ground_truth` function. 262 min_size: The minimal size for evaluating an object in the ground-truth. 263 The size is measured within the central slice. 264 """ 265 if grid_search_values is None: 266 grid_search_values = default_grid_search_values_multi_dimensional_segmentation() 267 268 assert len(grid_search_values.keys()) == 3, "There must be three grid-search parameters. See above for details." 269 270 os.makedirs(result_dir, exist_ok=True) 271 result_path = os.path.join(result_dir, "all_grid_search_results.csv") 272 best_params_path = os.path.join(result_dir, "grid_search_params_multi_dimensional_segmentation.csv") 273 if os.path.exists(result_path): 274 _get_best_parameters_from_grid_search_combinations(result_dir, best_params_path, grid_search_values) 275 return best_params_path 276 277 # Compute all combinations of grid search values. 278 gs_combinations = product(*grid_search_values.values()) 279 280 # Map each combination back to a valid kwarg input. 281 gs_combinations = [ 282 {k: v for k, v in zip(grid_search_values.keys(), vals)} for vals in gs_combinations 283 ] 284 285 net_list = [] 286 for gs_kwargs in tqdm(gs_combinations): 287 results = segment_slices_from_ground_truth( 288 volume=volume, 289 ground_truth=ground_truth, 290 model_type=model_type, 291 checkpoint_path=checkpoint_path, 292 embedding_path=embedding_path, 293 interactive_seg_mode=interactive_seg_mode, 294 verbose=verbose, 295 return_segmentation=False, 296 min_size=min_size, 297 **gs_kwargs 298 ) 299 300 result_dict = {**results, **gs_kwargs} 301 tmp_df = pd.DataFrame([result_dict]) 302 net_list.append(tmp_df) 303 304 res_df = pd.concat(net_list, ignore_index=True) 305 res_df.to_csv(result_path) 306 307 _get_best_parameters_from_grid_search_combinations(result_dir, best_params_path, grid_search_values) 308 print("The best grid-search parameters have been computed and stored at:", best_params_path) 309 return best_params_path
Run grid search for prompt-based multi-dimensional instance segmentation.
The parameters and their respective value ranges for the grid search are specified via the
grid_search_values
argument. For example, to run a grid search over the parameters iou_threshold
,
projection
and box_extension
, you can pass the following:
grid_search_values = {
"iou_threshold": [0.5, 0.6, 0.7, 0.8, 0.9],
"projection": ["mask", "bounding_box", "points"],
"box_extension": [0, 0.1, 0.2, 0.3, 0.4, 0,5],
}
All combinations of the parameters will be checked.
If passed None, the function default_grid_search_values_multi_dimensional_segmentation
is used
to get the default grid search parameters for the instance segmentation method.
Arguments:
- volume: The input volume.
- ground_truth: The label volume with instance segmentations.
- model_type: Choice of segment anything model.
- checkpoint_path: Path to the model checkpoint.
- embedding_path: Path to cache the computed embeddings.
- result_path: Path to save the grid search results.
- interactive_seg_mode: Method for guiding prompt-based instance segmentation.
- verbose: Whether to get the trace for projected segmentations.
- grid_search_values: The grid search values for parameters of the
segment_slices_from_ground_truth
function. - min_size: The minimal size for evaluating an object in the ground-truth. The size is measured within the central slice.