micro_sam.evaluation.instance_segmentation

Inference and evaluation for the automatic instance segmentation functionality.

  1"""Inference and evaluation for the automatic instance segmentation functionality.
  2"""
  3
  4import os
  5from glob import glob
  6from tqdm import tqdm
  7from pathlib import Path
  8from itertools import product
  9from typing import Any, Dict, List, Optional, Tuple, Union
 10
 11import numpy as np
 12import pandas as pd
 13import imageio.v3 as imageio
 14
 15from elf.io import open_file
 16from elf.evaluation import mean_segmentation_accuracy, matching
 17
 18from .. import util
 19from ..instance_segmentation import AMGBase, InstanceSegmentationWithDecoder
 20
 21
 22def _get_range_of_search_values(input_vals, step):
 23    if isinstance(input_vals, list):
 24        search_range = np.arange(input_vals[0], input_vals[1] + step, step)
 25        search_range = [round(e, 3) for e in search_range]
 26    else:
 27        search_range = [input_vals]
 28    return search_range
 29
 30
 31def default_grid_search_values_amg(
 32    iou_thresh_values: Optional[List[float]] = None, stability_score_values: Optional[List[float]] = None,
 33) -> Dict[str, List[float]]:
 34    """Default grid-search parameter for AMG-based instance segmentation.
 35
 36    Return grid search values for the two most important parameters:
 37    - `pred_iou_thresh`, the threshold for keeping objects according to the IoU predicted by the model.
 38    - `stability_score_thresh`, the theshold for keepong objects according to their stability.
 39
 40    Args:
 41        iou_thresh_values: The values for `pred_iou_thresh` used in the gridsearch.
 42            By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
 43        stability_score_values: The values for `stability_score_thresh` used in the gridsearch.
 44            By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
 45
 46    Returns:
 47        The values for grid search.
 48    """
 49    if iou_thresh_values is None:
 50        iou_thresh_values = _get_range_of_search_values([0.6, 0.9], step=0.025)
 51    if stability_score_values is None:
 52        stability_score_values = _get_range_of_search_values([0.6, 0.95], step=0.025)
 53    return {
 54        "pred_iou_thresh": iou_thresh_values,
 55        "stability_score_thresh": stability_score_values,
 56    }
 57
 58
 59def default_grid_search_values_instance_segmentation_with_decoder(
 60    center_distance_threshold_values: Optional[List[float]] = None,
 61    boundary_distance_threshold_values: Optional[List[float]] = None,
 62    distance_smoothing_values: Optional[List[float]] = None,
 63    min_size_values: Optional[List[float]] = None,
 64) -> Dict[str, List[float]]:
 65    """Default grid-search parameter for decoder-based instance segmentation.
 66
 67    Args:
 68        center_distance_threshold_values: The values for `center_distance_threshold` used in the gridsearch.
 69            By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used.
 70        boundary_distance_threshold_values: The values for `boundary_distance_threshold` used in the gridsearch.
 71            By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used.
 72        distance_smoothing_values: The values for `distance_smoothing` used in the gridsearch.
 73            By default values in the range from 1.0 to 2.0 with a stepsize of 0.1 will be used.
 74        min_size_values: The values for `min_size` used in the gridsearch.
 75            By default the values 50, 100 and 200  are used.
 76
 77    Returns:
 78        The values for grid search.
 79    """
 80    if center_distance_threshold_values is None:
 81        center_distance_threshold_values = _get_range_of_search_values(
 82            [0.3, 0.7], step=0.1
 83        )
 84    if boundary_distance_threshold_values is None:
 85        boundary_distance_threshold_values = _get_range_of_search_values(
 86            [0.3, 0.7], step=0.1
 87        )
 88    if distance_smoothing_values is None:
 89        distance_smoothing_values = _get_range_of_search_values(
 90            [1.0, 2.0], step=0.2
 91        )
 92    if min_size_values is None:
 93        min_size_values = [50, 100, 200]
 94
 95    return {
 96        "center_distance_threshold": center_distance_threshold_values,
 97        "boundary_distance_threshold": boundary_distance_threshold_values,
 98        "distance_smoothing": distance_smoothing_values,
 99        "min_size": min_size_values,
100    }
101
102
103def default_grid_search_values_apg(
104    min_distance_values: Optional[List[float]] = None,
105    threshold_abs_values: Optional[List[float]] = None,
106    multimasking_values: Optional[List[float]] = None,
107    prompt_selection_values: Optional[List[float]] = None,
108    min_size_values: Optional[List[float]] = None,
109    nms_threshold_values: Optional[List[float]] = None,
110    intersection_over_min_values: Optional[List[bool]] = None,
111    mask_threshold_values: Optional[List[Union[float, str]]] = None,
112    center_distance_threshold_values: Optional[List[float]] = None,
113    boundary_distance_threshold_values: Optional[List[float]] = None,
114) -> Dict[str, List[float]]:
115    """Default grid-search parameter for APG-based instance segmentation.
116
117    Args:
118        ...
119
120    Returns:
121        The values for grid search.
122    """
123    # NOTE: The two combinations below are for distances. Since we use connected components, we don't run them!
124    # if min_distance_values is None:
125    #     min_distance_values = _get_range_of_search_values([1, 5], step=1)
126    # if threshold_abs_values is None:
127    #     threshold_abs_values = _get_range_of_search_values([0.1, 0.5], step=0.1)
128
129    # if multimasking_values is None:
130    #     multimasking_values = [True, False]
131    # if prompt_selection_values is None:
132    #     prompt_selection_values = [
133    #         "center_distances",
134    #         "boundary_distances",
135    #         "connected_components",
136    #         ["center_distances", "connected_components"],
137    #         ["center_distances", "boundary_distances"],
138    #         ["boundary_distances", "connected_components"],
139    #         ["center_distances", "boundary_distances", "connected_components"]
140    #     ]
141
142    # NOTE: The two parameters below are for connected components.
143    if center_distance_threshold_values is None:
144        center_distance_threshold_values = _get_range_of_search_values([0.3, 0.7], step=0.1)
145    if boundary_distance_threshold_values is None:
146        boundary_distance_threshold_values = _get_range_of_search_values([0.3, 0.7], step=0.1)
147
148    if min_size_values is None:
149        min_size_values = [50, 100, 200]
150    if nms_threshold_values is None:
151        nms_threshold_values = _get_range_of_search_values([0.5, 0.9], step=0.1)
152    if intersection_over_min_values is None:
153        intersection_over_min_values = [True, False]
154    # if mask_threshold_values is None:
155    #     mask_threshold_values = [None, "auto"]  # 'None' derives the default from the model.
156
157    return {
158        # "min_distance": min_distance_values,
159        # "threshold_abs": threshold_abs_values,
160        # "multimasking": multimasking_values,
161        # "prompt_selection": prompt_selection_values,
162        "center_distance_threshold": center_distance_threshold_values,
163        "boundary_distance_threshold": boundary_distance_threshold_values,
164        "min_size": min_size_values,
165        "nms_threshold": nms_threshold_values,
166        "intersection_over_min": intersection_over_min_values,
167        # "mask_threshold": mask_threshold_values,
168    }
169
170
171def _grid_search_iteration(
172    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
173    gs_combinations: List[Dict],
174    gt: np.ndarray,
175    image_name: str,
176    fixed_generate_kwargs: Dict[str, Any],
177    result_path: Optional[Union[str, os.PathLike]],
178    verbose: bool = False,
179) -> pd.DataFrame:
180    net_list = []
181    for gs_kwargs in tqdm(gs_combinations, disable=not verbose):
182        generate_kwargs = gs_kwargs | fixed_generate_kwargs
183        instance_labels = segmenter.generate(**generate_kwargs)
184        m_sas, sas = mean_segmentation_accuracy(instance_labels, gt, return_accuracies=True)
185        stats = matching(instance_labels, gt)
186
187        result_dict = {
188            "image_name": image_name,
189            "mSA": m_sas,
190            "SA50": sas[0],
191            "SA75": sas[5],
192            "Precision": stats["precision"],
193            "Recall": stats["recall"],
194            "F1": stats["f1"],
195        }
196        result_dict.update(gs_kwargs)
197        tmp_df = pd.DataFrame([result_dict])
198        net_list.append(tmp_df)
199
200    img_gs_df = pd.concat(net_list)
201    img_gs_df.to_csv(result_path, index=False)
202
203    return img_gs_df
204
205
206def _load_image(path, key, roi):
207    if key is None:
208        im = imageio.imread(path)
209        if roi is not None:
210            im = im[roi]
211        return im
212    with open_file(path, "r") as f:
213        im = f[key][:] if roi is None else f[key][roi]
214
215    return im
216
217
218def run_instance_segmentation_grid_search(
219    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
220    grid_search_values: Dict[str, List],
221    image_paths: List[Union[str, os.PathLike]],
222    gt_paths: List[Union[str, os.PathLike]],
223    result_dir: Union[str, os.PathLike],
224    embedding_dir: Optional[Union[str, os.PathLike]],
225    fixed_generate_kwargs: Optional[Dict[str, Any]] = None,
226    verbose_gs: bool = False,
227    image_key: Optional[str] = None,
228    gt_key: Optional[str] = None,
229    rois: Optional[Tuple[slice, ...]] = None,
230    tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None,
231) -> None:
232    """Run grid search for automatic mask generation.
233
234    The parameters and their respective value ranges for the grid search are specified via the
235    'grid_search_values' argument. For example, to run a grid search over the parameters 'pred_iou_thresh'
236    and 'stability_score_thresh', you can pass the following:
237    ```
238    grid_search_values = {
239        "pred_iou_thresh": [0.6, 0.7, 0.8, 0.9],
240        "stability_score_thresh": [0.6, 0.7, 0.8, 0.9],
241    }
242    ```
243    All combinations of the parameters will be checked.
244
245    You can use the functions `default_grid_search_values_instance_segmentation_with_decoder`
246    or `default_grid_search_values_amg` to get the default grid search parameters for the two
247    respective instance segmentation methods.
248
249    Args:
250        segmenter: The class implementing the instance segmentation functionality.
251        grid_search_values: The grid search values for parameters of the `generate` function.
252        image_paths: The input images for the grid search.
253        gt_paths: The ground-truth segmentation for the grid search.
254        result_dir: Folder to cache the evaluation results per image.
255        embedding_dir: Folder to cache the image embeddings.
256        fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter.
257        verbose_gs: Whether to run the grid-search for individual images in a verbose mode.
258        image_key: Key for loading the image data from a more complex file format like HDF5.
259            If not given a simple image format like tif is assumed.
260        gt_key: Key for loading the ground-truth data from a more complex file format like HDF5.
261            If not given a simple image format like tif is assumed.
262        rois: Region of interests to resetrict the evaluation to.
263        tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.
264    """
265    verbose_embeddings = False
266
267    assert len(image_paths) == len(gt_paths)
268    fixed_generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs
269
270    duplicate_params = [gs_param for gs_param in grid_search_values.keys() if gs_param in fixed_generate_kwargs]
271    if duplicate_params:
272        raise ValueError(
273            "You may not pass duplicate parameters in 'grid_search_values' and 'fixed_generate_kwargs'."
274            f"The parameters {duplicate_params} are duplicated."
275        )
276
277    # Compute all combinations of grid search values.
278    gs_combinations = product(*grid_search_values.values())
279    # Map each combination back to a valid kwarg input.
280    gs_combinations = [
281        {k: v for k, v in zip(grid_search_values.keys(), vals)} for vals in gs_combinations
282    ]
283
284    os.makedirs(result_dir, exist_ok=True)
285    predictor = getattr(segmenter, "_predictor", None)
286
287    for i, (image_path, gt_path) in tqdm(
288        enumerate(zip(image_paths, gt_paths)), desc="Run instance segmentation grid-search", total=len(image_paths)
289    ):
290        image_name = Path(image_path).stem
291        result_path = os.path.join(result_dir, f"{image_name}.csv")
292
293        # We skip images for which the grid search was done already.
294        if os.path.exists(result_path):
295            continue
296
297        assert os.path.exists(image_path), image_path
298        assert os.path.exists(gt_path), gt_path
299
300        image = _load_image(image_path, image_key, roi=None if rois is None else rois[i])
301        gt = _load_image(gt_path, gt_key, roi=None if rois is None else rois[i])
302
303        if tiling_window_params is None:
304            tiling_window_params = {}
305
306        if embedding_dir is None:
307            embedding_path = None
308            segmenter.initialize(image, **tiling_window_params)
309
310        else:
311            assert predictor is not None
312            embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
313            image_embeddings = util.precompute_image_embeddings(
314                predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings, **tiling_window_params
315            )
316            segmenter.initialize(image, image_embeddings, **tiling_window_params)
317
318        _grid_search_iteration(
319            segmenter, gs_combinations, gt, image_name,
320            fixed_generate_kwargs=fixed_generate_kwargs, result_path=result_path, verbose=verbose_gs,
321        )
322
323
324def run_instance_segmentation_inference(
325    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
326    image_paths: List[Union[str, os.PathLike]],
327    embedding_dir: Optional[Union[str, os.PathLike]],
328    prediction_dir: Union[str, os.PathLike],
329    generate_kwargs: Optional[Dict[str, Any]] = None,
330    tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None,
331) -> None:
332    """Run inference for automatic mask generation.
333
334    Args:
335        segmenter: The class implementing the instance segmentation functionality.
336        image_paths: The input images.
337        embedding_dir: Folder to cache the image embeddings.
338        prediction_dir: Folder to save the predictions.
339        generate_kwargs: The keyword arguments for the `generate` method of the segmenter.
340        tiling_window_params: The parameters to decide whether to use tiling window operation
341            for automatic segmentation.
342    """
343
344    verbose_embeddings = False
345
346    generate_kwargs = {} if generate_kwargs is None else generate_kwargs
347    predictor = segmenter._predictor
348
349    for image_path in tqdm(image_paths, desc="Run inference for automatic mask generation"):
350        image_name = os.path.basename(image_path)
351
352        # We skip the images that already have been segmented.
353        prediction_path = os.path.join(prediction_dir, image_name)
354        if os.path.exists(prediction_path):
355            continue
356
357        assert os.path.exists(image_path), image_path
358        image = imageio.imread(image_path)
359
360        if embedding_dir is None:
361            embedding_path = None
362        else:
363            assert predictor is not None
364            embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
365
366        if tiling_window_params is None:
367            tiling_window_params = {}
368
369        image_embeddings = util.precompute_image_embeddings(
370            predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings, **tiling_window_params
371        )
372
373        segmenter.initialize(image, image_embeddings, **tiling_window_params)
374        instances = segmenter.generate(**generate_kwargs)
375
376        # It's important to compress here, otherwise the predictions would take up a lot of space.
377        imageio.imwrite(prediction_path, instances, compression=5)
378
379
380def evaluate_instance_segmentation_grid_search(
381    result_dir: Union[str, os.PathLike], grid_search_parameters: List[str], criterion: str = "mSA"
382) -> Tuple[Dict[str, Any], float]:
383    """Evaluate gridsearch results.
384
385    Args:
386        result_dir: The folder with the gridsearch results.
387        grid_search_parameters: The names for the gridsearch parameters.
388        criterion: The metric to use for determining the best parameters.
389
390    Returns:
391        The best parameter setting.
392        The evaluation score for the best setting.
393    """
394    # Load all the grid search results.
395    gs_files = glob(os.path.join(result_dir, "*.csv"))
396    gs_result = pd.concat([pd.read_csv(gs_file) for gs_file in gs_files])
397
398    # Retrieve only the relevant columns and group by the gridsearch columns.
399    gs_result = gs_result[grid_search_parameters + [criterion]].reset_index()
400
401    # Compute the mean over the grouped columns.
402    grouped_result = gs_result.groupby(grid_search_parameters).mean().reset_index()
403
404    # Find the best score and corresponding parameters.
405    best_score, best_idx = grouped_result[criterion].max(), grouped_result[criterion].idxmax()
406    best_params = grouped_result.iloc[best_idx]
407    assert np.isclose(best_params[criterion], best_score)
408    best_kwargs = {k: v for k, v in zip(grid_search_parameters, best_params)}
409
410    return best_kwargs, best_score
411
412
413def save_grid_search_best_params(best_kwargs, best_msa, grid_search_result_dir=None):
414    # saving the best parameters estimated from grid-search in the `results` folder
415    param_df = pd.DataFrame.from_dict([best_kwargs])
416    res_df = pd.DataFrame.from_dict([{"best_msa": best_msa}])
417    best_param_df = pd.merge(res_df, param_df, left_index=True, right_index=True)
418
419    path_name = "grid_search_params_amg.csv" if "pred_iou_thresh" and "stability_score_thresh" in best_kwargs \
420        else "grid_search_params_instance_segmentation_with_decoder.csv"
421
422    if grid_search_result_dir is not None:
423        os.makedirs(os.path.join(grid_search_result_dir, "results"), exist_ok=True)
424        res_path = os.path.join(grid_search_result_dir, "results", path_name)
425    else:
426        res_path = path_name
427
428    best_param_df.to_csv(res_path)
429
430
431def run_instance_segmentation_grid_search_and_inference(
432    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
433    grid_search_values: Dict[str, List],
434    val_image_paths: List[Union[str, os.PathLike]],
435    val_gt_paths: List[Union[str, os.PathLike]],
436    test_image_paths: List[Union[str, os.PathLike]],
437    embedding_dir: Optional[Union[str, os.PathLike]],
438    prediction_dir: Union[str, os.PathLike],
439    experiment_folder: Union[str, os.PathLike],
440    result_dir: Union[str, os.PathLike],
441    fixed_generate_kwargs: Optional[Dict[str, Any]] = None,
442    verbose_gs: bool = True,
443    tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None,
444) -> None:
445    """Run grid search and inference for automatic mask generation.
446
447    Please refer to the documentation of `run_instance_segmentation_grid_search`
448    for details on how to specify the grid search parameters.
449
450    Args:
451        segmenter: The class implementing the instance segmentation functionality.
452        grid_search_values: The grid search values for parameters of the `generate` function.
453        val_image_paths: The input images for the grid search.
454        val_gt_paths: The ground-truth segmentation for the grid search.
455        test_image_paths: The input images for inference.
456        embedding_dir: Folder to cache the image embeddings.
457        prediction_dir: Folder to save the predictions.
458        experiment_folder: Folder for caching best grid search parameters in 'results'.
459        result_dir: Folder to cache the evaluation results per image.
460        fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter.
461        verbose_gs: Whether to run the gridsearch for individual images in a verbose mode.
462        tiling_window_params: The parameters to decide whether to use tiling window operation
463            for automatic segmentation.
464    """
465    run_instance_segmentation_grid_search(
466        segmenter=segmenter,
467        grid_search_values=grid_search_values,
468        image_paths=val_image_paths,
469        gt_paths=val_gt_paths,
470        result_dir=result_dir,
471        embedding_dir=embedding_dir,
472        fixed_generate_kwargs=fixed_generate_kwargs,
473        verbose_gs=verbose_gs,
474        tiling_window_params=tiling_window_params,
475    )
476
477    best_kwargs, best_msa = evaluate_instance_segmentation_grid_search(result_dir, list(grid_search_values.keys()))
478    best_param_str = ", ".join(f"{k} = {v}" for k, v in best_kwargs.items())
479    print("Best grid-search result:", best_msa, "with parmeters:\n", best_param_str)
480    print()
481
482    save_grid_search_best_params(best_kwargs, best_msa, experiment_folder)
483
484    generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs
485    generate_kwargs.update(best_kwargs)
486
487    # NOTE: Make sure the 'prompt_selection' values for APG are as expected
488    if "prompt_selection" in generate_kwargs:
489        generate_kwargs["prompt_selection"] = _maybe_list_value(generate_kwargs["prompt_selection"])
490
491    run_instance_segmentation_inference(
492        segmenter=segmenter,
493        image_paths=test_image_paths,
494        embedding_dir=embedding_dir,
495        prediction_dir=prediction_dir,
496        generate_kwargs=generate_kwargs,
497        tiling_window_params=tiling_window_params,
498    )
499
500
501def _maybe_list_value(val):
502    # In case it's not a string, well we ignore it.
503    if not isinstance(val, str):
504        return val
505
506    s = val.strip()
507    # Let's try to parse through values that appear to be an obvious list.
508    if s.startswith("[") and s.endswith("]"):
509        import ast
510        parsed = ast.literal_eval(s)
511        if isinstance(parsed, list):
512            return parsed
513
514    return val
def default_grid_search_values_amg( iou_thresh_values: Optional[List[float]] = None, stability_score_values: Optional[List[float]] = None) -> Dict[str, List[float]]:
32def default_grid_search_values_amg(
33    iou_thresh_values: Optional[List[float]] = None, stability_score_values: Optional[List[float]] = None,
34) -> Dict[str, List[float]]:
35    """Default grid-search parameter for AMG-based instance segmentation.
36
37    Return grid search values for the two most important parameters:
38    - `pred_iou_thresh`, the threshold for keeping objects according to the IoU predicted by the model.
39    - `stability_score_thresh`, the theshold for keepong objects according to their stability.
40
41    Args:
42        iou_thresh_values: The values for `pred_iou_thresh` used in the gridsearch.
43            By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
44        stability_score_values: The values for `stability_score_thresh` used in the gridsearch.
45            By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
46
47    Returns:
48        The values for grid search.
49    """
50    if iou_thresh_values is None:
51        iou_thresh_values = _get_range_of_search_values([0.6, 0.9], step=0.025)
52    if stability_score_values is None:
53        stability_score_values = _get_range_of_search_values([0.6, 0.95], step=0.025)
54    return {
55        "pred_iou_thresh": iou_thresh_values,
56        "stability_score_thresh": stability_score_values,
57    }

Default grid-search parameter for AMG-based instance segmentation.

Return grid search values for the two most important parameters:

  • pred_iou_thresh, the threshold for keeping objects according to the IoU predicted by the model.
  • stability_score_thresh, the theshold for keepong objects according to their stability.
Arguments:
  • iou_thresh_values: The values for pred_iou_thresh used in the gridsearch. By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
  • stability_score_values: The values for stability_score_thresh used in the gridsearch. By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
Returns:

The values for grid search.

def default_grid_search_values_instance_segmentation_with_decoder( center_distance_threshold_values: Optional[List[float]] = None, boundary_distance_threshold_values: Optional[List[float]] = None, distance_smoothing_values: Optional[List[float]] = None, min_size_values: Optional[List[float]] = None) -> Dict[str, List[float]]:
 60def default_grid_search_values_instance_segmentation_with_decoder(
 61    center_distance_threshold_values: Optional[List[float]] = None,
 62    boundary_distance_threshold_values: Optional[List[float]] = None,
 63    distance_smoothing_values: Optional[List[float]] = None,
 64    min_size_values: Optional[List[float]] = None,
 65) -> Dict[str, List[float]]:
 66    """Default grid-search parameter for decoder-based instance segmentation.
 67
 68    Args:
 69        center_distance_threshold_values: The values for `center_distance_threshold` used in the gridsearch.
 70            By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used.
 71        boundary_distance_threshold_values: The values for `boundary_distance_threshold` used in the gridsearch.
 72            By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used.
 73        distance_smoothing_values: The values for `distance_smoothing` used in the gridsearch.
 74            By default values in the range from 1.0 to 2.0 with a stepsize of 0.1 will be used.
 75        min_size_values: The values for `min_size` used in the gridsearch.
 76            By default the values 50, 100 and 200  are used.
 77
 78    Returns:
 79        The values for grid search.
 80    """
 81    if center_distance_threshold_values is None:
 82        center_distance_threshold_values = _get_range_of_search_values(
 83            [0.3, 0.7], step=0.1
 84        )
 85    if boundary_distance_threshold_values is None:
 86        boundary_distance_threshold_values = _get_range_of_search_values(
 87            [0.3, 0.7], step=0.1
 88        )
 89    if distance_smoothing_values is None:
 90        distance_smoothing_values = _get_range_of_search_values(
 91            [1.0, 2.0], step=0.2
 92        )
 93    if min_size_values is None:
 94        min_size_values = [50, 100, 200]
 95
 96    return {
 97        "center_distance_threshold": center_distance_threshold_values,
 98        "boundary_distance_threshold": boundary_distance_threshold_values,
 99        "distance_smoothing": distance_smoothing_values,
100        "min_size": min_size_values,
101    }

Default grid-search parameter for decoder-based instance segmentation.

Arguments:
  • center_distance_threshold_values: The values for center_distance_threshold used in the gridsearch. By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used.
  • boundary_distance_threshold_values: The values for boundary_distance_threshold used in the gridsearch. By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used.
  • distance_smoothing_values: The values for distance_smoothing used in the gridsearch. By default values in the range from 1.0 to 2.0 with a stepsize of 0.1 will be used.
  • min_size_values: The values for min_size used in the gridsearch. By default the values 50, 100 and 200 are used.
Returns:

The values for grid search.

def default_grid_search_values_apg( min_distance_values: Optional[List[float]] = None, threshold_abs_values: Optional[List[float]] = None, multimasking_values: Optional[List[float]] = None, prompt_selection_values: Optional[List[float]] = None, min_size_values: Optional[List[float]] = None, nms_threshold_values: Optional[List[float]] = None, intersection_over_min_values: Optional[List[bool]] = None, mask_threshold_values: Optional[List[Union[float, str]]] = None, center_distance_threshold_values: Optional[List[float]] = None, boundary_distance_threshold_values: Optional[List[float]] = None) -> Dict[str, List[float]]:
104def default_grid_search_values_apg(
105    min_distance_values: Optional[List[float]] = None,
106    threshold_abs_values: Optional[List[float]] = None,
107    multimasking_values: Optional[List[float]] = None,
108    prompt_selection_values: Optional[List[float]] = None,
109    min_size_values: Optional[List[float]] = None,
110    nms_threshold_values: Optional[List[float]] = None,
111    intersection_over_min_values: Optional[List[bool]] = None,
112    mask_threshold_values: Optional[List[Union[float, str]]] = None,
113    center_distance_threshold_values: Optional[List[float]] = None,
114    boundary_distance_threshold_values: Optional[List[float]] = None,
115) -> Dict[str, List[float]]:
116    """Default grid-search parameter for APG-based instance segmentation.
117
118    Args:
119        ...
120
121    Returns:
122        The values for grid search.
123    """
124    # NOTE: The two combinations below are for distances. Since we use connected components, we don't run them!
125    # if min_distance_values is None:
126    #     min_distance_values = _get_range_of_search_values([1, 5], step=1)
127    # if threshold_abs_values is None:
128    #     threshold_abs_values = _get_range_of_search_values([0.1, 0.5], step=0.1)
129
130    # if multimasking_values is None:
131    #     multimasking_values = [True, False]
132    # if prompt_selection_values is None:
133    #     prompt_selection_values = [
134    #         "center_distances",
135    #         "boundary_distances",
136    #         "connected_components",
137    #         ["center_distances", "connected_components"],
138    #         ["center_distances", "boundary_distances"],
139    #         ["boundary_distances", "connected_components"],
140    #         ["center_distances", "boundary_distances", "connected_components"]
141    #     ]
142
143    # NOTE: The two parameters below are for connected components.
144    if center_distance_threshold_values is None:
145        center_distance_threshold_values = _get_range_of_search_values([0.3, 0.7], step=0.1)
146    if boundary_distance_threshold_values is None:
147        boundary_distance_threshold_values = _get_range_of_search_values([0.3, 0.7], step=0.1)
148
149    if min_size_values is None:
150        min_size_values = [50, 100, 200]
151    if nms_threshold_values is None:
152        nms_threshold_values = _get_range_of_search_values([0.5, 0.9], step=0.1)
153    if intersection_over_min_values is None:
154        intersection_over_min_values = [True, False]
155    # if mask_threshold_values is None:
156    #     mask_threshold_values = [None, "auto"]  # 'None' derives the default from the model.
157
158    return {
159        # "min_distance": min_distance_values,
160        # "threshold_abs": threshold_abs_values,
161        # "multimasking": multimasking_values,
162        # "prompt_selection": prompt_selection_values,
163        "center_distance_threshold": center_distance_threshold_values,
164        "boundary_distance_threshold": boundary_distance_threshold_values,
165        "min_size": min_size_values,
166        "nms_threshold": nms_threshold_values,
167        "intersection_over_min": intersection_over_min_values,
168        # "mask_threshold": mask_threshold_values,
169    }

Default grid-search parameter for APG-based instance segmentation.

Arguments:
  • ...
Returns:

The values for grid search.

def run_instance_segmentation_inference( segmenter: Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder], image_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike, NoneType], prediction_dir: Union[str, os.PathLike], generate_kwargs: Optional[Dict[str, Any]] = None, tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None) -> None:
325def run_instance_segmentation_inference(
326    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
327    image_paths: List[Union[str, os.PathLike]],
328    embedding_dir: Optional[Union[str, os.PathLike]],
329    prediction_dir: Union[str, os.PathLike],
330    generate_kwargs: Optional[Dict[str, Any]] = None,
331    tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None,
332) -> None:
333    """Run inference for automatic mask generation.
334
335    Args:
336        segmenter: The class implementing the instance segmentation functionality.
337        image_paths: The input images.
338        embedding_dir: Folder to cache the image embeddings.
339        prediction_dir: Folder to save the predictions.
340        generate_kwargs: The keyword arguments for the `generate` method of the segmenter.
341        tiling_window_params: The parameters to decide whether to use tiling window operation
342            for automatic segmentation.
343    """
344
345    verbose_embeddings = False
346
347    generate_kwargs = {} if generate_kwargs is None else generate_kwargs
348    predictor = segmenter._predictor
349
350    for image_path in tqdm(image_paths, desc="Run inference for automatic mask generation"):
351        image_name = os.path.basename(image_path)
352
353        # We skip the images that already have been segmented.
354        prediction_path = os.path.join(prediction_dir, image_name)
355        if os.path.exists(prediction_path):
356            continue
357
358        assert os.path.exists(image_path), image_path
359        image = imageio.imread(image_path)
360
361        if embedding_dir is None:
362            embedding_path = None
363        else:
364            assert predictor is not None
365            embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
366
367        if tiling_window_params is None:
368            tiling_window_params = {}
369
370        image_embeddings = util.precompute_image_embeddings(
371            predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings, **tiling_window_params
372        )
373
374        segmenter.initialize(image, image_embeddings, **tiling_window_params)
375        instances = segmenter.generate(**generate_kwargs)
376
377        # It's important to compress here, otherwise the predictions would take up a lot of space.
378        imageio.imwrite(prediction_path, instances, compression=5)

Run inference for automatic mask generation.

Arguments:
  • segmenter: The class implementing the instance segmentation functionality.
  • image_paths: The input images.
  • embedding_dir: Folder to cache the image embeddings.
  • prediction_dir: Folder to save the predictions.
  • generate_kwargs: The keyword arguments for the generate method of the segmenter.
  • tiling_window_params: The parameters to decide whether to use tiling window operation for automatic segmentation.
def save_grid_search_best_params(best_kwargs, best_msa, grid_search_result_dir=None):
414def save_grid_search_best_params(best_kwargs, best_msa, grid_search_result_dir=None):
415    # saving the best parameters estimated from grid-search in the `results` folder
416    param_df = pd.DataFrame.from_dict([best_kwargs])
417    res_df = pd.DataFrame.from_dict([{"best_msa": best_msa}])
418    best_param_df = pd.merge(res_df, param_df, left_index=True, right_index=True)
419
420    path_name = "grid_search_params_amg.csv" if "pred_iou_thresh" and "stability_score_thresh" in best_kwargs \
421        else "grid_search_params_instance_segmentation_with_decoder.csv"
422
423    if grid_search_result_dir is not None:
424        os.makedirs(os.path.join(grid_search_result_dir, "results"), exist_ok=True)
425        res_path = os.path.join(grid_search_result_dir, "results", path_name)
426    else:
427        res_path = path_name
428
429    best_param_df.to_csv(res_path)
def run_instance_segmentation_grid_search_and_inference( segmenter: Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder], grid_search_values: Dict[str, List], val_image_paths: List[Union[str, os.PathLike]], val_gt_paths: List[Union[str, os.PathLike]], test_image_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike, NoneType], prediction_dir: Union[str, os.PathLike], experiment_folder: Union[str, os.PathLike], result_dir: Union[str, os.PathLike], fixed_generate_kwargs: Optional[Dict[str, Any]] = None, verbose_gs: bool = True, tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None) -> None:
432def run_instance_segmentation_grid_search_and_inference(
433    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
434    grid_search_values: Dict[str, List],
435    val_image_paths: List[Union[str, os.PathLike]],
436    val_gt_paths: List[Union[str, os.PathLike]],
437    test_image_paths: List[Union[str, os.PathLike]],
438    embedding_dir: Optional[Union[str, os.PathLike]],
439    prediction_dir: Union[str, os.PathLike],
440    experiment_folder: Union[str, os.PathLike],
441    result_dir: Union[str, os.PathLike],
442    fixed_generate_kwargs: Optional[Dict[str, Any]] = None,
443    verbose_gs: bool = True,
444    tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None,
445) -> None:
446    """Run grid search and inference for automatic mask generation.
447
448    Please refer to the documentation of `run_instance_segmentation_grid_search`
449    for details on how to specify the grid search parameters.
450
451    Args:
452        segmenter: The class implementing the instance segmentation functionality.
453        grid_search_values: The grid search values for parameters of the `generate` function.
454        val_image_paths: The input images for the grid search.
455        val_gt_paths: The ground-truth segmentation for the grid search.
456        test_image_paths: The input images for inference.
457        embedding_dir: Folder to cache the image embeddings.
458        prediction_dir: Folder to save the predictions.
459        experiment_folder: Folder for caching best grid search parameters in 'results'.
460        result_dir: Folder to cache the evaluation results per image.
461        fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter.
462        verbose_gs: Whether to run the gridsearch for individual images in a verbose mode.
463        tiling_window_params: The parameters to decide whether to use tiling window operation
464            for automatic segmentation.
465    """
466    run_instance_segmentation_grid_search(
467        segmenter=segmenter,
468        grid_search_values=grid_search_values,
469        image_paths=val_image_paths,
470        gt_paths=val_gt_paths,
471        result_dir=result_dir,
472        embedding_dir=embedding_dir,
473        fixed_generate_kwargs=fixed_generate_kwargs,
474        verbose_gs=verbose_gs,
475        tiling_window_params=tiling_window_params,
476    )
477
478    best_kwargs, best_msa = evaluate_instance_segmentation_grid_search(result_dir, list(grid_search_values.keys()))
479    best_param_str = ", ".join(f"{k} = {v}" for k, v in best_kwargs.items())
480    print("Best grid-search result:", best_msa, "with parmeters:\n", best_param_str)
481    print()
482
483    save_grid_search_best_params(best_kwargs, best_msa, experiment_folder)
484
485    generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs
486    generate_kwargs.update(best_kwargs)
487
488    # NOTE: Make sure the 'prompt_selection' values for APG are as expected
489    if "prompt_selection" in generate_kwargs:
490        generate_kwargs["prompt_selection"] = _maybe_list_value(generate_kwargs["prompt_selection"])
491
492    run_instance_segmentation_inference(
493        segmenter=segmenter,
494        image_paths=test_image_paths,
495        embedding_dir=embedding_dir,
496        prediction_dir=prediction_dir,
497        generate_kwargs=generate_kwargs,
498        tiling_window_params=tiling_window_params,
499    )

Run grid search and inference for automatic mask generation.

Please refer to the documentation of run_instance_segmentation_grid_search for details on how to specify the grid search parameters.

Arguments:
  • segmenter: The class implementing the instance segmentation functionality.
  • grid_search_values: The grid search values for parameters of the generate function.
  • val_image_paths: The input images for the grid search.
  • val_gt_paths: The ground-truth segmentation for the grid search.
  • test_image_paths: The input images for inference.
  • embedding_dir: Folder to cache the image embeddings.
  • prediction_dir: Folder to save the predictions.
  • experiment_folder: Folder for caching best grid search parameters in 'results'.
  • result_dir: Folder to cache the evaluation results per image.
  • fixed_generate_kwargs: Fixed keyword arguments for the generate method of the segmenter.
  • verbose_gs: Whether to run the gridsearch for individual images in a verbose mode.
  • tiling_window_params: The parameters to decide whether to use tiling window operation for automatic segmentation.