micro_sam.evaluation.instance_segmentation
Inference and evaluation for the automatic instance segmentation functionality.
1"""Inference and evaluation for the automatic instance segmentation functionality. 2""" 3 4import os 5from glob import glob 6from tqdm import tqdm 7from pathlib import Path 8from itertools import product 9from typing import Any, Dict, List, Optional, Tuple, Union 10 11import numpy as np 12import pandas as pd 13import imageio.v3 as imageio 14 15from elf.io import open_file 16from elf.evaluation import mean_segmentation_accuracy 17 18from .. import util 19from ..instance_segmentation import AMGBase, InstanceSegmentationWithDecoder, mask_data_to_segmentation 20 21 22def _get_range_of_search_values(input_vals, step): 23 if isinstance(input_vals, list): 24 search_range = np.arange(input_vals[0], input_vals[1] + step, step) 25 search_range = [round(e, 3) for e in search_range] 26 else: 27 search_range = [input_vals] 28 return search_range 29 30 31def default_grid_search_values_amg( 32 iou_thresh_values: Optional[List[float]] = None, stability_score_values: Optional[List[float]] = None, 33) -> Dict[str, List[float]]: 34 """Default grid-search parameter for AMG-based instance segmentation. 35 36 Return grid search values for the two most important parameters: 37 - `pred_iou_thresh`, the threshold for keeping objects according to the IoU predicted by the model. 38 - `stability_score_thresh`, the theshold for keepong objects according to their stability. 39 40 Args: 41 iou_thresh_values: The values for `pred_iou_thresh` used in the gridsearch. 42 By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. 43 stability_score_values: The values for `stability_score_thresh` used in the gridsearch. 44 By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. 45 46 Returns: 47 The values for grid search. 48 """ 49 if iou_thresh_values is None: 50 iou_thresh_values = _get_range_of_search_values([0.6, 0.9], step=0.025) 51 if stability_score_values is None: 52 stability_score_values = _get_range_of_search_values([0.6, 0.95], step=0.025) 53 return { 54 "pred_iou_thresh": iou_thresh_values, 55 "stability_score_thresh": stability_score_values, 56 } 57 58 59def default_grid_search_values_instance_segmentation_with_decoder( 60 center_distance_threshold_values: Optional[List[float]] = None, 61 boundary_distance_threshold_values: Optional[List[float]] = None, 62 distance_smoothing_values: Optional[List[float]] = None, 63 min_size_values: Optional[List[float]] = None, 64) -> Dict[str, List[float]]: 65 """Default grid-search parameter for decoder-based instance segmentation. 66 67 Args: 68 center_distance_threshold_values: The values for `center_distance_threshold` used in the gridsearch. 69 By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used. 70 boundary_distance_threshold_values: The values for `boundary_distance_threshold` used in the gridsearch. 71 By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used. 72 distance_smoothing_values: The values for `distance_smoothing` used in the gridsearch. 73 By default values in the range from 1.0 to 2.0 with a stepsize of 0.1 will be used. 74 min_size_values: The values for `min_size` used in the gridsearch. 75 By default the values 50, 100 and 200 are used. 76 77 Returns: 78 The values for grid search. 79 """ 80 if center_distance_threshold_values is None: 81 center_distance_threshold_values = _get_range_of_search_values( 82 [0.3, 0.7], step=0.1 83 ) 84 if boundary_distance_threshold_values is None: 85 boundary_distance_threshold_values = _get_range_of_search_values( 86 [0.3, 0.7], step=0.1 87 ) 88 if distance_smoothing_values is None: 89 distance_smoothing_values = _get_range_of_search_values( 90 [1.0, 2.0], step=0.2 91 ) 92 if min_size_values is None: 93 min_size_values = [50, 100, 200] 94 95 return { 96 "center_distance_threshold": center_distance_threshold_values, 97 "boundary_distance_threshold": boundary_distance_threshold_values, 98 "distance_smoothing": distance_smoothing_values, 99 "min_size": min_size_values, 100 } 101 102 103def _grid_search_iteration( 104 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 105 gs_combinations: List[Dict], 106 gt: np.ndarray, 107 image_name: str, 108 fixed_generate_kwargs: Dict[str, Any], 109 result_path: Optional[Union[str, os.PathLike]], 110 verbose: bool = False, 111) -> pd.DataFrame: 112 net_list = [] 113 for gs_kwargs in tqdm(gs_combinations, disable=not verbose): 114 generate_kwargs = gs_kwargs | fixed_generate_kwargs 115 masks = segmenter.generate(**generate_kwargs) 116 117 min_object_size = generate_kwargs.get("min_mask_region_area", 0) 118 if len(masks) == 0: 119 instance_labels = np.zeros(gt.shape, dtype="uint32") 120 else: 121 instance_labels = mask_data_to_segmentation(masks, with_background=True, min_object_size=min_object_size) 122 m_sas, sas = mean_segmentation_accuracy(instance_labels, gt, return_accuracies=True) # type: ignore 123 124 result_dict = {"image_name": image_name, "mSA": m_sas, "SA50": sas[0], "SA75": sas[5]} 125 result_dict.update(gs_kwargs) 126 tmp_df = pd.DataFrame([result_dict]) 127 net_list.append(tmp_df) 128 129 img_gs_df = pd.concat(net_list) 130 img_gs_df.to_csv(result_path, index=False) 131 132 return img_gs_df 133 134 135def _load_image(path, key, roi): 136 if key is None: 137 im = imageio.imread(path) 138 if roi is not None: 139 im = im[roi] 140 return im 141 with open_file(path, "r") as f: 142 im = f[key][:] if roi is None else f[key][roi] 143 144 return im 145 146 147def run_instance_segmentation_grid_search( 148 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 149 grid_search_values: Dict[str, List], 150 image_paths: List[Union[str, os.PathLike]], 151 gt_paths: List[Union[str, os.PathLike]], 152 result_dir: Union[str, os.PathLike], 153 embedding_dir: Optional[Union[str, os.PathLike]], 154 fixed_generate_kwargs: Optional[Dict[str, Any]] = None, 155 verbose_gs: bool = False, 156 image_key: Optional[str] = None, 157 gt_key: Optional[str] = None, 158 rois: Optional[Tuple[slice, ...]] = None, 159 tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None, 160) -> None: 161 """Run grid search for automatic mask generation. 162 163 The parameters and their respective value ranges for the grid search are specified via the 164 'grid_search_values' argument. For example, to run a grid search over the parameters 'pred_iou_thresh' 165 and 'stability_score_thresh', you can pass the following: 166 ``` 167 grid_search_values = { 168 "pred_iou_thresh": [0.6, 0.7, 0.8, 0.9], 169 "stability_score_thresh": [0.6, 0.7, 0.8, 0.9], 170 } 171 ``` 172 All combinations of the parameters will be checked. 173 174 You can use the functions `default_grid_search_values_instance_segmentation_with_decoder` 175 or `default_grid_search_values_amg` to get the default grid search parameters for the two 176 respective instance segmentation methods. 177 178 Args: 179 segmenter: The class implementing the instance segmentation functionality. 180 grid_search_values: The grid search values for parameters of the `generate` function. 181 image_paths: The input images for the grid search. 182 gt_paths: The ground-truth segmentation for the grid search. 183 result_dir: Folder to cache the evaluation results per image. 184 embedding_dir: Folder to cache the image embeddings. 185 fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter. 186 verbose_gs: Whether to run the grid-search for individual images in a verbose mode. 187 image_key: Key for loading the image data from a more complex file format like HDF5. 188 If not given a simple image format like tif is assumed. 189 gt_key: Key for loading the ground-truth data from a more complex file format like HDF5. 190 If not given a simple image format like tif is assumed. 191 rois: Region of interests to resetrict the evaluation to. 192 tiling_window_params: The parameters to decide whether to use tiling window operation for AIS. 193 """ 194 verbose_embeddings = False 195 196 assert len(image_paths) == len(gt_paths) 197 fixed_generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs 198 199 duplicate_params = [gs_param for gs_param in grid_search_values.keys() if gs_param in fixed_generate_kwargs] 200 if duplicate_params: 201 raise ValueError( 202 "You may not pass duplicate parameters in 'grid_search_values' and 'fixed_generate_kwargs'." 203 f"The parameters {duplicate_params} are duplicated." 204 ) 205 206 # Compute all combinations of grid search values. 207 gs_combinations = product(*grid_search_values.values()) 208 # Map each combination back to a valid kwarg input. 209 gs_combinations = [ 210 {k: v for k, v in zip(grid_search_values.keys(), vals)} for vals in gs_combinations 211 ] 212 213 os.makedirs(result_dir, exist_ok=True) 214 predictor = getattr(segmenter, "_predictor", None) 215 216 for i, (image_path, gt_path) in tqdm( 217 enumerate(zip(image_paths, gt_paths)), desc="Run instance segmentation grid-search", total=len(image_paths) 218 ): 219 image_name = Path(image_path).stem 220 result_path = os.path.join(result_dir, f"{image_name}.csv") 221 222 # We skip images for which the grid search was done already. 223 if os.path.exists(result_path): 224 continue 225 226 assert os.path.exists(image_path), image_path 227 assert os.path.exists(gt_path), gt_path 228 229 image = _load_image(image_path, image_key, roi=None if rois is None else rois[i]) 230 gt = _load_image(gt_path, gt_key, roi=None if rois is None else rois[i]) 231 232 if embedding_dir is None: 233 embedding_path = None 234 else: 235 assert predictor is not None 236 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 237 238 if tiling_window_params is None: 239 tiling_window_params = {} 240 241 image_embeddings = util.precompute_image_embeddings( 242 predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings, **tiling_window_params 243 ) 244 245 segmenter.initialize(image, image_embeddings, **tiling_window_params) 246 247 _grid_search_iteration( 248 segmenter, gs_combinations, gt, image_name, 249 fixed_generate_kwargs=fixed_generate_kwargs, result_path=result_path, verbose=verbose_gs, 250 ) 251 252 253def run_instance_segmentation_inference( 254 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 255 image_paths: List[Union[str, os.PathLike]], 256 embedding_dir: Optional[Union[str, os.PathLike]], 257 prediction_dir: Union[str, os.PathLike], 258 generate_kwargs: Optional[Dict[str, Any]] = None, 259 tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None, 260) -> None: 261 """Run inference for automatic mask generation. 262 263 Args: 264 segmenter: The class implementing the instance segmentation functionality. 265 image_paths: The input images. 266 embedding_dir: Folder to cache the image embeddings. 267 prediction_dir: Folder to save the predictions. 268 generate_kwargs: The keyword arguments for the `generate` method of the segmenter. 269 tiling_window_params: The parameters to decide whether to use tiling window operation 270 for automatic segmentation. 271 """ 272 273 verbose_embeddings = False 274 275 generate_kwargs = {} if generate_kwargs is None else generate_kwargs 276 predictor = segmenter._predictor 277 min_object_size = generate_kwargs.get("min_mask_region_area", 0) 278 279 for image_path in tqdm(image_paths, desc="Run inference for automatic mask generation"): 280 image_name = os.path.basename(image_path) 281 282 # We skip the images that already have been segmented. 283 prediction_path = os.path.join(prediction_dir, image_name) 284 if os.path.exists(prediction_path): 285 continue 286 287 assert os.path.exists(image_path), image_path 288 image = imageio.imread(image_path) 289 290 if embedding_dir is None: 291 embedding_path = None 292 else: 293 assert predictor is not None 294 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 295 296 if tiling_window_params is None: 297 tiling_window_params = {} 298 299 image_embeddings = util.precompute_image_embeddings( 300 predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings, **tiling_window_params 301 ) 302 303 segmenter.initialize(image, image_embeddings, **tiling_window_params) 304 305 masks = segmenter.generate(**generate_kwargs) 306 307 if len(masks) == 0: # the instance segmentation can have no masks, hence we just save empty labels 308 if isinstance(segmenter, InstanceSegmentationWithDecoder): 309 this_shape = segmenter._foreground.shape 310 elif isinstance(segmenter, AMGBase): 311 this_shape = segmenter._original_size 312 else: 313 this_shape = image.shape[-2:] 314 315 instances = np.zeros(this_shape, dtype="uint32") 316 else: 317 instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=min_object_size) 318 319 # It's important to compress here, otherwise the predictions would take up a lot of space. 320 imageio.imwrite(prediction_path, instances, compression=5) 321 322 323def evaluate_instance_segmentation_grid_search( 324 result_dir: Union[str, os.PathLike], grid_search_parameters: List[str], criterion: str = "mSA" 325) -> Tuple[Dict[str, Any], float]: 326 """Evaluate gridsearch results. 327 328 Args: 329 result_dir: The folder with the gridsearch results. 330 grid_search_parameters: The names for the gridsearch parameters. 331 criterion: The metric to use for determining the best parameters. 332 333 Returns: 334 The best parameter setting. 335 The evaluation score for the best setting. 336 """ 337 # Load all the grid search results. 338 gs_files = glob(os.path.join(result_dir, "*.csv")) 339 gs_result = pd.concat([pd.read_csv(gs_file) for gs_file in gs_files]) 340 341 # Retrieve only the relevant columns and group by the gridsearch columns. 342 gs_result = gs_result[grid_search_parameters + [criterion]].reset_index() 343 344 # Compute the mean over the grouped columns. 345 grouped_result = gs_result.groupby(grid_search_parameters).mean().reset_index() 346 347 # Find the best score and corresponding parameters. 348 best_score, best_idx = grouped_result[criterion].max(), grouped_result[criterion].idxmax() 349 best_params = grouped_result.iloc[best_idx] 350 assert np.isclose(best_params[criterion], best_score) 351 best_kwargs = {k: v for k, v in zip(grid_search_parameters, best_params)} 352 353 return best_kwargs, best_score 354 355 356def save_grid_search_best_params(best_kwargs, best_msa, grid_search_result_dir=None): 357 # saving the best parameters estimated from grid-search in the `results` folder 358 param_df = pd.DataFrame.from_dict([best_kwargs]) 359 res_df = pd.DataFrame.from_dict([{"best_msa": best_msa}]) 360 best_param_df = pd.merge(res_df, param_df, left_index=True, right_index=True) 361 362 path_name = "grid_search_params_amg.csv" if "pred_iou_thresh" and "stability_score_thresh" in best_kwargs \ 363 else "grid_search_params_instance_segmentation_with_decoder.csv" 364 365 if grid_search_result_dir is not None: 366 os.makedirs(os.path.join(grid_search_result_dir, "results"), exist_ok=True) 367 res_path = os.path.join(grid_search_result_dir, "results", path_name) 368 else: 369 res_path = path_name 370 371 best_param_df.to_csv(res_path) 372 373 374def run_instance_segmentation_grid_search_and_inference( 375 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 376 grid_search_values: Dict[str, List], 377 val_image_paths: List[Union[str, os.PathLike]], 378 val_gt_paths: List[Union[str, os.PathLike]], 379 test_image_paths: List[Union[str, os.PathLike]], 380 embedding_dir: Optional[Union[str, os.PathLike]], 381 prediction_dir: Union[str, os.PathLike], 382 experiment_folder: Union[str, os.PathLike], 383 result_dir: Union[str, os.PathLike], 384 fixed_generate_kwargs: Optional[Dict[str, Any]] = None, 385 verbose_gs: bool = True, 386 tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None, 387) -> None: 388 """Run grid search and inference for automatic mask generation. 389 390 Please refer to the documentation of `run_instance_segmentation_grid_search` 391 for details on how to specify the grid search parameters. 392 393 Args: 394 segmenter: The class implementing the instance segmentation functionality. 395 grid_search_values: The grid search values for parameters of the `generate` function. 396 val_image_paths: The input images for the grid search. 397 val_gt_paths: The ground-truth segmentation for the grid search. 398 test_image_paths: The input images for inference. 399 embedding_dir: Folder to cache the image embeddings. 400 prediction_dir: Folder to save the predictions. 401 experiment_folder: Folder for caching best grid search parameters in 'results'. 402 result_dir: Folder to cache the evaluation results per image. 403 fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter. 404 verbose_gs: Whether to run the gridsearch for individual images in a verbose mode. 405 tiling_window_params: The parameters to decide whether to use tiling window operation 406 for automatic segmentation. 407 """ 408 run_instance_segmentation_grid_search( 409 segmenter=segmenter, 410 grid_search_values=grid_search_values, 411 image_paths=val_image_paths, 412 gt_paths=val_gt_paths, 413 result_dir=result_dir, 414 embedding_dir=embedding_dir, 415 fixed_generate_kwargs=fixed_generate_kwargs, 416 verbose_gs=verbose_gs, 417 tiling_window_params=tiling_window_params, 418 ) 419 420 best_kwargs, best_msa = evaluate_instance_segmentation_grid_search(result_dir, list(grid_search_values.keys())) 421 best_param_str = ", ".join(f"{k} = {v}" for k, v in best_kwargs.items()) 422 print("Best grid-search result:", best_msa, "with parmeters:\n", best_param_str) 423 print() 424 425 save_grid_search_best_params(best_kwargs, best_msa, experiment_folder) 426 427 generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs 428 generate_kwargs.update(best_kwargs) 429 430 run_instance_segmentation_inference( 431 segmenter=segmenter, 432 image_paths=test_image_paths, 433 embedding_dir=embedding_dir, 434 prediction_dir=prediction_dir, 435 generate_kwargs=generate_kwargs, 436 tiling_window_params=tiling_window_params, 437 )
32def default_grid_search_values_amg( 33 iou_thresh_values: Optional[List[float]] = None, stability_score_values: Optional[List[float]] = None, 34) -> Dict[str, List[float]]: 35 """Default grid-search parameter for AMG-based instance segmentation. 36 37 Return grid search values for the two most important parameters: 38 - `pred_iou_thresh`, the threshold for keeping objects according to the IoU predicted by the model. 39 - `stability_score_thresh`, the theshold for keepong objects according to their stability. 40 41 Args: 42 iou_thresh_values: The values for `pred_iou_thresh` used in the gridsearch. 43 By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. 44 stability_score_values: The values for `stability_score_thresh` used in the gridsearch. 45 By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. 46 47 Returns: 48 The values for grid search. 49 """ 50 if iou_thresh_values is None: 51 iou_thresh_values = _get_range_of_search_values([0.6, 0.9], step=0.025) 52 if stability_score_values is None: 53 stability_score_values = _get_range_of_search_values([0.6, 0.95], step=0.025) 54 return { 55 "pred_iou_thresh": iou_thresh_values, 56 "stability_score_thresh": stability_score_values, 57 }
Default grid-search parameter for AMG-based instance segmentation.
Return grid search values for the two most important parameters:
pred_iou_thresh
, the threshold for keeping objects according to the IoU predicted by the model.stability_score_thresh
, the theshold for keepong objects according to their stability.
Arguments:
- iou_thresh_values: The values for
pred_iou_thresh
used in the gridsearch. By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. - stability_score_values: The values for
stability_score_thresh
used in the gridsearch. By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
Returns:
The values for grid search.
60def default_grid_search_values_instance_segmentation_with_decoder( 61 center_distance_threshold_values: Optional[List[float]] = None, 62 boundary_distance_threshold_values: Optional[List[float]] = None, 63 distance_smoothing_values: Optional[List[float]] = None, 64 min_size_values: Optional[List[float]] = None, 65) -> Dict[str, List[float]]: 66 """Default grid-search parameter for decoder-based instance segmentation. 67 68 Args: 69 center_distance_threshold_values: The values for `center_distance_threshold` used in the gridsearch. 70 By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used. 71 boundary_distance_threshold_values: The values for `boundary_distance_threshold` used in the gridsearch. 72 By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used. 73 distance_smoothing_values: The values for `distance_smoothing` used in the gridsearch. 74 By default values in the range from 1.0 to 2.0 with a stepsize of 0.1 will be used. 75 min_size_values: The values for `min_size` used in the gridsearch. 76 By default the values 50, 100 and 200 are used. 77 78 Returns: 79 The values for grid search. 80 """ 81 if center_distance_threshold_values is None: 82 center_distance_threshold_values = _get_range_of_search_values( 83 [0.3, 0.7], step=0.1 84 ) 85 if boundary_distance_threshold_values is None: 86 boundary_distance_threshold_values = _get_range_of_search_values( 87 [0.3, 0.7], step=0.1 88 ) 89 if distance_smoothing_values is None: 90 distance_smoothing_values = _get_range_of_search_values( 91 [1.0, 2.0], step=0.2 92 ) 93 if min_size_values is None: 94 min_size_values = [50, 100, 200] 95 96 return { 97 "center_distance_threshold": center_distance_threshold_values, 98 "boundary_distance_threshold": boundary_distance_threshold_values, 99 "distance_smoothing": distance_smoothing_values, 100 "min_size": min_size_values, 101 }
Default grid-search parameter for decoder-based instance segmentation.
Arguments:
- center_distance_threshold_values: The values for
center_distance_threshold
used in the gridsearch. By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used. - boundary_distance_threshold_values: The values for
boundary_distance_threshold
used in the gridsearch. By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used. - distance_smoothing_values: The values for
distance_smoothing
used in the gridsearch. By default values in the range from 1.0 to 2.0 with a stepsize of 0.1 will be used. - min_size_values: The values for
min_size
used in the gridsearch. By default the values 50, 100 and 200 are used.
Returns:
The values for grid search.
148def run_instance_segmentation_grid_search( 149 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 150 grid_search_values: Dict[str, List], 151 image_paths: List[Union[str, os.PathLike]], 152 gt_paths: List[Union[str, os.PathLike]], 153 result_dir: Union[str, os.PathLike], 154 embedding_dir: Optional[Union[str, os.PathLike]], 155 fixed_generate_kwargs: Optional[Dict[str, Any]] = None, 156 verbose_gs: bool = False, 157 image_key: Optional[str] = None, 158 gt_key: Optional[str] = None, 159 rois: Optional[Tuple[slice, ...]] = None, 160 tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None, 161) -> None: 162 """Run grid search for automatic mask generation. 163 164 The parameters and their respective value ranges for the grid search are specified via the 165 'grid_search_values' argument. For example, to run a grid search over the parameters 'pred_iou_thresh' 166 and 'stability_score_thresh', you can pass the following: 167 ``` 168 grid_search_values = { 169 "pred_iou_thresh": [0.6, 0.7, 0.8, 0.9], 170 "stability_score_thresh": [0.6, 0.7, 0.8, 0.9], 171 } 172 ``` 173 All combinations of the parameters will be checked. 174 175 You can use the functions `default_grid_search_values_instance_segmentation_with_decoder` 176 or `default_grid_search_values_amg` to get the default grid search parameters for the two 177 respective instance segmentation methods. 178 179 Args: 180 segmenter: The class implementing the instance segmentation functionality. 181 grid_search_values: The grid search values for parameters of the `generate` function. 182 image_paths: The input images for the grid search. 183 gt_paths: The ground-truth segmentation for the grid search. 184 result_dir: Folder to cache the evaluation results per image. 185 embedding_dir: Folder to cache the image embeddings. 186 fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter. 187 verbose_gs: Whether to run the grid-search for individual images in a verbose mode. 188 image_key: Key for loading the image data from a more complex file format like HDF5. 189 If not given a simple image format like tif is assumed. 190 gt_key: Key for loading the ground-truth data from a more complex file format like HDF5. 191 If not given a simple image format like tif is assumed. 192 rois: Region of interests to resetrict the evaluation to. 193 tiling_window_params: The parameters to decide whether to use tiling window operation for AIS. 194 """ 195 verbose_embeddings = False 196 197 assert len(image_paths) == len(gt_paths) 198 fixed_generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs 199 200 duplicate_params = [gs_param for gs_param in grid_search_values.keys() if gs_param in fixed_generate_kwargs] 201 if duplicate_params: 202 raise ValueError( 203 "You may not pass duplicate parameters in 'grid_search_values' and 'fixed_generate_kwargs'." 204 f"The parameters {duplicate_params} are duplicated." 205 ) 206 207 # Compute all combinations of grid search values. 208 gs_combinations = product(*grid_search_values.values()) 209 # Map each combination back to a valid kwarg input. 210 gs_combinations = [ 211 {k: v for k, v in zip(grid_search_values.keys(), vals)} for vals in gs_combinations 212 ] 213 214 os.makedirs(result_dir, exist_ok=True) 215 predictor = getattr(segmenter, "_predictor", None) 216 217 for i, (image_path, gt_path) in tqdm( 218 enumerate(zip(image_paths, gt_paths)), desc="Run instance segmentation grid-search", total=len(image_paths) 219 ): 220 image_name = Path(image_path).stem 221 result_path = os.path.join(result_dir, f"{image_name}.csv") 222 223 # We skip images for which the grid search was done already. 224 if os.path.exists(result_path): 225 continue 226 227 assert os.path.exists(image_path), image_path 228 assert os.path.exists(gt_path), gt_path 229 230 image = _load_image(image_path, image_key, roi=None if rois is None else rois[i]) 231 gt = _load_image(gt_path, gt_key, roi=None if rois is None else rois[i]) 232 233 if embedding_dir is None: 234 embedding_path = None 235 else: 236 assert predictor is not None 237 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 238 239 if tiling_window_params is None: 240 tiling_window_params = {} 241 242 image_embeddings = util.precompute_image_embeddings( 243 predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings, **tiling_window_params 244 ) 245 246 segmenter.initialize(image, image_embeddings, **tiling_window_params) 247 248 _grid_search_iteration( 249 segmenter, gs_combinations, gt, image_name, 250 fixed_generate_kwargs=fixed_generate_kwargs, result_path=result_path, verbose=verbose_gs, 251 )
Run grid search for automatic mask generation.
The parameters and their respective value ranges for the grid search are specified via the 'grid_search_values' argument. For example, to run a grid search over the parameters 'pred_iou_thresh' and 'stability_score_thresh', you can pass the following:
grid_search_values = {
"pred_iou_thresh": [0.6, 0.7, 0.8, 0.9],
"stability_score_thresh": [0.6, 0.7, 0.8, 0.9],
}
All combinations of the parameters will be checked.
You can use the functions default_grid_search_values_instance_segmentation_with_decoder
or default_grid_search_values_amg
to get the default grid search parameters for the two
respective instance segmentation methods.
Arguments:
- segmenter: The class implementing the instance segmentation functionality.
- grid_search_values: The grid search values for parameters of the
generate
function. - image_paths: The input images for the grid search.
- gt_paths: The ground-truth segmentation for the grid search.
- result_dir: Folder to cache the evaluation results per image.
- embedding_dir: Folder to cache the image embeddings.
- fixed_generate_kwargs: Fixed keyword arguments for the
generate
method of the segmenter. - verbose_gs: Whether to run the grid-search for individual images in a verbose mode.
- image_key: Key for loading the image data from a more complex file format like HDF5. If not given a simple image format like tif is assumed.
- gt_key: Key for loading the ground-truth data from a more complex file format like HDF5. If not given a simple image format like tif is assumed.
- rois: Region of interests to resetrict the evaluation to.
- tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.
254def run_instance_segmentation_inference( 255 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 256 image_paths: List[Union[str, os.PathLike]], 257 embedding_dir: Optional[Union[str, os.PathLike]], 258 prediction_dir: Union[str, os.PathLike], 259 generate_kwargs: Optional[Dict[str, Any]] = None, 260 tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None, 261) -> None: 262 """Run inference for automatic mask generation. 263 264 Args: 265 segmenter: The class implementing the instance segmentation functionality. 266 image_paths: The input images. 267 embedding_dir: Folder to cache the image embeddings. 268 prediction_dir: Folder to save the predictions. 269 generate_kwargs: The keyword arguments for the `generate` method of the segmenter. 270 tiling_window_params: The parameters to decide whether to use tiling window operation 271 for automatic segmentation. 272 """ 273 274 verbose_embeddings = False 275 276 generate_kwargs = {} if generate_kwargs is None else generate_kwargs 277 predictor = segmenter._predictor 278 min_object_size = generate_kwargs.get("min_mask_region_area", 0) 279 280 for image_path in tqdm(image_paths, desc="Run inference for automatic mask generation"): 281 image_name = os.path.basename(image_path) 282 283 # We skip the images that already have been segmented. 284 prediction_path = os.path.join(prediction_dir, image_name) 285 if os.path.exists(prediction_path): 286 continue 287 288 assert os.path.exists(image_path), image_path 289 image = imageio.imread(image_path) 290 291 if embedding_dir is None: 292 embedding_path = None 293 else: 294 assert predictor is not None 295 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 296 297 if tiling_window_params is None: 298 tiling_window_params = {} 299 300 image_embeddings = util.precompute_image_embeddings( 301 predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings, **tiling_window_params 302 ) 303 304 segmenter.initialize(image, image_embeddings, **tiling_window_params) 305 306 masks = segmenter.generate(**generate_kwargs) 307 308 if len(masks) == 0: # the instance segmentation can have no masks, hence we just save empty labels 309 if isinstance(segmenter, InstanceSegmentationWithDecoder): 310 this_shape = segmenter._foreground.shape 311 elif isinstance(segmenter, AMGBase): 312 this_shape = segmenter._original_size 313 else: 314 this_shape = image.shape[-2:] 315 316 instances = np.zeros(this_shape, dtype="uint32") 317 else: 318 instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=min_object_size) 319 320 # It's important to compress here, otherwise the predictions would take up a lot of space. 321 imageio.imwrite(prediction_path, instances, compression=5)
Run inference for automatic mask generation.
Arguments:
- segmenter: The class implementing the instance segmentation functionality.
- image_paths: The input images.
- embedding_dir: Folder to cache the image embeddings.
- prediction_dir: Folder to save the predictions.
- generate_kwargs: The keyword arguments for the
generate
method of the segmenter. - tiling_window_params: The parameters to decide whether to use tiling window operation for automatic segmentation.
324def evaluate_instance_segmentation_grid_search( 325 result_dir: Union[str, os.PathLike], grid_search_parameters: List[str], criterion: str = "mSA" 326) -> Tuple[Dict[str, Any], float]: 327 """Evaluate gridsearch results. 328 329 Args: 330 result_dir: The folder with the gridsearch results. 331 grid_search_parameters: The names for the gridsearch parameters. 332 criterion: The metric to use for determining the best parameters. 333 334 Returns: 335 The best parameter setting. 336 The evaluation score for the best setting. 337 """ 338 # Load all the grid search results. 339 gs_files = glob(os.path.join(result_dir, "*.csv")) 340 gs_result = pd.concat([pd.read_csv(gs_file) for gs_file in gs_files]) 341 342 # Retrieve only the relevant columns and group by the gridsearch columns. 343 gs_result = gs_result[grid_search_parameters + [criterion]].reset_index() 344 345 # Compute the mean over the grouped columns. 346 grouped_result = gs_result.groupby(grid_search_parameters).mean().reset_index() 347 348 # Find the best score and corresponding parameters. 349 best_score, best_idx = grouped_result[criterion].max(), grouped_result[criterion].idxmax() 350 best_params = grouped_result.iloc[best_idx] 351 assert np.isclose(best_params[criterion], best_score) 352 best_kwargs = {k: v for k, v in zip(grid_search_parameters, best_params)} 353 354 return best_kwargs, best_score
Evaluate gridsearch results.
Arguments:
- result_dir: The folder with the gridsearch results.
- grid_search_parameters: The names for the gridsearch parameters.
- criterion: The metric to use for determining the best parameters.
Returns:
The best parameter setting. The evaluation score for the best setting.
357def save_grid_search_best_params(best_kwargs, best_msa, grid_search_result_dir=None): 358 # saving the best parameters estimated from grid-search in the `results` folder 359 param_df = pd.DataFrame.from_dict([best_kwargs]) 360 res_df = pd.DataFrame.from_dict([{"best_msa": best_msa}]) 361 best_param_df = pd.merge(res_df, param_df, left_index=True, right_index=True) 362 363 path_name = "grid_search_params_amg.csv" if "pred_iou_thresh" and "stability_score_thresh" in best_kwargs \ 364 else "grid_search_params_instance_segmentation_with_decoder.csv" 365 366 if grid_search_result_dir is not None: 367 os.makedirs(os.path.join(grid_search_result_dir, "results"), exist_ok=True) 368 res_path = os.path.join(grid_search_result_dir, "results", path_name) 369 else: 370 res_path = path_name 371 372 best_param_df.to_csv(res_path)
375def run_instance_segmentation_grid_search_and_inference( 376 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 377 grid_search_values: Dict[str, List], 378 val_image_paths: List[Union[str, os.PathLike]], 379 val_gt_paths: List[Union[str, os.PathLike]], 380 test_image_paths: List[Union[str, os.PathLike]], 381 embedding_dir: Optional[Union[str, os.PathLike]], 382 prediction_dir: Union[str, os.PathLike], 383 experiment_folder: Union[str, os.PathLike], 384 result_dir: Union[str, os.PathLike], 385 fixed_generate_kwargs: Optional[Dict[str, Any]] = None, 386 verbose_gs: bool = True, 387 tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None, 388) -> None: 389 """Run grid search and inference for automatic mask generation. 390 391 Please refer to the documentation of `run_instance_segmentation_grid_search` 392 for details on how to specify the grid search parameters. 393 394 Args: 395 segmenter: The class implementing the instance segmentation functionality. 396 grid_search_values: The grid search values for parameters of the `generate` function. 397 val_image_paths: The input images for the grid search. 398 val_gt_paths: The ground-truth segmentation for the grid search. 399 test_image_paths: The input images for inference. 400 embedding_dir: Folder to cache the image embeddings. 401 prediction_dir: Folder to save the predictions. 402 experiment_folder: Folder for caching best grid search parameters in 'results'. 403 result_dir: Folder to cache the evaluation results per image. 404 fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter. 405 verbose_gs: Whether to run the gridsearch for individual images in a verbose mode. 406 tiling_window_params: The parameters to decide whether to use tiling window operation 407 for automatic segmentation. 408 """ 409 run_instance_segmentation_grid_search( 410 segmenter=segmenter, 411 grid_search_values=grid_search_values, 412 image_paths=val_image_paths, 413 gt_paths=val_gt_paths, 414 result_dir=result_dir, 415 embedding_dir=embedding_dir, 416 fixed_generate_kwargs=fixed_generate_kwargs, 417 verbose_gs=verbose_gs, 418 tiling_window_params=tiling_window_params, 419 ) 420 421 best_kwargs, best_msa = evaluate_instance_segmentation_grid_search(result_dir, list(grid_search_values.keys())) 422 best_param_str = ", ".join(f"{k} = {v}" for k, v in best_kwargs.items()) 423 print("Best grid-search result:", best_msa, "with parmeters:\n", best_param_str) 424 print() 425 426 save_grid_search_best_params(best_kwargs, best_msa, experiment_folder) 427 428 generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs 429 generate_kwargs.update(best_kwargs) 430 431 run_instance_segmentation_inference( 432 segmenter=segmenter, 433 image_paths=test_image_paths, 434 embedding_dir=embedding_dir, 435 prediction_dir=prediction_dir, 436 generate_kwargs=generate_kwargs, 437 tiling_window_params=tiling_window_params, 438 )
Run grid search and inference for automatic mask generation.
Please refer to the documentation of run_instance_segmentation_grid_search
for details on how to specify the grid search parameters.
Arguments:
- segmenter: The class implementing the instance segmentation functionality.
- grid_search_values: The grid search values for parameters of the
generate
function. - val_image_paths: The input images for the grid search.
- val_gt_paths: The ground-truth segmentation for the grid search.
- test_image_paths: The input images for inference.
- embedding_dir: Folder to cache the image embeddings.
- prediction_dir: Folder to save the predictions.
- experiment_folder: Folder for caching best grid search parameters in 'results'.
- result_dir: Folder to cache the evaluation results per image.
- fixed_generate_kwargs: Fixed keyword arguments for the
generate
method of the segmenter. - verbose_gs: Whether to run the gridsearch for individual images in a verbose mode.
- tiling_window_params: The parameters to decide whether to use tiling window operation for automatic segmentation.