micro_sam.evaluation.instance_segmentation
Inference and evaluation for the automatic instance segmentation functionality.
1"""Inference and evaluation for the automatic instance segmentation functionality. 2""" 3 4import os 5from glob import glob 6from tqdm import tqdm 7from pathlib import Path 8from itertools import product 9from typing import Any, Dict, List, Optional, Tuple, Union 10 11import numpy as np 12import pandas as pd 13import imageio.v3 as imageio 14 15from elf.io import open_file 16from elf.evaluation import mean_segmentation_accuracy 17 18from .. import util 19from ..instance_segmentation import AMGBase, InstanceSegmentationWithDecoder, mask_data_to_segmentation 20 21 22def _get_range_of_search_values(input_vals, step): 23 if isinstance(input_vals, list): 24 search_range = np.arange(input_vals[0], input_vals[1] + step, step) 25 search_range = [round(e, 3) for e in search_range] 26 else: 27 search_range = [input_vals] 28 return search_range 29 30 31def default_grid_search_values_amg( 32 iou_thresh_values: Optional[List[float]] = None, 33 stability_score_values: Optional[List[float]] = None, 34) -> Dict[str, List[float]]: 35 """Default grid-search parameter for AMG-based instance segmentation. 36 37 Return grid search values for the two most important parameters: 38 - `pred_iou_thresh`, the threshold for keeping objects according to the IoU predicted by the model. 39 - `stability_score_thresh`, the theshold for keepong objects according to their stability. 40 41 Args: 42 iou_thresh_values: The values for `pred_iou_thresh` used in the gridsearch. 43 By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. 44 stability_score_values: The values for `stability_score_thresh` used in the gridsearch. 45 By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. 46 47 Returns: 48 The values for grid search. 49 """ 50 if iou_thresh_values is None: 51 iou_thresh_values = _get_range_of_search_values([0.6, 0.9], step=0.025) 52 if stability_score_values is None: 53 stability_score_values = _get_range_of_search_values([0.6, 0.95], step=0.025) 54 return { 55 "pred_iou_thresh": iou_thresh_values, 56 "stability_score_thresh": stability_score_values, 57 } 58 59 60def default_grid_search_values_instance_segmentation_with_decoder( 61 center_distance_threshold_values: Optional[List[float]] = None, 62 boundary_distance_threshold_values: Optional[List[float]] = None, 63 distance_smoothing_values: Optional[List[float]] = None, 64 min_size_values: Optional[List[float]] = None, 65) -> Dict[str, List[float]]: 66 """Default grid-search parameter for decoder-based instance segmentation. 67 68 Args: 69 center_distance_threshold_values: The values for `center_distance_threshold` used in the gridsearch. 70 By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used. 71 boundary_distance_threshold_values: The values for `boundary_distance_threshold` used in the gridsearch. 72 By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used. 73 distance_smoothing_values: The values for `distance_smoothing` used in the gridsearch. 74 By default values in the range from 1.0 to 2.0 with a stepsize of 0.1 will be used. 75 min_size_values: The values for `min_size` used in the gridsearch. 76 By default the values 50, 100 and 200 are used. 77 78 Returns: 79 The values for grid search. 80 """ 81 if center_distance_threshold_values is None: 82 center_distance_threshold_values = _get_range_of_search_values( 83 [0.3, 0.7], step=0.1 84 ) 85 if boundary_distance_threshold_values is None: 86 boundary_distance_threshold_values = _get_range_of_search_values( 87 [0.3, 0.7], step=0.1 88 ) 89 if distance_smoothing_values is None: 90 distance_smoothing_values = _get_range_of_search_values( 91 [1.0, 2.0], step=0.2 92 ) 93 if min_size_values is None: 94 min_size_values = [50, 100, 200] 95 return { 96 "center_distance_threshold": center_distance_threshold_values, 97 "boundary_distance_threshold": boundary_distance_threshold_values, 98 "distance_smoothing": distance_smoothing_values, 99 "min_size": min_size_values, 100 } 101 102 103def _grid_search_iteration( 104 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 105 gs_combinations: List[Dict], 106 gt: np.ndarray, 107 image_name: str, 108 fixed_generate_kwargs: Dict[str, Any], 109 result_path: Optional[Union[str, os.PathLike]], 110 verbose: bool = False, 111) -> pd.DataFrame: 112 net_list = [] 113 for gs_kwargs in tqdm(gs_combinations, disable=not verbose): 114 generate_kwargs = gs_kwargs | fixed_generate_kwargs 115 masks = segmenter.generate(**generate_kwargs) 116 117 min_object_size = generate_kwargs.get("min_mask_region_area", 0) 118 if len(masks) == 0: 119 instance_labels = np.zeros(gt.shape, dtype="uint32") 120 else: 121 instance_labels = mask_data_to_segmentation(masks, with_background=True, min_object_size=min_object_size) 122 m_sas, sas = mean_segmentation_accuracy(instance_labels, gt, return_accuracies=True) # type: ignore 123 124 result_dict = {"image_name": image_name, "mSA": m_sas, "SA50": sas[0], "SA75": sas[5]} 125 result_dict.update(gs_kwargs) 126 tmp_df = pd.DataFrame([result_dict]) 127 net_list.append(tmp_df) 128 129 img_gs_df = pd.concat(net_list) 130 img_gs_df.to_csv(result_path, index=False) 131 132 return img_gs_df 133 134 135def _load_image(path, key, roi): 136 if key is None: 137 im = imageio.imread(path) 138 if roi is not None: 139 im = im[roi] 140 return im 141 with open_file(path, "r") as f: 142 im = f[key][:] if roi is None else f[key][roi] 143 return im 144 145 146def run_instance_segmentation_grid_search( 147 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 148 grid_search_values: Dict[str, List], 149 image_paths: List[Union[str, os.PathLike]], 150 gt_paths: List[Union[str, os.PathLike]], 151 result_dir: Union[str, os.PathLike], 152 embedding_dir: Optional[Union[str, os.PathLike]], 153 fixed_generate_kwargs: Optional[Dict[str, Any]] = None, 154 verbose_gs: bool = False, 155 image_key: Optional[str] = None, 156 gt_key: Optional[str] = None, 157 rois: Optional[Tuple[slice, ...]] = None, 158) -> None: 159 """Run grid search for automatic mask generation. 160 161 The parameters and their respective value ranges for the grid search are specified via the 162 'grid_search_values' argument. For example, to run a grid search over the parameters 'pred_iou_thresh' 163 and 'stability_score_thresh', you can pass the following: 164 ``` 165 grid_search_values = { 166 "pred_iou_thresh": [0.6, 0.7, 0.8, 0.9], 167 "stability_score_thresh": [0.6, 0.7, 0.8, 0.9], 168 } 169 ``` 170 All combinations of the parameters will be checked. 171 172 You can use the functions `default_grid_search_values_instance_segmentation_with_decoder` 173 or `default_grid_search_values_amg` to get the default grid search parameters for the two 174 respective instance segmentation methods. 175 176 Args: 177 segmenter: The class implementing the instance segmentation functionality. 178 grid_search_values: The grid search values for parameters of the `generate` function. 179 image_paths: The input images for the grid search. 180 gt_paths: The ground-truth segmentation for the grid search. 181 result_dir: Folder to cache the evaluation results per image. 182 embedding_dir: Folder to cache the image embeddings. 183 fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter. 184 verbose_gs: Whether to run the grid-search for individual images in a verbose mode. 185 image_key: Key for loading the image data from a more complex file format like HDF5. 186 If not given a simple image format like tif is assumed. 187 gt_key: Key for loading the ground-truth data from a more complex file format like HDF5. 188 If not given a simple image format like tif is assumed. 189 rois: Region of interests to resetrict the evaluation to. 190 """ 191 verbose_embeddings = False 192 193 assert len(image_paths) == len(gt_paths) 194 fixed_generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs 195 196 duplicate_params = [gs_param for gs_param in grid_search_values.keys() if gs_param in fixed_generate_kwargs] 197 if duplicate_params: 198 raise ValueError( 199 "You may not pass duplicate parameters in 'grid_search_values' and 'fixed_generate_kwargs'." 200 f"The parameters {duplicate_params} are duplicated." 201 ) 202 203 # Compute all combinations of grid search values. 204 gs_combinations = product(*grid_search_values.values()) 205 # Map each combination back to a valid kwarg input. 206 gs_combinations = [ 207 {k: v for k, v in zip(grid_search_values.keys(), vals)} for vals in gs_combinations 208 ] 209 210 os.makedirs(result_dir, exist_ok=True) 211 predictor = getattr(segmenter, "_predictor", None) 212 213 for i, (image_path, gt_path) in tqdm( 214 enumerate(zip(image_paths, gt_paths)), desc="Run instance segmentation grid-search", total=len(image_paths) 215 ): 216 image_name = Path(image_path).stem 217 result_path = os.path.join(result_dir, f"{image_name}.csv") 218 219 # We skip images for which the grid search was done already. 220 if os.path.exists(result_path): 221 continue 222 223 assert os.path.exists(image_path), image_path 224 assert os.path.exists(gt_path), gt_path 225 226 image = _load_image(image_path, image_key, roi=None if rois is None else rois[i]) 227 gt = _load_image(gt_path, gt_key, roi=None if rois is None else rois[i]) 228 229 if embedding_dir is None: 230 embedding_path = None 231 else: 232 assert predictor is not None 233 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 234 235 image_embeddings = util.precompute_image_embeddings( 236 predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings 237 ) 238 239 segmenter.initialize(image, image_embeddings) 240 241 _grid_search_iteration( 242 segmenter, gs_combinations, gt, image_name, 243 fixed_generate_kwargs=fixed_generate_kwargs, result_path=result_path, verbose=verbose_gs, 244 ) 245 246 247def run_instance_segmentation_inference( 248 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 249 image_paths: List[Union[str, os.PathLike]], 250 embedding_dir: Optional[Union[str, os.PathLike]], 251 prediction_dir: Union[str, os.PathLike], 252 generate_kwargs: Optional[Dict[str, Any]] = None, 253) -> None: 254 """Run inference for automatic mask generation. 255 256 Args: 257 segmenter: The class implementing the instance segmentation functionality. 258 image_paths: The input images. 259 embedding_dir: Folder to cache the image embeddings. 260 prediction_dir: Folder to save the predictions. 261 generate_kwargs: The keyword arguments for the `generate` method of the segmenter. 262 """ 263 264 verbose_embeddings = False 265 266 generate_kwargs = {} if generate_kwargs is None else generate_kwargs 267 predictor = segmenter._predictor 268 min_object_size = generate_kwargs.get("min_mask_region_area", 0) 269 270 for image_path in tqdm(image_paths, desc="Run inference for automatic mask generation"): 271 image_name = os.path.basename(image_path) 272 273 # We skip the images that already have been segmented. 274 prediction_path = os.path.join(prediction_dir, image_name) 275 if os.path.exists(prediction_path): 276 continue 277 278 assert os.path.exists(image_path), image_path 279 image = imageio.imread(image_path) 280 281 if embedding_dir is None: 282 embedding_path = None 283 else: 284 assert predictor is not None 285 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 286 287 image_embeddings = util.precompute_image_embeddings( 288 predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings 289 ) 290 291 segmenter.initialize(image, image_embeddings) 292 293 masks = segmenter.generate(**generate_kwargs) 294 295 if len(masks) == 0: # the instance segmentation can have no masks, hence we just save empty labels 296 if isinstance(segmenter, InstanceSegmentationWithDecoder): 297 this_shape = segmenter._foreground.shape 298 elif isinstance(segmenter, AMGBase): 299 this_shape = segmenter._original_size 300 else: 301 this_shape = image.shape[-2:] 302 303 instances = np.zeros(this_shape, dtype="uint32") 304 else: 305 instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=min_object_size) 306 307 # It's important to compress here, otherwise the predictions would take up a lot of space. 308 imageio.imwrite(prediction_path, instances, compression=5) 309 310 311def evaluate_instance_segmentation_grid_search( 312 result_dir: Union[str, os.PathLike], grid_search_parameters: List[str], criterion: str = "mSA" 313) -> Tuple[Dict[str, Any], float]: 314 """Evaluate gridsearch results. 315 316 Args: 317 result_dir: The folder with the gridsearch results. 318 grid_search_parameters: The names for the gridsearch parameters. 319 criterion: The metric to use for determining the best parameters. 320 321 Returns: 322 The best parameter setting. 323 The evaluation score for the best setting. 324 """ 325 # Load all the grid search results. 326 gs_files = glob(os.path.join(result_dir, "*.csv")) 327 gs_result = pd.concat([pd.read_csv(gs_file) for gs_file in gs_files]) 328 329 # Retrieve only the relevant columns and group by the gridsearch columns. 330 gs_result = gs_result[grid_search_parameters + [criterion]].reset_index() 331 332 # Compute the mean over the grouped columns. 333 grouped_result = gs_result.groupby(grid_search_parameters).mean().reset_index() 334 335 # Find the best score and corresponding parameters. 336 best_score, best_idx = grouped_result[criterion].max(), grouped_result[criterion].idxmax() 337 best_params = grouped_result.iloc[best_idx] 338 assert np.isclose(best_params[criterion], best_score) 339 best_kwargs = {k: v for k, v in zip(grid_search_parameters, best_params)} 340 341 return best_kwargs, best_score 342 343 344def save_grid_search_best_params(best_kwargs, best_msa, grid_search_result_dir=None): 345 # saving the best parameters estimated from grid-search in the `results` folder 346 param_df = pd.DataFrame.from_dict([best_kwargs]) 347 res_df = pd.DataFrame.from_dict([{"best_msa": best_msa}]) 348 best_param_df = pd.merge(res_df, param_df, left_index=True, right_index=True) 349 350 path_name = "grid_search_params_amg.csv" if "pred_iou_thresh" and "stability_score_thresh" in best_kwargs \ 351 else "grid_search_params_instance_segmentation_with_decoder.csv" 352 353 if grid_search_result_dir is not None: 354 os.makedirs(os.path.join(grid_search_result_dir, "results"), exist_ok=True) 355 res_path = os.path.join(grid_search_result_dir, "results", path_name) 356 else: 357 res_path = path_name 358 359 best_param_df.to_csv(res_path) 360 361 362def run_instance_segmentation_grid_search_and_inference( 363 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 364 grid_search_values: Dict[str, List], 365 val_image_paths: List[Union[str, os.PathLike]], 366 val_gt_paths: List[Union[str, os.PathLike]], 367 test_image_paths: List[Union[str, os.PathLike]], 368 embedding_dir: Optional[Union[str, os.PathLike]], 369 prediction_dir: Union[str, os.PathLike], 370 experiment_folder: Union[str, os.PathLike], 371 result_dir: Union[str, os.PathLike], 372 fixed_generate_kwargs: Optional[Dict[str, Any]] = None, 373 verbose_gs: bool = True, 374) -> None: 375 """Run grid search and inference for automatic mask generation. 376 377 Please refer to the documentation of `run_instance_segmentation_grid_search` 378 for details on how to specify the grid search parameters. 379 380 Args: 381 segmenter: The class implementing the instance segmentation functionality. 382 grid_search_values: The grid search values for parameters of the `generate` function. 383 val_image_paths: The input images for the grid search. 384 val_gt_paths: The ground-truth segmentation for the grid search. 385 test_image_paths: The input images for inference. 386 embedding_dir: Folder to cache the image embeddings. 387 prediction_dir: Folder to save the predictions. 388 experiment_folder: Folder for caching best grid search parameters in 'results'. 389 result_dir: Folder to cache the evaluation results per image. 390 fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter. 391 verbose_gs: Whether to run the gridsearch for individual images in a verbose mode. 392 """ 393 run_instance_segmentation_grid_search( 394 segmenter, grid_search_values, val_image_paths, val_gt_paths, 395 result_dir=result_dir, embedding_dir=embedding_dir, 396 fixed_generate_kwargs=fixed_generate_kwargs, verbose_gs=verbose_gs, 397 ) 398 399 best_kwargs, best_msa = evaluate_instance_segmentation_grid_search(result_dir, list(grid_search_values.keys())) 400 best_param_str = ", ".join(f"{k} = {v}" for k, v in best_kwargs.items()) 401 print("Best grid-search result:", best_msa, "with parmeters:\n", best_param_str) 402 print() 403 404 save_grid_search_best_params(best_kwargs, best_msa, experiment_folder) 405 406 generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs 407 generate_kwargs.update(best_kwargs) 408 409 run_instance_segmentation_inference( 410 segmenter, test_image_paths, embedding_dir, prediction_dir, generate_kwargs 411 )
32def default_grid_search_values_amg( 33 iou_thresh_values: Optional[List[float]] = None, 34 stability_score_values: Optional[List[float]] = None, 35) -> Dict[str, List[float]]: 36 """Default grid-search parameter for AMG-based instance segmentation. 37 38 Return grid search values for the two most important parameters: 39 - `pred_iou_thresh`, the threshold for keeping objects according to the IoU predicted by the model. 40 - `stability_score_thresh`, the theshold for keepong objects according to their stability. 41 42 Args: 43 iou_thresh_values: The values for `pred_iou_thresh` used in the gridsearch. 44 By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. 45 stability_score_values: The values for `stability_score_thresh` used in the gridsearch. 46 By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. 47 48 Returns: 49 The values for grid search. 50 """ 51 if iou_thresh_values is None: 52 iou_thresh_values = _get_range_of_search_values([0.6, 0.9], step=0.025) 53 if stability_score_values is None: 54 stability_score_values = _get_range_of_search_values([0.6, 0.95], step=0.025) 55 return { 56 "pred_iou_thresh": iou_thresh_values, 57 "stability_score_thresh": stability_score_values, 58 }
Default grid-search parameter for AMG-based instance segmentation.
Return grid search values for the two most important parameters:
pred_iou_thresh
, the threshold for keeping objects according to the IoU predicted by the model.stability_score_thresh
, the theshold for keepong objects according to their stability.
Arguments:
- iou_thresh_values: The values for
pred_iou_thresh
used in the gridsearch. By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. - stability_score_values: The values for
stability_score_thresh
used in the gridsearch. By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
Returns:
The values for grid search.
61def default_grid_search_values_instance_segmentation_with_decoder( 62 center_distance_threshold_values: Optional[List[float]] = None, 63 boundary_distance_threshold_values: Optional[List[float]] = None, 64 distance_smoothing_values: Optional[List[float]] = None, 65 min_size_values: Optional[List[float]] = None, 66) -> Dict[str, List[float]]: 67 """Default grid-search parameter for decoder-based instance segmentation. 68 69 Args: 70 center_distance_threshold_values: The values for `center_distance_threshold` used in the gridsearch. 71 By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used. 72 boundary_distance_threshold_values: The values for `boundary_distance_threshold` used in the gridsearch. 73 By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used. 74 distance_smoothing_values: The values for `distance_smoothing` used in the gridsearch. 75 By default values in the range from 1.0 to 2.0 with a stepsize of 0.1 will be used. 76 min_size_values: The values for `min_size` used in the gridsearch. 77 By default the values 50, 100 and 200 are used. 78 79 Returns: 80 The values for grid search. 81 """ 82 if center_distance_threshold_values is None: 83 center_distance_threshold_values = _get_range_of_search_values( 84 [0.3, 0.7], step=0.1 85 ) 86 if boundary_distance_threshold_values is None: 87 boundary_distance_threshold_values = _get_range_of_search_values( 88 [0.3, 0.7], step=0.1 89 ) 90 if distance_smoothing_values is None: 91 distance_smoothing_values = _get_range_of_search_values( 92 [1.0, 2.0], step=0.2 93 ) 94 if min_size_values is None: 95 min_size_values = [50, 100, 200] 96 return { 97 "center_distance_threshold": center_distance_threshold_values, 98 "boundary_distance_threshold": boundary_distance_threshold_values, 99 "distance_smoothing": distance_smoothing_values, 100 "min_size": min_size_values, 101 }
Default grid-search parameter for decoder-based instance segmentation.
Arguments:
- center_distance_threshold_values: The values for
center_distance_threshold
used in the gridsearch. By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used. - boundary_distance_threshold_values: The values for
boundary_distance_threshold
used in the gridsearch. By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used. - distance_smoothing_values: The values for
distance_smoothing
used in the gridsearch. By default values in the range from 1.0 to 2.0 with a stepsize of 0.1 will be used. - min_size_values: The values for
min_size
used in the gridsearch. By default the values 50, 100 and 200 are used.
Returns:
The values for grid search.
147def run_instance_segmentation_grid_search( 148 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 149 grid_search_values: Dict[str, List], 150 image_paths: List[Union[str, os.PathLike]], 151 gt_paths: List[Union[str, os.PathLike]], 152 result_dir: Union[str, os.PathLike], 153 embedding_dir: Optional[Union[str, os.PathLike]], 154 fixed_generate_kwargs: Optional[Dict[str, Any]] = None, 155 verbose_gs: bool = False, 156 image_key: Optional[str] = None, 157 gt_key: Optional[str] = None, 158 rois: Optional[Tuple[slice, ...]] = None, 159) -> None: 160 """Run grid search for automatic mask generation. 161 162 The parameters and their respective value ranges for the grid search are specified via the 163 'grid_search_values' argument. For example, to run a grid search over the parameters 'pred_iou_thresh' 164 and 'stability_score_thresh', you can pass the following: 165 ``` 166 grid_search_values = { 167 "pred_iou_thresh": [0.6, 0.7, 0.8, 0.9], 168 "stability_score_thresh": [0.6, 0.7, 0.8, 0.9], 169 } 170 ``` 171 All combinations of the parameters will be checked. 172 173 You can use the functions `default_grid_search_values_instance_segmentation_with_decoder` 174 or `default_grid_search_values_amg` to get the default grid search parameters for the two 175 respective instance segmentation methods. 176 177 Args: 178 segmenter: The class implementing the instance segmentation functionality. 179 grid_search_values: The grid search values for parameters of the `generate` function. 180 image_paths: The input images for the grid search. 181 gt_paths: The ground-truth segmentation for the grid search. 182 result_dir: Folder to cache the evaluation results per image. 183 embedding_dir: Folder to cache the image embeddings. 184 fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter. 185 verbose_gs: Whether to run the grid-search for individual images in a verbose mode. 186 image_key: Key for loading the image data from a more complex file format like HDF5. 187 If not given a simple image format like tif is assumed. 188 gt_key: Key for loading the ground-truth data from a more complex file format like HDF5. 189 If not given a simple image format like tif is assumed. 190 rois: Region of interests to resetrict the evaluation to. 191 """ 192 verbose_embeddings = False 193 194 assert len(image_paths) == len(gt_paths) 195 fixed_generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs 196 197 duplicate_params = [gs_param for gs_param in grid_search_values.keys() if gs_param in fixed_generate_kwargs] 198 if duplicate_params: 199 raise ValueError( 200 "You may not pass duplicate parameters in 'grid_search_values' and 'fixed_generate_kwargs'." 201 f"The parameters {duplicate_params} are duplicated." 202 ) 203 204 # Compute all combinations of grid search values. 205 gs_combinations = product(*grid_search_values.values()) 206 # Map each combination back to a valid kwarg input. 207 gs_combinations = [ 208 {k: v for k, v in zip(grid_search_values.keys(), vals)} for vals in gs_combinations 209 ] 210 211 os.makedirs(result_dir, exist_ok=True) 212 predictor = getattr(segmenter, "_predictor", None) 213 214 for i, (image_path, gt_path) in tqdm( 215 enumerate(zip(image_paths, gt_paths)), desc="Run instance segmentation grid-search", total=len(image_paths) 216 ): 217 image_name = Path(image_path).stem 218 result_path = os.path.join(result_dir, f"{image_name}.csv") 219 220 # We skip images for which the grid search was done already. 221 if os.path.exists(result_path): 222 continue 223 224 assert os.path.exists(image_path), image_path 225 assert os.path.exists(gt_path), gt_path 226 227 image = _load_image(image_path, image_key, roi=None if rois is None else rois[i]) 228 gt = _load_image(gt_path, gt_key, roi=None if rois is None else rois[i]) 229 230 if embedding_dir is None: 231 embedding_path = None 232 else: 233 assert predictor is not None 234 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 235 236 image_embeddings = util.precompute_image_embeddings( 237 predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings 238 ) 239 240 segmenter.initialize(image, image_embeddings) 241 242 _grid_search_iteration( 243 segmenter, gs_combinations, gt, image_name, 244 fixed_generate_kwargs=fixed_generate_kwargs, result_path=result_path, verbose=verbose_gs, 245 )
Run grid search for automatic mask generation.
The parameters and their respective value ranges for the grid search are specified via the 'grid_search_values' argument. For example, to run a grid search over the parameters 'pred_iou_thresh' and 'stability_score_thresh', you can pass the following:
grid_search_values = {
"pred_iou_thresh": [0.6, 0.7, 0.8, 0.9],
"stability_score_thresh": [0.6, 0.7, 0.8, 0.9],
}
All combinations of the parameters will be checked.
You can use the functions default_grid_search_values_instance_segmentation_with_decoder
or default_grid_search_values_amg
to get the default grid search parameters for the two
respective instance segmentation methods.
Arguments:
- segmenter: The class implementing the instance segmentation functionality.
- grid_search_values: The grid search values for parameters of the
generate
function. - image_paths: The input images for the grid search.
- gt_paths: The ground-truth segmentation for the grid search.
- result_dir: Folder to cache the evaluation results per image.
- embedding_dir: Folder to cache the image embeddings.
- fixed_generate_kwargs: Fixed keyword arguments for the
generate
method of the segmenter. - verbose_gs: Whether to run the grid-search for individual images in a verbose mode.
- image_key: Key for loading the image data from a more complex file format like HDF5. If not given a simple image format like tif is assumed.
- gt_key: Key for loading the ground-truth data from a more complex file format like HDF5. If not given a simple image format like tif is assumed.
- rois: Region of interests to resetrict the evaluation to.
248def run_instance_segmentation_inference( 249 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 250 image_paths: List[Union[str, os.PathLike]], 251 embedding_dir: Optional[Union[str, os.PathLike]], 252 prediction_dir: Union[str, os.PathLike], 253 generate_kwargs: Optional[Dict[str, Any]] = None, 254) -> None: 255 """Run inference for automatic mask generation. 256 257 Args: 258 segmenter: The class implementing the instance segmentation functionality. 259 image_paths: The input images. 260 embedding_dir: Folder to cache the image embeddings. 261 prediction_dir: Folder to save the predictions. 262 generate_kwargs: The keyword arguments for the `generate` method of the segmenter. 263 """ 264 265 verbose_embeddings = False 266 267 generate_kwargs = {} if generate_kwargs is None else generate_kwargs 268 predictor = segmenter._predictor 269 min_object_size = generate_kwargs.get("min_mask_region_area", 0) 270 271 for image_path in tqdm(image_paths, desc="Run inference for automatic mask generation"): 272 image_name = os.path.basename(image_path) 273 274 # We skip the images that already have been segmented. 275 prediction_path = os.path.join(prediction_dir, image_name) 276 if os.path.exists(prediction_path): 277 continue 278 279 assert os.path.exists(image_path), image_path 280 image = imageio.imread(image_path) 281 282 if embedding_dir is None: 283 embedding_path = None 284 else: 285 assert predictor is not None 286 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 287 288 image_embeddings = util.precompute_image_embeddings( 289 predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings 290 ) 291 292 segmenter.initialize(image, image_embeddings) 293 294 masks = segmenter.generate(**generate_kwargs) 295 296 if len(masks) == 0: # the instance segmentation can have no masks, hence we just save empty labels 297 if isinstance(segmenter, InstanceSegmentationWithDecoder): 298 this_shape = segmenter._foreground.shape 299 elif isinstance(segmenter, AMGBase): 300 this_shape = segmenter._original_size 301 else: 302 this_shape = image.shape[-2:] 303 304 instances = np.zeros(this_shape, dtype="uint32") 305 else: 306 instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=min_object_size) 307 308 # It's important to compress here, otherwise the predictions would take up a lot of space. 309 imageio.imwrite(prediction_path, instances, compression=5)
Run inference for automatic mask generation.
Arguments:
- segmenter: The class implementing the instance segmentation functionality.
- image_paths: The input images.
- embedding_dir: Folder to cache the image embeddings.
- prediction_dir: Folder to save the predictions.
- generate_kwargs: The keyword arguments for the
generate
method of the segmenter.
312def evaluate_instance_segmentation_grid_search( 313 result_dir: Union[str, os.PathLike], grid_search_parameters: List[str], criterion: str = "mSA" 314) -> Tuple[Dict[str, Any], float]: 315 """Evaluate gridsearch results. 316 317 Args: 318 result_dir: The folder with the gridsearch results. 319 grid_search_parameters: The names for the gridsearch parameters. 320 criterion: The metric to use for determining the best parameters. 321 322 Returns: 323 The best parameter setting. 324 The evaluation score for the best setting. 325 """ 326 # Load all the grid search results. 327 gs_files = glob(os.path.join(result_dir, "*.csv")) 328 gs_result = pd.concat([pd.read_csv(gs_file) for gs_file in gs_files]) 329 330 # Retrieve only the relevant columns and group by the gridsearch columns. 331 gs_result = gs_result[grid_search_parameters + [criterion]].reset_index() 332 333 # Compute the mean over the grouped columns. 334 grouped_result = gs_result.groupby(grid_search_parameters).mean().reset_index() 335 336 # Find the best score and corresponding parameters. 337 best_score, best_idx = grouped_result[criterion].max(), grouped_result[criterion].idxmax() 338 best_params = grouped_result.iloc[best_idx] 339 assert np.isclose(best_params[criterion], best_score) 340 best_kwargs = {k: v for k, v in zip(grid_search_parameters, best_params)} 341 342 return best_kwargs, best_score
Evaluate gridsearch results.
Arguments:
- result_dir: The folder with the gridsearch results.
- grid_search_parameters: The names for the gridsearch parameters.
- criterion: The metric to use for determining the best parameters.
Returns:
The best parameter setting. The evaluation score for the best setting.
345def save_grid_search_best_params(best_kwargs, best_msa, grid_search_result_dir=None): 346 # saving the best parameters estimated from grid-search in the `results` folder 347 param_df = pd.DataFrame.from_dict([best_kwargs]) 348 res_df = pd.DataFrame.from_dict([{"best_msa": best_msa}]) 349 best_param_df = pd.merge(res_df, param_df, left_index=True, right_index=True) 350 351 path_name = "grid_search_params_amg.csv" if "pred_iou_thresh" and "stability_score_thresh" in best_kwargs \ 352 else "grid_search_params_instance_segmentation_with_decoder.csv" 353 354 if grid_search_result_dir is not None: 355 os.makedirs(os.path.join(grid_search_result_dir, "results"), exist_ok=True) 356 res_path = os.path.join(grid_search_result_dir, "results", path_name) 357 else: 358 res_path = path_name 359 360 best_param_df.to_csv(res_path)
363def run_instance_segmentation_grid_search_and_inference( 364 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 365 grid_search_values: Dict[str, List], 366 val_image_paths: List[Union[str, os.PathLike]], 367 val_gt_paths: List[Union[str, os.PathLike]], 368 test_image_paths: List[Union[str, os.PathLike]], 369 embedding_dir: Optional[Union[str, os.PathLike]], 370 prediction_dir: Union[str, os.PathLike], 371 experiment_folder: Union[str, os.PathLike], 372 result_dir: Union[str, os.PathLike], 373 fixed_generate_kwargs: Optional[Dict[str, Any]] = None, 374 verbose_gs: bool = True, 375) -> None: 376 """Run grid search and inference for automatic mask generation. 377 378 Please refer to the documentation of `run_instance_segmentation_grid_search` 379 for details on how to specify the grid search parameters. 380 381 Args: 382 segmenter: The class implementing the instance segmentation functionality. 383 grid_search_values: The grid search values for parameters of the `generate` function. 384 val_image_paths: The input images for the grid search. 385 val_gt_paths: The ground-truth segmentation for the grid search. 386 test_image_paths: The input images for inference. 387 embedding_dir: Folder to cache the image embeddings. 388 prediction_dir: Folder to save the predictions. 389 experiment_folder: Folder for caching best grid search parameters in 'results'. 390 result_dir: Folder to cache the evaluation results per image. 391 fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter. 392 verbose_gs: Whether to run the gridsearch for individual images in a verbose mode. 393 """ 394 run_instance_segmentation_grid_search( 395 segmenter, grid_search_values, val_image_paths, val_gt_paths, 396 result_dir=result_dir, embedding_dir=embedding_dir, 397 fixed_generate_kwargs=fixed_generate_kwargs, verbose_gs=verbose_gs, 398 ) 399 400 best_kwargs, best_msa = evaluate_instance_segmentation_grid_search(result_dir, list(grid_search_values.keys())) 401 best_param_str = ", ".join(f"{k} = {v}" for k, v in best_kwargs.items()) 402 print("Best grid-search result:", best_msa, "with parmeters:\n", best_param_str) 403 print() 404 405 save_grid_search_best_params(best_kwargs, best_msa, experiment_folder) 406 407 generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs 408 generate_kwargs.update(best_kwargs) 409 410 run_instance_segmentation_inference( 411 segmenter, test_image_paths, embedding_dir, prediction_dir, generate_kwargs 412 )
Run grid search and inference for automatic mask generation.
Please refer to the documentation of run_instance_segmentation_grid_search
for details on how to specify the grid search parameters.
Arguments:
- segmenter: The class implementing the instance segmentation functionality.
- grid_search_values: The grid search values for parameters of the
generate
function. - val_image_paths: The input images for the grid search.
- val_gt_paths: The ground-truth segmentation for the grid search.
- test_image_paths: The input images for inference.
- embedding_dir: Folder to cache the image embeddings.
- prediction_dir: Folder to save the predictions.
- experiment_folder: Folder for caching best grid search parameters in 'results'.
- result_dir: Folder to cache the evaluation results per image.
- fixed_generate_kwargs: Fixed keyword arguments for the
generate
method of the segmenter. - verbose_gs: Whether to run the gridsearch for individual images in a verbose mode.