micro_sam.evaluation.instance_segmentation
Inference and evaluation for the automatic instance segmentation functionality.
1"""Inference and evaluation for the automatic instance segmentation functionality. 2""" 3 4import os 5from glob import glob 6from tqdm import tqdm 7from pathlib import Path 8from itertools import product 9from typing import Any, Dict, List, Optional, Tuple, Union 10 11import numpy as np 12import pandas as pd 13import imageio.v3 as imageio 14 15from elf.io import open_file 16from elf.evaluation import mean_segmentation_accuracy, matching 17 18from .. import util 19from ..instance_segmentation import AMGBase, InstanceSegmentationWithDecoder 20 21 22def _get_range_of_search_values(input_vals, step): 23 if isinstance(input_vals, list): 24 search_range = np.arange(input_vals[0], input_vals[1] + step, step) 25 search_range = [round(e, 3) for e in search_range] 26 else: 27 search_range = [input_vals] 28 return search_range 29 30 31def default_grid_search_values_amg( 32 iou_thresh_values: Optional[List[float]] = None, stability_score_values: Optional[List[float]] = None, 33) -> Dict[str, List[float]]: 34 """Default grid-search parameter for AMG-based instance segmentation. 35 36 Return grid search values for the two most important parameters: 37 - `pred_iou_thresh`, the threshold for keeping objects according to the IoU predicted by the model. 38 - `stability_score_thresh`, the theshold for keepong objects according to their stability. 39 40 Args: 41 iou_thresh_values: The values for `pred_iou_thresh` used in the gridsearch. 42 By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. 43 stability_score_values: The values for `stability_score_thresh` used in the gridsearch. 44 By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. 45 46 Returns: 47 The values for grid search. 48 """ 49 if iou_thresh_values is None: 50 iou_thresh_values = _get_range_of_search_values([0.6, 0.9], step=0.025) 51 if stability_score_values is None: 52 stability_score_values = _get_range_of_search_values([0.6, 0.95], step=0.025) 53 return { 54 "pred_iou_thresh": iou_thresh_values, 55 "stability_score_thresh": stability_score_values, 56 } 57 58 59def default_grid_search_values_instance_segmentation_with_decoder( 60 center_distance_threshold_values: Optional[List[float]] = None, 61 boundary_distance_threshold_values: Optional[List[float]] = None, 62 distance_smoothing_values: Optional[List[float]] = None, 63 min_size_values: Optional[List[float]] = None, 64) -> Dict[str, List[float]]: 65 """Default grid-search parameter for decoder-based instance segmentation. 66 67 Args: 68 center_distance_threshold_values: The values for `center_distance_threshold` used in the gridsearch. 69 By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used. 70 boundary_distance_threshold_values: The values for `boundary_distance_threshold` used in the gridsearch. 71 By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used. 72 distance_smoothing_values: The values for `distance_smoothing` used in the gridsearch. 73 By default values in the range from 1.0 to 2.0 with a stepsize of 0.1 will be used. 74 min_size_values: The values for `min_size` used in the gridsearch. 75 By default the values 50, 100 and 200 are used. 76 77 Returns: 78 The values for grid search. 79 """ 80 if center_distance_threshold_values is None: 81 center_distance_threshold_values = _get_range_of_search_values( 82 [0.3, 0.7], step=0.1 83 ) 84 if boundary_distance_threshold_values is None: 85 boundary_distance_threshold_values = _get_range_of_search_values( 86 [0.3, 0.7], step=0.1 87 ) 88 if distance_smoothing_values is None: 89 distance_smoothing_values = _get_range_of_search_values( 90 [1.0, 2.0], step=0.2 91 ) 92 if min_size_values is None: 93 min_size_values = [50, 100, 200] 94 95 return { 96 "center_distance_threshold": center_distance_threshold_values, 97 "boundary_distance_threshold": boundary_distance_threshold_values, 98 "distance_smoothing": distance_smoothing_values, 99 "min_size": min_size_values, 100 } 101 102 103def default_grid_search_values_apg( 104 min_distance_values: Optional[List[float]] = None, 105 threshold_abs_values: Optional[List[float]] = None, 106 multimasking_values: Optional[List[float]] = None, 107 prompt_selection_values: Optional[List[float]] = None, 108 min_size_values: Optional[List[float]] = None, 109 nms_threshold_values: Optional[List[float]] = None, 110 intersection_over_min_values: Optional[List[bool]] = None, 111 mask_threshold_values: Optional[List[Union[float, str]]] = None, 112 center_distance_threshold_values: Optional[List[float]] = None, 113 boundary_distance_threshold_values: Optional[List[float]] = None, 114) -> Dict[str, List[float]]: 115 """Default grid-search parameter for APG-based instance segmentation. 116 117 Args: 118 ... 119 120 Returns: 121 The values for grid search. 122 """ 123 # NOTE: The two combinations below are for distances. Since we use connected components, we don't run them! 124 # if min_distance_values is None: 125 # min_distance_values = _get_range_of_search_values([1, 5], step=1) 126 # if threshold_abs_values is None: 127 # threshold_abs_values = _get_range_of_search_values([0.1, 0.5], step=0.1) 128 129 # if multimasking_values is None: 130 # multimasking_values = [True, False] 131 # if prompt_selection_values is None: 132 # prompt_selection_values = [ 133 # "center_distances", 134 # "boundary_distances", 135 # "connected_components", 136 # ["center_distances", "connected_components"], 137 # ["center_distances", "boundary_distances"], 138 # ["boundary_distances", "connected_components"], 139 # ["center_distances", "boundary_distances", "connected_components"] 140 # ] 141 142 # NOTE: The two parameters below are for connected components. 143 if center_distance_threshold_values is None: 144 center_distance_threshold_values = _get_range_of_search_values([0.3, 0.7], step=0.1) 145 if boundary_distance_threshold_values is None: 146 boundary_distance_threshold_values = _get_range_of_search_values([0.3, 0.7], step=0.1) 147 148 if min_size_values is None: 149 min_size_values = [50, 100, 200] 150 if nms_threshold_values is None: 151 nms_threshold_values = _get_range_of_search_values([0.5, 0.9], step=0.1) 152 if intersection_over_min_values is None: 153 intersection_over_min_values = [True, False] 154 # if mask_threshold_values is None: 155 # mask_threshold_values = [None, "auto"] # 'None' derives the default from the model. 156 157 return { 158 # "min_distance": min_distance_values, 159 # "threshold_abs": threshold_abs_values, 160 # "multimasking": multimasking_values, 161 # "prompt_selection": prompt_selection_values, 162 "center_distance_threshold": center_distance_threshold_values, 163 "boundary_distance_threshold": boundary_distance_threshold_values, 164 "min_size": min_size_values, 165 "nms_threshold": nms_threshold_values, 166 "intersection_over_min": intersection_over_min_values, 167 # "mask_threshold": mask_threshold_values, 168 } 169 170 171def _grid_search_iteration( 172 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 173 gs_combinations: List[Dict], 174 gt: np.ndarray, 175 image_name: str, 176 fixed_generate_kwargs: Dict[str, Any], 177 result_path: Optional[Union[str, os.PathLike]], 178 verbose: bool = False, 179) -> pd.DataFrame: 180 net_list = [] 181 for gs_kwargs in tqdm(gs_combinations, disable=not verbose): 182 generate_kwargs = gs_kwargs | fixed_generate_kwargs 183 instance_labels = segmenter.generate(**generate_kwargs) 184 m_sas, sas = mean_segmentation_accuracy(instance_labels, gt, return_accuracies=True) 185 stats = matching(instance_labels, gt) 186 187 result_dict = { 188 "image_name": image_name, 189 "mSA": m_sas, 190 "SA50": sas[0], 191 "SA75": sas[5], 192 "Precision": stats["precision"], 193 "Recall": stats["recall"], 194 "F1": stats["f1"], 195 } 196 result_dict.update(gs_kwargs) 197 tmp_df = pd.DataFrame([result_dict]) 198 net_list.append(tmp_df) 199 200 img_gs_df = pd.concat(net_list) 201 img_gs_df.to_csv(result_path, index=False) 202 203 return img_gs_df 204 205 206def _load_image(path, key, roi): 207 if key is None: 208 im = imageio.imread(path) 209 if roi is not None: 210 im = im[roi] 211 return im 212 with open_file(path, "r") as f: 213 im = f[key][:] if roi is None else f[key][roi] 214 215 return im 216 217 218def run_instance_segmentation_grid_search( 219 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 220 grid_search_values: Dict[str, List], 221 image_paths: List[Union[str, os.PathLike]], 222 gt_paths: List[Union[str, os.PathLike]], 223 result_dir: Union[str, os.PathLike], 224 embedding_dir: Optional[Union[str, os.PathLike]], 225 fixed_generate_kwargs: Optional[Dict[str, Any]] = None, 226 verbose_gs: bool = False, 227 image_key: Optional[str] = None, 228 gt_key: Optional[str] = None, 229 rois: Optional[Tuple[slice, ...]] = None, 230 tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None, 231) -> None: 232 """Run grid search for automatic mask generation. 233 234 The parameters and their respective value ranges for the grid search are specified via the 235 'grid_search_values' argument. For example, to run a grid search over the parameters 'pred_iou_thresh' 236 and 'stability_score_thresh', you can pass the following: 237 ``` 238 grid_search_values = { 239 "pred_iou_thresh": [0.6, 0.7, 0.8, 0.9], 240 "stability_score_thresh": [0.6, 0.7, 0.8, 0.9], 241 } 242 ``` 243 All combinations of the parameters will be checked. 244 245 You can use the functions `default_grid_search_values_instance_segmentation_with_decoder` 246 or `default_grid_search_values_amg` to get the default grid search parameters for the two 247 respective instance segmentation methods. 248 249 Args: 250 segmenter: The class implementing the instance segmentation functionality. 251 grid_search_values: The grid search values for parameters of the `generate` function. 252 image_paths: The input images for the grid search. 253 gt_paths: The ground-truth segmentation for the grid search. 254 result_dir: Folder to cache the evaluation results per image. 255 embedding_dir: Folder to cache the image embeddings. 256 fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter. 257 verbose_gs: Whether to run the grid-search for individual images in a verbose mode. 258 image_key: Key for loading the image data from a more complex file format like HDF5. 259 If not given a simple image format like tif is assumed. 260 gt_key: Key for loading the ground-truth data from a more complex file format like HDF5. 261 If not given a simple image format like tif is assumed. 262 rois: Region of interests to resetrict the evaluation to. 263 tiling_window_params: The parameters to decide whether to use tiling window operation for AIS. 264 """ 265 verbose_embeddings = False 266 267 assert len(image_paths) == len(gt_paths) 268 fixed_generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs 269 270 duplicate_params = [gs_param for gs_param in grid_search_values.keys() if gs_param in fixed_generate_kwargs] 271 if duplicate_params: 272 raise ValueError( 273 "You may not pass duplicate parameters in 'grid_search_values' and 'fixed_generate_kwargs'." 274 f"The parameters {duplicate_params} are duplicated." 275 ) 276 277 # Compute all combinations of grid search values. 278 gs_combinations = product(*grid_search_values.values()) 279 # Map each combination back to a valid kwarg input. 280 gs_combinations = [ 281 {k: v for k, v in zip(grid_search_values.keys(), vals)} for vals in gs_combinations 282 ] 283 284 os.makedirs(result_dir, exist_ok=True) 285 predictor = getattr(segmenter, "_predictor", None) 286 287 for i, (image_path, gt_path) in tqdm( 288 enumerate(zip(image_paths, gt_paths)), desc="Run instance segmentation grid-search", total=len(image_paths) 289 ): 290 image_name = Path(image_path).stem 291 result_path = os.path.join(result_dir, f"{image_name}.csv") 292 293 # We skip images for which the grid search was done already. 294 if os.path.exists(result_path): 295 continue 296 297 assert os.path.exists(image_path), image_path 298 assert os.path.exists(gt_path), gt_path 299 300 image = _load_image(image_path, image_key, roi=None if rois is None else rois[i]) 301 gt = _load_image(gt_path, gt_key, roi=None if rois is None else rois[i]) 302 303 if tiling_window_params is None: 304 tiling_window_params = {} 305 306 if embedding_dir is None: 307 embedding_path = None 308 segmenter.initialize(image, **tiling_window_params) 309 310 else: 311 assert predictor is not None 312 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 313 image_embeddings = util.precompute_image_embeddings( 314 predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings, **tiling_window_params 315 ) 316 segmenter.initialize(image, image_embeddings, **tiling_window_params) 317 318 _grid_search_iteration( 319 segmenter, gs_combinations, gt, image_name, 320 fixed_generate_kwargs=fixed_generate_kwargs, result_path=result_path, verbose=verbose_gs, 321 ) 322 323 324def run_instance_segmentation_inference( 325 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 326 image_paths: List[Union[str, os.PathLike]], 327 embedding_dir: Optional[Union[str, os.PathLike]], 328 prediction_dir: Union[str, os.PathLike], 329 generate_kwargs: Optional[Dict[str, Any]] = None, 330 tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None, 331) -> None: 332 """Run inference for automatic mask generation. 333 334 Args: 335 segmenter: The class implementing the instance segmentation functionality. 336 image_paths: The input images. 337 embedding_dir: Folder to cache the image embeddings. 338 prediction_dir: Folder to save the predictions. 339 generate_kwargs: The keyword arguments for the `generate` method of the segmenter. 340 tiling_window_params: The parameters to decide whether to use tiling window operation 341 for automatic segmentation. 342 """ 343 344 verbose_embeddings = False 345 346 generate_kwargs = {} if generate_kwargs is None else generate_kwargs 347 predictor = segmenter._predictor 348 349 for image_path in tqdm(image_paths, desc="Run inference for automatic mask generation"): 350 image_name = os.path.basename(image_path) 351 352 # We skip the images that already have been segmented. 353 prediction_path = os.path.join(prediction_dir, image_name) 354 if os.path.exists(prediction_path): 355 continue 356 357 assert os.path.exists(image_path), image_path 358 image = imageio.imread(image_path) 359 360 if embedding_dir is None: 361 embedding_path = None 362 else: 363 assert predictor is not None 364 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 365 366 if tiling_window_params is None: 367 tiling_window_params = {} 368 369 image_embeddings = util.precompute_image_embeddings( 370 predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings, **tiling_window_params 371 ) 372 373 segmenter.initialize(image, image_embeddings, **tiling_window_params) 374 instances = segmenter.generate(**generate_kwargs) 375 376 # It's important to compress here, otherwise the predictions would take up a lot of space. 377 imageio.imwrite(prediction_path, instances, compression=5) 378 379 380def evaluate_instance_segmentation_grid_search( 381 result_dir: Union[str, os.PathLike], grid_search_parameters: List[str], criterion: str = "mSA" 382) -> Tuple[Dict[str, Any], float]: 383 """Evaluate gridsearch results. 384 385 Args: 386 result_dir: The folder with the gridsearch results. 387 grid_search_parameters: The names for the gridsearch parameters. 388 criterion: The metric to use for determining the best parameters. 389 390 Returns: 391 The best parameter setting. 392 The evaluation score for the best setting. 393 """ 394 # Load all the grid search results. 395 gs_files = glob(os.path.join(result_dir, "*.csv")) 396 gs_result = pd.concat([pd.read_csv(gs_file) for gs_file in gs_files]) 397 398 # Retrieve only the relevant columns and group by the gridsearch columns. 399 gs_result = gs_result[grid_search_parameters + [criterion]].reset_index() 400 401 # Compute the mean over the grouped columns. 402 grouped_result = gs_result.groupby(grid_search_parameters).mean().reset_index() 403 404 # Find the best score and corresponding parameters. 405 best_score, best_idx = grouped_result[criterion].max(), grouped_result[criterion].idxmax() 406 best_params = grouped_result.iloc[best_idx] 407 assert np.isclose(best_params[criterion], best_score) 408 best_kwargs = {k: v for k, v in zip(grid_search_parameters, best_params)} 409 410 return best_kwargs, best_score 411 412 413def save_grid_search_best_params(best_kwargs, best_msa, grid_search_result_dir=None): 414 # saving the best parameters estimated from grid-search in the `results` folder 415 param_df = pd.DataFrame.from_dict([best_kwargs]) 416 res_df = pd.DataFrame.from_dict([{"best_msa": best_msa}]) 417 best_param_df = pd.merge(res_df, param_df, left_index=True, right_index=True) 418 419 path_name = "grid_search_params_amg.csv" if "pred_iou_thresh" and "stability_score_thresh" in best_kwargs \ 420 else "grid_search_params_instance_segmentation_with_decoder.csv" 421 422 if grid_search_result_dir is not None: 423 os.makedirs(os.path.join(grid_search_result_dir, "results"), exist_ok=True) 424 res_path = os.path.join(grid_search_result_dir, "results", path_name) 425 else: 426 res_path = path_name 427 428 best_param_df.to_csv(res_path) 429 430 431def run_instance_segmentation_grid_search_and_inference( 432 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 433 grid_search_values: Dict[str, List], 434 val_image_paths: List[Union[str, os.PathLike]], 435 val_gt_paths: List[Union[str, os.PathLike]], 436 test_image_paths: List[Union[str, os.PathLike]], 437 embedding_dir: Optional[Union[str, os.PathLike]], 438 prediction_dir: Union[str, os.PathLike], 439 experiment_folder: Union[str, os.PathLike], 440 result_dir: Union[str, os.PathLike], 441 fixed_generate_kwargs: Optional[Dict[str, Any]] = None, 442 verbose_gs: bool = True, 443 tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None, 444) -> None: 445 """Run grid search and inference for automatic mask generation. 446 447 Please refer to the documentation of `run_instance_segmentation_grid_search` 448 for details on how to specify the grid search parameters. 449 450 Args: 451 segmenter: The class implementing the instance segmentation functionality. 452 grid_search_values: The grid search values for parameters of the `generate` function. 453 val_image_paths: The input images for the grid search. 454 val_gt_paths: The ground-truth segmentation for the grid search. 455 test_image_paths: The input images for inference. 456 embedding_dir: Folder to cache the image embeddings. 457 prediction_dir: Folder to save the predictions. 458 experiment_folder: Folder for caching best grid search parameters in 'results'. 459 result_dir: Folder to cache the evaluation results per image. 460 fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter. 461 verbose_gs: Whether to run the gridsearch for individual images in a verbose mode. 462 tiling_window_params: The parameters to decide whether to use tiling window operation 463 for automatic segmentation. 464 """ 465 run_instance_segmentation_grid_search( 466 segmenter=segmenter, 467 grid_search_values=grid_search_values, 468 image_paths=val_image_paths, 469 gt_paths=val_gt_paths, 470 result_dir=result_dir, 471 embedding_dir=embedding_dir, 472 fixed_generate_kwargs=fixed_generate_kwargs, 473 verbose_gs=verbose_gs, 474 tiling_window_params=tiling_window_params, 475 ) 476 477 best_kwargs, best_msa = evaluate_instance_segmentation_grid_search(result_dir, list(grid_search_values.keys())) 478 best_param_str = ", ".join(f"{k} = {v}" for k, v in best_kwargs.items()) 479 print("Best grid-search result:", best_msa, "with parmeters:\n", best_param_str) 480 print() 481 482 save_grid_search_best_params(best_kwargs, best_msa, experiment_folder) 483 484 generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs 485 generate_kwargs.update(best_kwargs) 486 487 # NOTE: Make sure the 'prompt_selection' values for APG are as expected 488 if "prompt_selection" in generate_kwargs: 489 generate_kwargs["prompt_selection"] = _maybe_list_value(generate_kwargs["prompt_selection"]) 490 491 run_instance_segmentation_inference( 492 segmenter=segmenter, 493 image_paths=test_image_paths, 494 embedding_dir=embedding_dir, 495 prediction_dir=prediction_dir, 496 generate_kwargs=generate_kwargs, 497 tiling_window_params=tiling_window_params, 498 ) 499 500 501def _maybe_list_value(val): 502 # In case it's not a string, well we ignore it. 503 if not isinstance(val, str): 504 return val 505 506 s = val.strip() 507 # Let's try to parse through values that appear to be an obvious list. 508 if s.startswith("[") and s.endswith("]"): 509 import ast 510 parsed = ast.literal_eval(s) 511 if isinstance(parsed, list): 512 return parsed 513 514 return val
32def default_grid_search_values_amg( 33 iou_thresh_values: Optional[List[float]] = None, stability_score_values: Optional[List[float]] = None, 34) -> Dict[str, List[float]]: 35 """Default grid-search parameter for AMG-based instance segmentation. 36 37 Return grid search values for the two most important parameters: 38 - `pred_iou_thresh`, the threshold for keeping objects according to the IoU predicted by the model. 39 - `stability_score_thresh`, the theshold for keepong objects according to their stability. 40 41 Args: 42 iou_thresh_values: The values for `pred_iou_thresh` used in the gridsearch. 43 By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. 44 stability_score_values: The values for `stability_score_thresh` used in the gridsearch. 45 By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. 46 47 Returns: 48 The values for grid search. 49 """ 50 if iou_thresh_values is None: 51 iou_thresh_values = _get_range_of_search_values([0.6, 0.9], step=0.025) 52 if stability_score_values is None: 53 stability_score_values = _get_range_of_search_values([0.6, 0.95], step=0.025) 54 return { 55 "pred_iou_thresh": iou_thresh_values, 56 "stability_score_thresh": stability_score_values, 57 }
Default grid-search parameter for AMG-based instance segmentation.
Return grid search values for the two most important parameters:
pred_iou_thresh, the threshold for keeping objects according to the IoU predicted by the model.stability_score_thresh, the theshold for keepong objects according to their stability.
Arguments:
- iou_thresh_values: The values for
pred_iou_threshused in the gridsearch. By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. - stability_score_values: The values for
stability_score_threshused in the gridsearch. By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
Returns:
The values for grid search.
60def default_grid_search_values_instance_segmentation_with_decoder( 61 center_distance_threshold_values: Optional[List[float]] = None, 62 boundary_distance_threshold_values: Optional[List[float]] = None, 63 distance_smoothing_values: Optional[List[float]] = None, 64 min_size_values: Optional[List[float]] = None, 65) -> Dict[str, List[float]]: 66 """Default grid-search parameter for decoder-based instance segmentation. 67 68 Args: 69 center_distance_threshold_values: The values for `center_distance_threshold` used in the gridsearch. 70 By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used. 71 boundary_distance_threshold_values: The values for `boundary_distance_threshold` used in the gridsearch. 72 By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used. 73 distance_smoothing_values: The values for `distance_smoothing` used in the gridsearch. 74 By default values in the range from 1.0 to 2.0 with a stepsize of 0.1 will be used. 75 min_size_values: The values for `min_size` used in the gridsearch. 76 By default the values 50, 100 and 200 are used. 77 78 Returns: 79 The values for grid search. 80 """ 81 if center_distance_threshold_values is None: 82 center_distance_threshold_values = _get_range_of_search_values( 83 [0.3, 0.7], step=0.1 84 ) 85 if boundary_distance_threshold_values is None: 86 boundary_distance_threshold_values = _get_range_of_search_values( 87 [0.3, 0.7], step=0.1 88 ) 89 if distance_smoothing_values is None: 90 distance_smoothing_values = _get_range_of_search_values( 91 [1.0, 2.0], step=0.2 92 ) 93 if min_size_values is None: 94 min_size_values = [50, 100, 200] 95 96 return { 97 "center_distance_threshold": center_distance_threshold_values, 98 "boundary_distance_threshold": boundary_distance_threshold_values, 99 "distance_smoothing": distance_smoothing_values, 100 "min_size": min_size_values, 101 }
Default grid-search parameter for decoder-based instance segmentation.
Arguments:
- center_distance_threshold_values: The values for
center_distance_thresholdused in the gridsearch. By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used. - boundary_distance_threshold_values: The values for
boundary_distance_thresholdused in the gridsearch. By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used. - distance_smoothing_values: The values for
distance_smoothingused in the gridsearch. By default values in the range from 1.0 to 2.0 with a stepsize of 0.1 will be used. - min_size_values: The values for
min_sizeused in the gridsearch. By default the values 50, 100 and 200 are used.
Returns:
The values for grid search.
104def default_grid_search_values_apg( 105 min_distance_values: Optional[List[float]] = None, 106 threshold_abs_values: Optional[List[float]] = None, 107 multimasking_values: Optional[List[float]] = None, 108 prompt_selection_values: Optional[List[float]] = None, 109 min_size_values: Optional[List[float]] = None, 110 nms_threshold_values: Optional[List[float]] = None, 111 intersection_over_min_values: Optional[List[bool]] = None, 112 mask_threshold_values: Optional[List[Union[float, str]]] = None, 113 center_distance_threshold_values: Optional[List[float]] = None, 114 boundary_distance_threshold_values: Optional[List[float]] = None, 115) -> Dict[str, List[float]]: 116 """Default grid-search parameter for APG-based instance segmentation. 117 118 Args: 119 ... 120 121 Returns: 122 The values for grid search. 123 """ 124 # NOTE: The two combinations below are for distances. Since we use connected components, we don't run them! 125 # if min_distance_values is None: 126 # min_distance_values = _get_range_of_search_values([1, 5], step=1) 127 # if threshold_abs_values is None: 128 # threshold_abs_values = _get_range_of_search_values([0.1, 0.5], step=0.1) 129 130 # if multimasking_values is None: 131 # multimasking_values = [True, False] 132 # if prompt_selection_values is None: 133 # prompt_selection_values = [ 134 # "center_distances", 135 # "boundary_distances", 136 # "connected_components", 137 # ["center_distances", "connected_components"], 138 # ["center_distances", "boundary_distances"], 139 # ["boundary_distances", "connected_components"], 140 # ["center_distances", "boundary_distances", "connected_components"] 141 # ] 142 143 # NOTE: The two parameters below are for connected components. 144 if center_distance_threshold_values is None: 145 center_distance_threshold_values = _get_range_of_search_values([0.3, 0.7], step=0.1) 146 if boundary_distance_threshold_values is None: 147 boundary_distance_threshold_values = _get_range_of_search_values([0.3, 0.7], step=0.1) 148 149 if min_size_values is None: 150 min_size_values = [50, 100, 200] 151 if nms_threshold_values is None: 152 nms_threshold_values = _get_range_of_search_values([0.5, 0.9], step=0.1) 153 if intersection_over_min_values is None: 154 intersection_over_min_values = [True, False] 155 # if mask_threshold_values is None: 156 # mask_threshold_values = [None, "auto"] # 'None' derives the default from the model. 157 158 return { 159 # "min_distance": min_distance_values, 160 # "threshold_abs": threshold_abs_values, 161 # "multimasking": multimasking_values, 162 # "prompt_selection": prompt_selection_values, 163 "center_distance_threshold": center_distance_threshold_values, 164 "boundary_distance_threshold": boundary_distance_threshold_values, 165 "min_size": min_size_values, 166 "nms_threshold": nms_threshold_values, 167 "intersection_over_min": intersection_over_min_values, 168 # "mask_threshold": mask_threshold_values, 169 }
Default grid-search parameter for APG-based instance segmentation.
Arguments:
- ...
Returns:
The values for grid search.
219def run_instance_segmentation_grid_search( 220 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 221 grid_search_values: Dict[str, List], 222 image_paths: List[Union[str, os.PathLike]], 223 gt_paths: List[Union[str, os.PathLike]], 224 result_dir: Union[str, os.PathLike], 225 embedding_dir: Optional[Union[str, os.PathLike]], 226 fixed_generate_kwargs: Optional[Dict[str, Any]] = None, 227 verbose_gs: bool = False, 228 image_key: Optional[str] = None, 229 gt_key: Optional[str] = None, 230 rois: Optional[Tuple[slice, ...]] = None, 231 tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None, 232) -> None: 233 """Run grid search for automatic mask generation. 234 235 The parameters and their respective value ranges for the grid search are specified via the 236 'grid_search_values' argument. For example, to run a grid search over the parameters 'pred_iou_thresh' 237 and 'stability_score_thresh', you can pass the following: 238 ``` 239 grid_search_values = { 240 "pred_iou_thresh": [0.6, 0.7, 0.8, 0.9], 241 "stability_score_thresh": [0.6, 0.7, 0.8, 0.9], 242 } 243 ``` 244 All combinations of the parameters will be checked. 245 246 You can use the functions `default_grid_search_values_instance_segmentation_with_decoder` 247 or `default_grid_search_values_amg` to get the default grid search parameters for the two 248 respective instance segmentation methods. 249 250 Args: 251 segmenter: The class implementing the instance segmentation functionality. 252 grid_search_values: The grid search values for parameters of the `generate` function. 253 image_paths: The input images for the grid search. 254 gt_paths: The ground-truth segmentation for the grid search. 255 result_dir: Folder to cache the evaluation results per image. 256 embedding_dir: Folder to cache the image embeddings. 257 fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter. 258 verbose_gs: Whether to run the grid-search for individual images in a verbose mode. 259 image_key: Key for loading the image data from a more complex file format like HDF5. 260 If not given a simple image format like tif is assumed. 261 gt_key: Key for loading the ground-truth data from a more complex file format like HDF5. 262 If not given a simple image format like tif is assumed. 263 rois: Region of interests to resetrict the evaluation to. 264 tiling_window_params: The parameters to decide whether to use tiling window operation for AIS. 265 """ 266 verbose_embeddings = False 267 268 assert len(image_paths) == len(gt_paths) 269 fixed_generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs 270 271 duplicate_params = [gs_param for gs_param in grid_search_values.keys() if gs_param in fixed_generate_kwargs] 272 if duplicate_params: 273 raise ValueError( 274 "You may not pass duplicate parameters in 'grid_search_values' and 'fixed_generate_kwargs'." 275 f"The parameters {duplicate_params} are duplicated." 276 ) 277 278 # Compute all combinations of grid search values. 279 gs_combinations = product(*grid_search_values.values()) 280 # Map each combination back to a valid kwarg input. 281 gs_combinations = [ 282 {k: v for k, v in zip(grid_search_values.keys(), vals)} for vals in gs_combinations 283 ] 284 285 os.makedirs(result_dir, exist_ok=True) 286 predictor = getattr(segmenter, "_predictor", None) 287 288 for i, (image_path, gt_path) in tqdm( 289 enumerate(zip(image_paths, gt_paths)), desc="Run instance segmentation grid-search", total=len(image_paths) 290 ): 291 image_name = Path(image_path).stem 292 result_path = os.path.join(result_dir, f"{image_name}.csv") 293 294 # We skip images for which the grid search was done already. 295 if os.path.exists(result_path): 296 continue 297 298 assert os.path.exists(image_path), image_path 299 assert os.path.exists(gt_path), gt_path 300 301 image = _load_image(image_path, image_key, roi=None if rois is None else rois[i]) 302 gt = _load_image(gt_path, gt_key, roi=None if rois is None else rois[i]) 303 304 if tiling_window_params is None: 305 tiling_window_params = {} 306 307 if embedding_dir is None: 308 embedding_path = None 309 segmenter.initialize(image, **tiling_window_params) 310 311 else: 312 assert predictor is not None 313 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 314 image_embeddings = util.precompute_image_embeddings( 315 predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings, **tiling_window_params 316 ) 317 segmenter.initialize(image, image_embeddings, **tiling_window_params) 318 319 _grid_search_iteration( 320 segmenter, gs_combinations, gt, image_name, 321 fixed_generate_kwargs=fixed_generate_kwargs, result_path=result_path, verbose=verbose_gs, 322 )
Run grid search for automatic mask generation.
The parameters and their respective value ranges for the grid search are specified via the 'grid_search_values' argument. For example, to run a grid search over the parameters 'pred_iou_thresh' and 'stability_score_thresh', you can pass the following:
grid_search_values = {
"pred_iou_thresh": [0.6, 0.7, 0.8, 0.9],
"stability_score_thresh": [0.6, 0.7, 0.8, 0.9],
}
All combinations of the parameters will be checked.
You can use the functions default_grid_search_values_instance_segmentation_with_decoder
or default_grid_search_values_amg to get the default grid search parameters for the two
respective instance segmentation methods.
Arguments:
- segmenter: The class implementing the instance segmentation functionality.
- grid_search_values: The grid search values for parameters of the
generatefunction. - image_paths: The input images for the grid search.
- gt_paths: The ground-truth segmentation for the grid search.
- result_dir: Folder to cache the evaluation results per image.
- embedding_dir: Folder to cache the image embeddings.
- fixed_generate_kwargs: Fixed keyword arguments for the
generatemethod of the segmenter. - verbose_gs: Whether to run the grid-search for individual images in a verbose mode.
- image_key: Key for loading the image data from a more complex file format like HDF5. If not given a simple image format like tif is assumed.
- gt_key: Key for loading the ground-truth data from a more complex file format like HDF5. If not given a simple image format like tif is assumed.
- rois: Region of interests to resetrict the evaluation to.
- tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.
325def run_instance_segmentation_inference( 326 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 327 image_paths: List[Union[str, os.PathLike]], 328 embedding_dir: Optional[Union[str, os.PathLike]], 329 prediction_dir: Union[str, os.PathLike], 330 generate_kwargs: Optional[Dict[str, Any]] = None, 331 tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None, 332) -> None: 333 """Run inference for automatic mask generation. 334 335 Args: 336 segmenter: The class implementing the instance segmentation functionality. 337 image_paths: The input images. 338 embedding_dir: Folder to cache the image embeddings. 339 prediction_dir: Folder to save the predictions. 340 generate_kwargs: The keyword arguments for the `generate` method of the segmenter. 341 tiling_window_params: The parameters to decide whether to use tiling window operation 342 for automatic segmentation. 343 """ 344 345 verbose_embeddings = False 346 347 generate_kwargs = {} if generate_kwargs is None else generate_kwargs 348 predictor = segmenter._predictor 349 350 for image_path in tqdm(image_paths, desc="Run inference for automatic mask generation"): 351 image_name = os.path.basename(image_path) 352 353 # We skip the images that already have been segmented. 354 prediction_path = os.path.join(prediction_dir, image_name) 355 if os.path.exists(prediction_path): 356 continue 357 358 assert os.path.exists(image_path), image_path 359 image = imageio.imread(image_path) 360 361 if embedding_dir is None: 362 embedding_path = None 363 else: 364 assert predictor is not None 365 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 366 367 if tiling_window_params is None: 368 tiling_window_params = {} 369 370 image_embeddings = util.precompute_image_embeddings( 371 predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings, **tiling_window_params 372 ) 373 374 segmenter.initialize(image, image_embeddings, **tiling_window_params) 375 instances = segmenter.generate(**generate_kwargs) 376 377 # It's important to compress here, otherwise the predictions would take up a lot of space. 378 imageio.imwrite(prediction_path, instances, compression=5)
Run inference for automatic mask generation.
Arguments:
- segmenter: The class implementing the instance segmentation functionality.
- image_paths: The input images.
- embedding_dir: Folder to cache the image embeddings.
- prediction_dir: Folder to save the predictions.
- generate_kwargs: The keyword arguments for the
generatemethod of the segmenter. - tiling_window_params: The parameters to decide whether to use tiling window operation for automatic segmentation.
381def evaluate_instance_segmentation_grid_search( 382 result_dir: Union[str, os.PathLike], grid_search_parameters: List[str], criterion: str = "mSA" 383) -> Tuple[Dict[str, Any], float]: 384 """Evaluate gridsearch results. 385 386 Args: 387 result_dir: The folder with the gridsearch results. 388 grid_search_parameters: The names for the gridsearch parameters. 389 criterion: The metric to use for determining the best parameters. 390 391 Returns: 392 The best parameter setting. 393 The evaluation score for the best setting. 394 """ 395 # Load all the grid search results. 396 gs_files = glob(os.path.join(result_dir, "*.csv")) 397 gs_result = pd.concat([pd.read_csv(gs_file) for gs_file in gs_files]) 398 399 # Retrieve only the relevant columns and group by the gridsearch columns. 400 gs_result = gs_result[grid_search_parameters + [criterion]].reset_index() 401 402 # Compute the mean over the grouped columns. 403 grouped_result = gs_result.groupby(grid_search_parameters).mean().reset_index() 404 405 # Find the best score and corresponding parameters. 406 best_score, best_idx = grouped_result[criterion].max(), grouped_result[criterion].idxmax() 407 best_params = grouped_result.iloc[best_idx] 408 assert np.isclose(best_params[criterion], best_score) 409 best_kwargs = {k: v for k, v in zip(grid_search_parameters, best_params)} 410 411 return best_kwargs, best_score
Evaluate gridsearch results.
Arguments:
- result_dir: The folder with the gridsearch results.
- grid_search_parameters: The names for the gridsearch parameters.
- criterion: The metric to use for determining the best parameters.
Returns:
The best parameter setting. The evaluation score for the best setting.
414def save_grid_search_best_params(best_kwargs, best_msa, grid_search_result_dir=None): 415 # saving the best parameters estimated from grid-search in the `results` folder 416 param_df = pd.DataFrame.from_dict([best_kwargs]) 417 res_df = pd.DataFrame.from_dict([{"best_msa": best_msa}]) 418 best_param_df = pd.merge(res_df, param_df, left_index=True, right_index=True) 419 420 path_name = "grid_search_params_amg.csv" if "pred_iou_thresh" and "stability_score_thresh" in best_kwargs \ 421 else "grid_search_params_instance_segmentation_with_decoder.csv" 422 423 if grid_search_result_dir is not None: 424 os.makedirs(os.path.join(grid_search_result_dir, "results"), exist_ok=True) 425 res_path = os.path.join(grid_search_result_dir, "results", path_name) 426 else: 427 res_path = path_name 428 429 best_param_df.to_csv(res_path)
432def run_instance_segmentation_grid_search_and_inference( 433 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 434 grid_search_values: Dict[str, List], 435 val_image_paths: List[Union[str, os.PathLike]], 436 val_gt_paths: List[Union[str, os.PathLike]], 437 test_image_paths: List[Union[str, os.PathLike]], 438 embedding_dir: Optional[Union[str, os.PathLike]], 439 prediction_dir: Union[str, os.PathLike], 440 experiment_folder: Union[str, os.PathLike], 441 result_dir: Union[str, os.PathLike], 442 fixed_generate_kwargs: Optional[Dict[str, Any]] = None, 443 verbose_gs: bool = True, 444 tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None, 445) -> None: 446 """Run grid search and inference for automatic mask generation. 447 448 Please refer to the documentation of `run_instance_segmentation_grid_search` 449 for details on how to specify the grid search parameters. 450 451 Args: 452 segmenter: The class implementing the instance segmentation functionality. 453 grid_search_values: The grid search values for parameters of the `generate` function. 454 val_image_paths: The input images for the grid search. 455 val_gt_paths: The ground-truth segmentation for the grid search. 456 test_image_paths: The input images for inference. 457 embedding_dir: Folder to cache the image embeddings. 458 prediction_dir: Folder to save the predictions. 459 experiment_folder: Folder for caching best grid search parameters in 'results'. 460 result_dir: Folder to cache the evaluation results per image. 461 fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter. 462 verbose_gs: Whether to run the gridsearch for individual images in a verbose mode. 463 tiling_window_params: The parameters to decide whether to use tiling window operation 464 for automatic segmentation. 465 """ 466 run_instance_segmentation_grid_search( 467 segmenter=segmenter, 468 grid_search_values=grid_search_values, 469 image_paths=val_image_paths, 470 gt_paths=val_gt_paths, 471 result_dir=result_dir, 472 embedding_dir=embedding_dir, 473 fixed_generate_kwargs=fixed_generate_kwargs, 474 verbose_gs=verbose_gs, 475 tiling_window_params=tiling_window_params, 476 ) 477 478 best_kwargs, best_msa = evaluate_instance_segmentation_grid_search(result_dir, list(grid_search_values.keys())) 479 best_param_str = ", ".join(f"{k} = {v}" for k, v in best_kwargs.items()) 480 print("Best grid-search result:", best_msa, "with parmeters:\n", best_param_str) 481 print() 482 483 save_grid_search_best_params(best_kwargs, best_msa, experiment_folder) 484 485 generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs 486 generate_kwargs.update(best_kwargs) 487 488 # NOTE: Make sure the 'prompt_selection' values for APG are as expected 489 if "prompt_selection" in generate_kwargs: 490 generate_kwargs["prompt_selection"] = _maybe_list_value(generate_kwargs["prompt_selection"]) 491 492 run_instance_segmentation_inference( 493 segmenter=segmenter, 494 image_paths=test_image_paths, 495 embedding_dir=embedding_dir, 496 prediction_dir=prediction_dir, 497 generate_kwargs=generate_kwargs, 498 tiling_window_params=tiling_window_params, 499 )
Run grid search and inference for automatic mask generation.
Please refer to the documentation of run_instance_segmentation_grid_search
for details on how to specify the grid search parameters.
Arguments:
- segmenter: The class implementing the instance segmentation functionality.
- grid_search_values: The grid search values for parameters of the
generatefunction. - val_image_paths: The input images for the grid search.
- val_gt_paths: The ground-truth segmentation for the grid search.
- test_image_paths: The input images for inference.
- embedding_dir: Folder to cache the image embeddings.
- prediction_dir: Folder to save the predictions.
- experiment_folder: Folder for caching best grid search parameters in 'results'.
- result_dir: Folder to cache the evaluation results per image.
- fixed_generate_kwargs: Fixed keyword arguments for the
generatemethod of the segmenter. - verbose_gs: Whether to run the gridsearch for individual images in a verbose mode.
- tiling_window_params: The parameters to decide whether to use tiling window operation for automatic segmentation.