micro_sam.evaluation.instance_segmentation
Inference and evaluation for the automatic instance segmentation functionality.
1"""Inference and evaluation for the automatic instance segmentation functionality. 2""" 3 4import os 5from glob import glob 6from tqdm import tqdm 7from pathlib import Path 8from itertools import product 9from typing import Any, Dict, List, Optional, Tuple, Union 10 11import numpy as np 12import pandas as pd 13import imageio.v3 as imageio 14 15from elf.io import open_file 16from elf.evaluation import mean_segmentation_accuracy 17 18from .. import util 19from ..instance_segmentation import AMGBase, InstanceSegmentationWithDecoder, mask_data_to_segmentation 20 21 22def _get_range_of_search_values(input_vals, step): 23 if isinstance(input_vals, list): 24 search_range = np.arange(input_vals[0], input_vals[1] + step, step) 25 search_range = [round(e, 3) for e in search_range] 26 else: 27 search_range = [input_vals] 28 return search_range 29 30 31def default_grid_search_values_amg( 32 iou_thresh_values: Optional[List[float]] = None, 33 stability_score_values: Optional[List[float]] = None, 34) -> Dict[str, List[float]]: 35 """Default grid-search parameter for AMG-based instance segmentation. 36 37 Return grid search values for the two most important parameters: 38 - `pred_iou_thresh`, the threshold for keeping objects according to the IoU predicted by the model. 39 - `stability_score_thresh`, the theshold for keepong objects according to their stability. 40 41 Args: 42 iou_thresh_values: The values for `pred_iou_thresh` used in the gridsearch. 43 By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. 44 stability_score_values: The values for `stability_score_thresh` used in the gridsearch. 45 By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. 46 47 Returns: 48 The values for grid search. 49 """ 50 if iou_thresh_values is None: 51 iou_thresh_values = _get_range_of_search_values([0.6, 0.9], step=0.025) 52 if stability_score_values is None: 53 stability_score_values = _get_range_of_search_values([0.6, 0.95], step=0.025) 54 return { 55 "pred_iou_thresh": iou_thresh_values, 56 "stability_score_thresh": stability_score_values, 57 } 58 59 60def default_grid_search_values_instance_segmentation_with_decoder( 61 center_distance_threshold_values: Optional[List[float]] = None, 62 boundary_distance_threshold_values: Optional[List[float]] = None, 63 distance_smoothing_values: Optional[List[float]] = None, 64 min_size_values: Optional[List[float]] = None, 65) -> Dict[str, List[float]]: 66 """Default grid-search parameter for decoder-based instance segmentation. 67 68 Args: 69 center_distance_threshold_values: The values for `center_distance_threshold` used in the gridsearch. 70 By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used. 71 boundary_distance_threshold_values: The values for `boundary_distance_threshold` used in the gridsearch. 72 By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used. 73 distance_smoothing_values: The values for `distance_smoothing` used in the gridsearch. 74 By default values in the range from 1.0 to 2.0 with a stepsize of 0.1 will be used. 75 min_size_values: The values for `min_size` used in the gridsearch. 76 By default the values 50, 100 and 200 are used. 77 78 Returns: 79 The values for grid search. 80 """ 81 if center_distance_threshold_values is None: 82 center_distance_threshold_values = _get_range_of_search_values( 83 [0.3, 0.7], step=0.1 84 ) 85 if boundary_distance_threshold_values is None: 86 boundary_distance_threshold_values = _get_range_of_search_values( 87 [0.3, 0.7], step=0.1 88 ) 89 if distance_smoothing_values is None: 90 distance_smoothing_values = _get_range_of_search_values( 91 [1.0, 2.0], step=0.2 92 ) 93 if min_size_values is None: 94 min_size_values = [50, 100, 200] 95 return { 96 "center_distance_threshold": center_distance_threshold_values, 97 "boundary_distance_threshold": boundary_distance_threshold_values, 98 "distance_smoothing": distance_smoothing_values, 99 "min_size": min_size_values, 100 } 101 102 103def _grid_search_iteration( 104 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 105 gs_combinations: List[Dict], 106 gt: np.ndarray, 107 image_name: str, 108 fixed_generate_kwargs: Dict[str, Any], 109 result_path: Optional[Union[str, os.PathLike]], 110 verbose: bool = False, 111) -> pd.DataFrame: 112 net_list = [] 113 for gs_kwargs in tqdm(gs_combinations, disable=not verbose): 114 generate_kwargs = gs_kwargs | fixed_generate_kwargs 115 masks = segmenter.generate(**generate_kwargs) 116 117 min_object_size = generate_kwargs.get("min_mask_region_area", 0) 118 if len(masks) == 0: 119 instance_labels = np.zeros(gt.shape, dtype="uint32") 120 else: 121 instance_labels = mask_data_to_segmentation(masks, with_background=True, min_object_size=min_object_size) 122 m_sas, sas = mean_segmentation_accuracy(instance_labels, gt, return_accuracies=True) # type: ignore 123 124 result_dict = {"image_name": image_name, "mSA": m_sas, "SA50": sas[0], "SA75": sas[5]} 125 result_dict.update(gs_kwargs) 126 tmp_df = pd.DataFrame([result_dict]) 127 net_list.append(tmp_df) 128 129 img_gs_df = pd.concat(net_list) 130 img_gs_df.to_csv(result_path, index=False) 131 132 return img_gs_df 133 134 135def _load_image(path, key, roi): 136 if key is None: 137 im = imageio.imread(path) 138 if roi is not None: 139 im = im[roi] 140 return im 141 with open_file(path, "r") as f: 142 im = f[key][:] if roi is None else f[key][roi] 143 return im 144 145 146def run_instance_segmentation_grid_search( 147 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 148 grid_search_values: Dict[str, List], 149 image_paths: List[Union[str, os.PathLike]], 150 gt_paths: List[Union[str, os.PathLike]], 151 result_dir: Union[str, os.PathLike], 152 embedding_dir: Optional[Union[str, os.PathLike]], 153 fixed_generate_kwargs: Optional[Dict[str, Any]] = None, 154 verbose_gs: bool = False, 155 image_key: Optional[str] = None, 156 gt_key: Optional[str] = None, 157 rois: Optional[Tuple[slice, ...]] = None, 158) -> None: 159 """Run grid search for automatic mask generation. 160 161 The parameters and their respective value ranges for the grid search are specified via the 162 'grid_search_values' argument. For example, to run a grid search over the parameters 'pred_iou_thresh' 163 and 'stability_score_thresh', you can pass the following: 164 ``` 165 grid_search_values = { 166 "pred_iou_thresh": [0.6, 0.7, 0.8, 0.9], 167 "stability_score_thresh": [0.6, 0.7, 0.8, 0.9], 168 } 169 ``` 170 All combinations of the parameters will be checked. 171 172 You can use the functions `default_grid_search_values_instance_segmentation_with_decoder` 173 or `default_grid_search_values_amg` to get the default grid search parameters for the two 174 respective instance segmentation methods. 175 176 Args: 177 segmenter: The class implementing the instance segmentation functionality. 178 grid_search_values: The grid search values for parameters of the `generate` function. 179 image_paths: The input images for the grid search. 180 gt_paths: The ground-truth segmentation for the grid search. 181 result_dir: Folder to cache the evaluation results per image. 182 embedding_dir: Folder to cache the image embeddings. 183 fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter. 184 verbose_gs: Whether to run the grid-search for individual images in a verbose mode. 185 image_key: Key for loading the image data from a more complex file format like HDF5. 186 If not given a simple image format like tif is assumed. 187 gt_key: Key for loading the ground-truth data from a more complex file format like HDF5. 188 If not given a simple image format like tif is assumed. 189 rois: Region of interests to resetrict the evaluation to. 190 """ 191 verbose_embeddings = False 192 193 assert len(image_paths) == len(gt_paths) 194 fixed_generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs 195 196 duplicate_params = [gs_param for gs_param in grid_search_values.keys() if gs_param in fixed_generate_kwargs] 197 if duplicate_params: 198 raise ValueError( 199 "You may not pass duplicate parameters in 'grid_search_values' and 'fixed_generate_kwargs'." 200 f"The parameters {duplicate_params} are duplicated." 201 ) 202 203 # Compute all combinations of grid search values. 204 gs_combinations = product(*grid_search_values.values()) 205 # Map each combination back to a valid kwarg input. 206 gs_combinations = [ 207 {k: v for k, v in zip(grid_search_values.keys(), vals)} for vals in gs_combinations 208 ] 209 210 os.makedirs(result_dir, exist_ok=True) 211 predictor = getattr(segmenter, "_predictor", None) 212 213 for i, (image_path, gt_path) in tqdm( 214 enumerate(zip(image_paths, gt_paths)), desc="Run instance segmentation grid-search", total=len(image_paths) 215 ): 216 image_name = Path(image_path).stem 217 result_path = os.path.join(result_dir, f"{image_name}.csv") 218 219 # We skip images for which the grid search was done already. 220 if os.path.exists(result_path): 221 continue 222 223 assert os.path.exists(image_path), image_path 224 assert os.path.exists(gt_path), gt_path 225 226 image = _load_image(image_path, image_key, roi=None if rois is None else rois[i]) 227 gt = _load_image(gt_path, gt_key, roi=None if rois is None else rois[i]) 228 229 if embedding_dir is None: 230 segmenter.initialize(image) 231 else: 232 assert predictor is not None 233 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 234 image_embeddings = util.precompute_image_embeddings( 235 predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings 236 ) 237 segmenter.initialize(image, image_embeddings) 238 239 _grid_search_iteration( 240 segmenter, gs_combinations, gt, image_name, 241 fixed_generate_kwargs=fixed_generate_kwargs, result_path=result_path, verbose=verbose_gs, 242 ) 243 244 245def run_instance_segmentation_inference( 246 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 247 image_paths: List[Union[str, os.PathLike]], 248 embedding_dir: Union[str, os.PathLike], 249 prediction_dir: Union[str, os.PathLike], 250 generate_kwargs: Optional[Dict[str, Any]] = None, 251) -> None: 252 """Run inference for automatic mask generation. 253 254 Args: 255 segmenter: The class implementing the instance segmentation functionality. 256 image_paths: The input images. 257 embedding_dir: Folder to cache the image embeddings. 258 prediction_dir: Folder to save the predictions. 259 generate_kwargs: The keyword arguments for the `generate` method of the segmenter. 260 """ 261 262 verbose_embeddings = False 263 264 generate_kwargs = {} if generate_kwargs is None else generate_kwargs 265 predictor = segmenter._predictor 266 min_object_size = generate_kwargs.get("min_mask_region_area", 0) 267 268 for image_path in tqdm(image_paths, desc="Run inference for automatic mask generation"): 269 image_name = os.path.basename(image_path) 270 271 # We skip the images that already have been segmented. 272 prediction_path = os.path.join(prediction_dir, image_name) 273 if os.path.exists(prediction_path): 274 continue 275 276 assert os.path.exists(image_path), image_path 277 image = imageio.imread(image_path) 278 279 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 280 image_embeddings = util.precompute_image_embeddings( 281 predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings 282 ) 283 284 segmenter.initialize(image, image_embeddings) 285 masks = segmenter.generate(**generate_kwargs) 286 287 if len(masks) == 0: # the instance segmentation can have no masks, hence we just save empty labels 288 if isinstance(segmenter, InstanceSegmentationWithDecoder): 289 this_shape = segmenter._foreground.shape 290 elif isinstance(segmenter, AMGBase): 291 this_shape = segmenter._original_size 292 else: 293 this_shape = image.shape[-2:] 294 295 instances = np.zeros(this_shape, dtype="uint32") 296 else: 297 instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=min_object_size) 298 299 # It's important to compress here, otherwise the predictions would take up a lot of space. 300 imageio.imwrite(prediction_path, instances, compression=5) 301 302 303def evaluate_instance_segmentation_grid_search( 304 result_dir: Union[str, os.PathLike], 305 grid_search_parameters: List[str], 306 criterion: str = "mSA" 307) -> Tuple[Dict[str, Any], float]: 308 """Evaluate gridsearch results. 309 310 Args: 311 result_dir: The folder with the gridsearch results. 312 grid_search_parameters: The names for the gridsearch parameters. 313 criterion: The metric to use for determining the best parameters. 314 315 Returns: 316 The best parameter setting. 317 The evaluation score for the best setting. 318 """ 319 320 # Load all the grid search results. 321 gs_files = glob(os.path.join(result_dir, "*.csv")) 322 gs_result = pd.concat([pd.read_csv(gs_file) for gs_file in gs_files]) 323 324 # Retrieve only the relevant columns and group by the gridsearch columns. 325 gs_result = gs_result[grid_search_parameters + [criterion]].reset_index() 326 327 # Compute the mean over the grouped columns. 328 grouped_result = gs_result.groupby(grid_search_parameters).mean().reset_index() 329 330 # Find the best score and corresponding parameters. 331 best_score, best_idx = grouped_result[criterion].max(), grouped_result[criterion].idxmax() 332 best_params = grouped_result.iloc[best_idx] 333 assert np.isclose(best_params[criterion], best_score) 334 best_kwargs = {k: v for k, v in zip(grid_search_parameters, best_params)} 335 336 return best_kwargs, best_score 337 338 339def save_grid_search_best_params(best_kwargs, best_msa, grid_search_result_dir=None): 340 # saving the best parameters estimated from grid-search in the `results` folder 341 param_df = pd.DataFrame.from_dict([best_kwargs]) 342 res_df = pd.DataFrame.from_dict([{"best_msa": best_msa}]) 343 best_param_df = pd.merge(res_df, param_df, left_index=True, right_index=True) 344 345 path_name = "grid_search_params_amg.csv" if "pred_iou_thresh" and "stability_score_thresh" in best_kwargs \ 346 else "grid_search_params_instance_segmentation_with_decoder.csv" 347 348 if grid_search_result_dir is not None: 349 os.makedirs(os.path.join(grid_search_result_dir, "results"), exist_ok=True) 350 res_path = os.path.join(grid_search_result_dir, "results", path_name) 351 else: 352 res_path = path_name 353 354 best_param_df.to_csv(res_path) 355 356 357def run_instance_segmentation_grid_search_and_inference( 358 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 359 grid_search_values: Dict[str, List], 360 val_image_paths: List[Union[str, os.PathLike]], 361 val_gt_paths: List[Union[str, os.PathLike]], 362 test_image_paths: List[Union[str, os.PathLike]], 363 embedding_dir: Union[str, os.PathLike], 364 prediction_dir: Union[str, os.PathLike], 365 result_dir: Union[str, os.PathLike], 366 fixed_generate_kwargs: Optional[Dict[str, Any]] = None, 367 verbose_gs: bool = True, 368) -> None: 369 """Run grid search and inference for automatic mask generation. 370 371 Please refer to the documentation of `run_instance_segmentation_grid_search` 372 for details on how to specify the grid search parameters. 373 374 Args: 375 segmenter: The class implementing the instance segmentation functionality. 376 grid_search_values: The grid search values for parameters of the `generate` function. 377 val_image_paths: The input images for the grid search. 378 val_gt_paths: The ground-truth segmentation for the grid search. 379 test_image_paths: The input images for inference. 380 embedding_dir: Folder to cache the image embeddings. 381 prediction_dir: Folder to save the predictions. 382 result_dir: Folder to cache the evaluation results per image. 383 fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter. 384 verbose_gs: Whether to run the gridsearch for individual images in a verbose mode. 385 """ 386 run_instance_segmentation_grid_search( 387 segmenter, grid_search_values, val_image_paths, val_gt_paths, 388 result_dir=result_dir, embedding_dir=embedding_dir, 389 fixed_generate_kwargs=fixed_generate_kwargs, verbose_gs=verbose_gs, 390 ) 391 392 best_kwargs, best_msa = evaluate_instance_segmentation_grid_search(result_dir, list(grid_search_values.keys())) 393 best_param_str = ", ".join(f"{k} = {v}" for k, v in best_kwargs.items()) 394 print("Best grid-search result:", best_msa, "with parmeters:\n", best_param_str) 395 print() 396 397 save_grid_search_best_params(best_kwargs, best_msa, Path(embedding_dir).parent) 398 399 generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs 400 generate_kwargs.update(best_kwargs) 401 402 run_instance_segmentation_inference( 403 segmenter, test_image_paths, embedding_dir, prediction_dir, generate_kwargs 404 )
32def default_grid_search_values_amg( 33 iou_thresh_values: Optional[List[float]] = None, 34 stability_score_values: Optional[List[float]] = None, 35) -> Dict[str, List[float]]: 36 """Default grid-search parameter for AMG-based instance segmentation. 37 38 Return grid search values for the two most important parameters: 39 - `pred_iou_thresh`, the threshold for keeping objects according to the IoU predicted by the model. 40 - `stability_score_thresh`, the theshold for keepong objects according to their stability. 41 42 Args: 43 iou_thresh_values: The values for `pred_iou_thresh` used in the gridsearch. 44 By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. 45 stability_score_values: The values for `stability_score_thresh` used in the gridsearch. 46 By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. 47 48 Returns: 49 The values for grid search. 50 """ 51 if iou_thresh_values is None: 52 iou_thresh_values = _get_range_of_search_values([0.6, 0.9], step=0.025) 53 if stability_score_values is None: 54 stability_score_values = _get_range_of_search_values([0.6, 0.95], step=0.025) 55 return { 56 "pred_iou_thresh": iou_thresh_values, 57 "stability_score_thresh": stability_score_values, 58 }
Default grid-search parameter for AMG-based instance segmentation.
Return grid search values for the two most important parameters:
pred_iou_thresh
, the threshold for keeping objects according to the IoU predicted by the model.stability_score_thresh
, the theshold for keepong objects according to their stability.
Arguments:
- iou_thresh_values: The values for
pred_iou_thresh
used in the gridsearch. By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. - stability_score_values: The values for
stability_score_thresh
used in the gridsearch. By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
Returns:
The values for grid search.
61def default_grid_search_values_instance_segmentation_with_decoder( 62 center_distance_threshold_values: Optional[List[float]] = None, 63 boundary_distance_threshold_values: Optional[List[float]] = None, 64 distance_smoothing_values: Optional[List[float]] = None, 65 min_size_values: Optional[List[float]] = None, 66) -> Dict[str, List[float]]: 67 """Default grid-search parameter for decoder-based instance segmentation. 68 69 Args: 70 center_distance_threshold_values: The values for `center_distance_threshold` used in the gridsearch. 71 By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used. 72 boundary_distance_threshold_values: The values for `boundary_distance_threshold` used in the gridsearch. 73 By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used. 74 distance_smoothing_values: The values for `distance_smoothing` used in the gridsearch. 75 By default values in the range from 1.0 to 2.0 with a stepsize of 0.1 will be used. 76 min_size_values: The values for `min_size` used in the gridsearch. 77 By default the values 50, 100 and 200 are used. 78 79 Returns: 80 The values for grid search. 81 """ 82 if center_distance_threshold_values is None: 83 center_distance_threshold_values = _get_range_of_search_values( 84 [0.3, 0.7], step=0.1 85 ) 86 if boundary_distance_threshold_values is None: 87 boundary_distance_threshold_values = _get_range_of_search_values( 88 [0.3, 0.7], step=0.1 89 ) 90 if distance_smoothing_values is None: 91 distance_smoothing_values = _get_range_of_search_values( 92 [1.0, 2.0], step=0.2 93 ) 94 if min_size_values is None: 95 min_size_values = [50, 100, 200] 96 return { 97 "center_distance_threshold": center_distance_threshold_values, 98 "boundary_distance_threshold": boundary_distance_threshold_values, 99 "distance_smoothing": distance_smoothing_values, 100 "min_size": min_size_values, 101 }
Default grid-search parameter for decoder-based instance segmentation.
Arguments:
- center_distance_threshold_values: The values for
center_distance_threshold
used in the gridsearch. By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used. - boundary_distance_threshold_values: The values for
boundary_distance_threshold
used in the gridsearch. By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used. - distance_smoothing_values: The values for
distance_smoothing
used in the gridsearch. By default values in the range from 1.0 to 2.0 with a stepsize of 0.1 will be used. - min_size_values: The values for
min_size
used in the gridsearch. By default the values 50, 100 and 200 are used.
Returns:
The values for grid search.
147def run_instance_segmentation_grid_search( 148 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 149 grid_search_values: Dict[str, List], 150 image_paths: List[Union[str, os.PathLike]], 151 gt_paths: List[Union[str, os.PathLike]], 152 result_dir: Union[str, os.PathLike], 153 embedding_dir: Optional[Union[str, os.PathLike]], 154 fixed_generate_kwargs: Optional[Dict[str, Any]] = None, 155 verbose_gs: bool = False, 156 image_key: Optional[str] = None, 157 gt_key: Optional[str] = None, 158 rois: Optional[Tuple[slice, ...]] = None, 159) -> None: 160 """Run grid search for automatic mask generation. 161 162 The parameters and their respective value ranges for the grid search are specified via the 163 'grid_search_values' argument. For example, to run a grid search over the parameters 'pred_iou_thresh' 164 and 'stability_score_thresh', you can pass the following: 165 ``` 166 grid_search_values = { 167 "pred_iou_thresh": [0.6, 0.7, 0.8, 0.9], 168 "stability_score_thresh": [0.6, 0.7, 0.8, 0.9], 169 } 170 ``` 171 All combinations of the parameters will be checked. 172 173 You can use the functions `default_grid_search_values_instance_segmentation_with_decoder` 174 or `default_grid_search_values_amg` to get the default grid search parameters for the two 175 respective instance segmentation methods. 176 177 Args: 178 segmenter: The class implementing the instance segmentation functionality. 179 grid_search_values: The grid search values for parameters of the `generate` function. 180 image_paths: The input images for the grid search. 181 gt_paths: The ground-truth segmentation for the grid search. 182 result_dir: Folder to cache the evaluation results per image. 183 embedding_dir: Folder to cache the image embeddings. 184 fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter. 185 verbose_gs: Whether to run the grid-search for individual images in a verbose mode. 186 image_key: Key for loading the image data from a more complex file format like HDF5. 187 If not given a simple image format like tif is assumed. 188 gt_key: Key for loading the ground-truth data from a more complex file format like HDF5. 189 If not given a simple image format like tif is assumed. 190 rois: Region of interests to resetrict the evaluation to. 191 """ 192 verbose_embeddings = False 193 194 assert len(image_paths) == len(gt_paths) 195 fixed_generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs 196 197 duplicate_params = [gs_param for gs_param in grid_search_values.keys() if gs_param in fixed_generate_kwargs] 198 if duplicate_params: 199 raise ValueError( 200 "You may not pass duplicate parameters in 'grid_search_values' and 'fixed_generate_kwargs'." 201 f"The parameters {duplicate_params} are duplicated." 202 ) 203 204 # Compute all combinations of grid search values. 205 gs_combinations = product(*grid_search_values.values()) 206 # Map each combination back to a valid kwarg input. 207 gs_combinations = [ 208 {k: v for k, v in zip(grid_search_values.keys(), vals)} for vals in gs_combinations 209 ] 210 211 os.makedirs(result_dir, exist_ok=True) 212 predictor = getattr(segmenter, "_predictor", None) 213 214 for i, (image_path, gt_path) in tqdm( 215 enumerate(zip(image_paths, gt_paths)), desc="Run instance segmentation grid-search", total=len(image_paths) 216 ): 217 image_name = Path(image_path).stem 218 result_path = os.path.join(result_dir, f"{image_name}.csv") 219 220 # We skip images for which the grid search was done already. 221 if os.path.exists(result_path): 222 continue 223 224 assert os.path.exists(image_path), image_path 225 assert os.path.exists(gt_path), gt_path 226 227 image = _load_image(image_path, image_key, roi=None if rois is None else rois[i]) 228 gt = _load_image(gt_path, gt_key, roi=None if rois is None else rois[i]) 229 230 if embedding_dir is None: 231 segmenter.initialize(image) 232 else: 233 assert predictor is not None 234 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 235 image_embeddings = util.precompute_image_embeddings( 236 predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings 237 ) 238 segmenter.initialize(image, image_embeddings) 239 240 _grid_search_iteration( 241 segmenter, gs_combinations, gt, image_name, 242 fixed_generate_kwargs=fixed_generate_kwargs, result_path=result_path, verbose=verbose_gs, 243 )
Run grid search for automatic mask generation.
The parameters and their respective value ranges for the grid search are specified via the 'grid_search_values' argument. For example, to run a grid search over the parameters 'pred_iou_thresh' and 'stability_score_thresh', you can pass the following:
grid_search_values = {
"pred_iou_thresh": [0.6, 0.7, 0.8, 0.9],
"stability_score_thresh": [0.6, 0.7, 0.8, 0.9],
}
All combinations of the parameters will be checked.
You can use the functions default_grid_search_values_instance_segmentation_with_decoder
or default_grid_search_values_amg
to get the default grid search parameters for the two
respective instance segmentation methods.
Arguments:
- segmenter: The class implementing the instance segmentation functionality.
- grid_search_values: The grid search values for parameters of the
generate
function. - image_paths: The input images for the grid search.
- gt_paths: The ground-truth segmentation for the grid search.
- result_dir: Folder to cache the evaluation results per image.
- embedding_dir: Folder to cache the image embeddings.
- fixed_generate_kwargs: Fixed keyword arguments for the
generate
method of the segmenter. - verbose_gs: Whether to run the grid-search for individual images in a verbose mode.
- image_key: Key for loading the image data from a more complex file format like HDF5. If not given a simple image format like tif is assumed.
- gt_key: Key for loading the ground-truth data from a more complex file format like HDF5. If not given a simple image format like tif is assumed.
- rois: Region of interests to resetrict the evaluation to.
246def run_instance_segmentation_inference( 247 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 248 image_paths: List[Union[str, os.PathLike]], 249 embedding_dir: Union[str, os.PathLike], 250 prediction_dir: Union[str, os.PathLike], 251 generate_kwargs: Optional[Dict[str, Any]] = None, 252) -> None: 253 """Run inference for automatic mask generation. 254 255 Args: 256 segmenter: The class implementing the instance segmentation functionality. 257 image_paths: The input images. 258 embedding_dir: Folder to cache the image embeddings. 259 prediction_dir: Folder to save the predictions. 260 generate_kwargs: The keyword arguments for the `generate` method of the segmenter. 261 """ 262 263 verbose_embeddings = False 264 265 generate_kwargs = {} if generate_kwargs is None else generate_kwargs 266 predictor = segmenter._predictor 267 min_object_size = generate_kwargs.get("min_mask_region_area", 0) 268 269 for image_path in tqdm(image_paths, desc="Run inference for automatic mask generation"): 270 image_name = os.path.basename(image_path) 271 272 # We skip the images that already have been segmented. 273 prediction_path = os.path.join(prediction_dir, image_name) 274 if os.path.exists(prediction_path): 275 continue 276 277 assert os.path.exists(image_path), image_path 278 image = imageio.imread(image_path) 279 280 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 281 image_embeddings = util.precompute_image_embeddings( 282 predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings 283 ) 284 285 segmenter.initialize(image, image_embeddings) 286 masks = segmenter.generate(**generate_kwargs) 287 288 if len(masks) == 0: # the instance segmentation can have no masks, hence we just save empty labels 289 if isinstance(segmenter, InstanceSegmentationWithDecoder): 290 this_shape = segmenter._foreground.shape 291 elif isinstance(segmenter, AMGBase): 292 this_shape = segmenter._original_size 293 else: 294 this_shape = image.shape[-2:] 295 296 instances = np.zeros(this_shape, dtype="uint32") 297 else: 298 instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=min_object_size) 299 300 # It's important to compress here, otherwise the predictions would take up a lot of space. 301 imageio.imwrite(prediction_path, instances, compression=5)
Run inference for automatic mask generation.
Arguments:
- segmenter: The class implementing the instance segmentation functionality.
- image_paths: The input images.
- embedding_dir: Folder to cache the image embeddings.
- prediction_dir: Folder to save the predictions.
- generate_kwargs: The keyword arguments for the
generate
method of the segmenter.
304def evaluate_instance_segmentation_grid_search( 305 result_dir: Union[str, os.PathLike], 306 grid_search_parameters: List[str], 307 criterion: str = "mSA" 308) -> Tuple[Dict[str, Any], float]: 309 """Evaluate gridsearch results. 310 311 Args: 312 result_dir: The folder with the gridsearch results. 313 grid_search_parameters: The names for the gridsearch parameters. 314 criterion: The metric to use for determining the best parameters. 315 316 Returns: 317 The best parameter setting. 318 The evaluation score for the best setting. 319 """ 320 321 # Load all the grid search results. 322 gs_files = glob(os.path.join(result_dir, "*.csv")) 323 gs_result = pd.concat([pd.read_csv(gs_file) for gs_file in gs_files]) 324 325 # Retrieve only the relevant columns and group by the gridsearch columns. 326 gs_result = gs_result[grid_search_parameters + [criterion]].reset_index() 327 328 # Compute the mean over the grouped columns. 329 grouped_result = gs_result.groupby(grid_search_parameters).mean().reset_index() 330 331 # Find the best score and corresponding parameters. 332 best_score, best_idx = grouped_result[criterion].max(), grouped_result[criterion].idxmax() 333 best_params = grouped_result.iloc[best_idx] 334 assert np.isclose(best_params[criterion], best_score) 335 best_kwargs = {k: v for k, v in zip(grid_search_parameters, best_params)} 336 337 return best_kwargs, best_score
Evaluate gridsearch results.
Arguments:
- result_dir: The folder with the gridsearch results.
- grid_search_parameters: The names for the gridsearch parameters.
- criterion: The metric to use for determining the best parameters.
Returns:
The best parameter setting. The evaluation score for the best setting.
340def save_grid_search_best_params(best_kwargs, best_msa, grid_search_result_dir=None): 341 # saving the best parameters estimated from grid-search in the `results` folder 342 param_df = pd.DataFrame.from_dict([best_kwargs]) 343 res_df = pd.DataFrame.from_dict([{"best_msa": best_msa}]) 344 best_param_df = pd.merge(res_df, param_df, left_index=True, right_index=True) 345 346 path_name = "grid_search_params_amg.csv" if "pred_iou_thresh" and "stability_score_thresh" in best_kwargs \ 347 else "grid_search_params_instance_segmentation_with_decoder.csv" 348 349 if grid_search_result_dir is not None: 350 os.makedirs(os.path.join(grid_search_result_dir, "results"), exist_ok=True) 351 res_path = os.path.join(grid_search_result_dir, "results", path_name) 352 else: 353 res_path = path_name 354 355 best_param_df.to_csv(res_path)
358def run_instance_segmentation_grid_search_and_inference( 359 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 360 grid_search_values: Dict[str, List], 361 val_image_paths: List[Union[str, os.PathLike]], 362 val_gt_paths: List[Union[str, os.PathLike]], 363 test_image_paths: List[Union[str, os.PathLike]], 364 embedding_dir: Union[str, os.PathLike], 365 prediction_dir: Union[str, os.PathLike], 366 result_dir: Union[str, os.PathLike], 367 fixed_generate_kwargs: Optional[Dict[str, Any]] = None, 368 verbose_gs: bool = True, 369) -> None: 370 """Run grid search and inference for automatic mask generation. 371 372 Please refer to the documentation of `run_instance_segmentation_grid_search` 373 for details on how to specify the grid search parameters. 374 375 Args: 376 segmenter: The class implementing the instance segmentation functionality. 377 grid_search_values: The grid search values for parameters of the `generate` function. 378 val_image_paths: The input images for the grid search. 379 val_gt_paths: The ground-truth segmentation for the grid search. 380 test_image_paths: The input images for inference. 381 embedding_dir: Folder to cache the image embeddings. 382 prediction_dir: Folder to save the predictions. 383 result_dir: Folder to cache the evaluation results per image. 384 fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter. 385 verbose_gs: Whether to run the gridsearch for individual images in a verbose mode. 386 """ 387 run_instance_segmentation_grid_search( 388 segmenter, grid_search_values, val_image_paths, val_gt_paths, 389 result_dir=result_dir, embedding_dir=embedding_dir, 390 fixed_generate_kwargs=fixed_generate_kwargs, verbose_gs=verbose_gs, 391 ) 392 393 best_kwargs, best_msa = evaluate_instance_segmentation_grid_search(result_dir, list(grid_search_values.keys())) 394 best_param_str = ", ".join(f"{k} = {v}" for k, v in best_kwargs.items()) 395 print("Best grid-search result:", best_msa, "with parmeters:\n", best_param_str) 396 print() 397 398 save_grid_search_best_params(best_kwargs, best_msa, Path(embedding_dir).parent) 399 400 generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs 401 generate_kwargs.update(best_kwargs) 402 403 run_instance_segmentation_inference( 404 segmenter, test_image_paths, embedding_dir, prediction_dir, generate_kwargs 405 )
Run grid search and inference for automatic mask generation.
Please refer to the documentation of run_instance_segmentation_grid_search
for details on how to specify the grid search parameters.
Arguments:
- segmenter: The class implementing the instance segmentation functionality.
- grid_search_values: The grid search values for parameters of the
generate
function. - val_image_paths: The input images for the grid search.
- val_gt_paths: The ground-truth segmentation for the grid search.
- test_image_paths: The input images for inference.
- embedding_dir: Folder to cache the image embeddings.
- prediction_dir: Folder to save the predictions.
- result_dir: Folder to cache the evaluation results per image.
- fixed_generate_kwargs: Fixed keyword arguments for the
generate
method of the segmenter. - verbose_gs: Whether to run the gridsearch for individual images in a verbose mode.