micro_sam.evaluation.instance_segmentation

Inference and evaluation for the automatic instance segmentation functionality.

  1"""Inference and evaluation for the automatic instance segmentation functionality.
  2"""
  3
  4import os
  5from glob import glob
  6from tqdm import tqdm
  7from pathlib import Path
  8from itertools import product
  9from typing import Any, Dict, List, Optional, Tuple, Union
 10
 11import numpy as np
 12import pandas as pd
 13import imageio.v3 as imageio
 14
 15from elf.io import open_file
 16from elf.evaluation import mean_segmentation_accuracy
 17
 18from .. import util
 19from ..instance_segmentation import AMGBase, InstanceSegmentationWithDecoder, mask_data_to_segmentation
 20
 21
 22def _get_range_of_search_values(input_vals, step):
 23    if isinstance(input_vals, list):
 24        search_range = np.arange(input_vals[0], input_vals[1] + step, step)
 25        search_range = [round(e, 3) for e in search_range]
 26    else:
 27        search_range = [input_vals]
 28    return search_range
 29
 30
 31def default_grid_search_values_amg(
 32    iou_thresh_values: Optional[List[float]] = None, stability_score_values: Optional[List[float]] = None,
 33) -> Dict[str, List[float]]:
 34    """Default grid-search parameter for AMG-based instance segmentation.
 35
 36    Return grid search values for the two most important parameters:
 37    - `pred_iou_thresh`, the threshold for keeping objects according to the IoU predicted by the model.
 38    - `stability_score_thresh`, the theshold for keepong objects according to their stability.
 39
 40    Args:
 41        iou_thresh_values: The values for `pred_iou_thresh` used in the gridsearch.
 42            By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
 43        stability_score_values: The values for `stability_score_thresh` used in the gridsearch.
 44            By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
 45
 46    Returns:
 47        The values for grid search.
 48    """
 49    if iou_thresh_values is None:
 50        iou_thresh_values = _get_range_of_search_values([0.6, 0.9], step=0.025)
 51    if stability_score_values is None:
 52        stability_score_values = _get_range_of_search_values([0.6, 0.95], step=0.025)
 53    return {
 54        "pred_iou_thresh": iou_thresh_values,
 55        "stability_score_thresh": stability_score_values,
 56    }
 57
 58
 59def default_grid_search_values_instance_segmentation_with_decoder(
 60    center_distance_threshold_values: Optional[List[float]] = None,
 61    boundary_distance_threshold_values: Optional[List[float]] = None,
 62    distance_smoothing_values: Optional[List[float]] = None,
 63    min_size_values: Optional[List[float]] = None,
 64) -> Dict[str, List[float]]:
 65    """Default grid-search parameter for decoder-based instance segmentation.
 66
 67    Args:
 68        center_distance_threshold_values: The values for `center_distance_threshold` used in the gridsearch.
 69            By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used.
 70        boundary_distance_threshold_values: The values for `boundary_distance_threshold` used in the gridsearch.
 71            By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used.
 72        distance_smoothing_values: The values for `distance_smoothing` used in the gridsearch.
 73            By default values in the range from 1.0 to 2.0 with a stepsize of 0.1 will be used.
 74        min_size_values: The values for `min_size` used in the gridsearch.
 75            By default the values 50, 100 and 200  are used.
 76
 77    Returns:
 78        The values for grid search.
 79    """
 80    if center_distance_threshold_values is None:
 81        center_distance_threshold_values = _get_range_of_search_values(
 82            [0.3, 0.7], step=0.1
 83        )
 84    if boundary_distance_threshold_values is None:
 85        boundary_distance_threshold_values = _get_range_of_search_values(
 86            [0.3, 0.7], step=0.1
 87        )
 88    if distance_smoothing_values is None:
 89        distance_smoothing_values = _get_range_of_search_values(
 90            [1.0, 2.0], step=0.2
 91        )
 92    if min_size_values is None:
 93        min_size_values = [50, 100, 200]
 94
 95    return {
 96        "center_distance_threshold": center_distance_threshold_values,
 97        "boundary_distance_threshold": boundary_distance_threshold_values,
 98        "distance_smoothing": distance_smoothing_values,
 99        "min_size": min_size_values,
100    }
101
102
103def _grid_search_iteration(
104    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
105    gs_combinations: List[Dict],
106    gt: np.ndarray,
107    image_name: str,
108    fixed_generate_kwargs: Dict[str, Any],
109    result_path: Optional[Union[str, os.PathLike]],
110    verbose: bool = False,
111) -> pd.DataFrame:
112    net_list = []
113    for gs_kwargs in tqdm(gs_combinations, disable=not verbose):
114        generate_kwargs = gs_kwargs | fixed_generate_kwargs
115        masks = segmenter.generate(**generate_kwargs)
116
117        min_object_size = generate_kwargs.get("min_mask_region_area", 0)
118        if len(masks) == 0:
119            instance_labels = np.zeros(gt.shape, dtype="uint32")
120        else:
121            instance_labels = mask_data_to_segmentation(masks, with_background=True, min_object_size=min_object_size)
122        m_sas, sas = mean_segmentation_accuracy(instance_labels, gt, return_accuracies=True)  # type: ignore
123
124        result_dict = {"image_name": image_name, "mSA": m_sas, "SA50": sas[0], "SA75": sas[5]}
125        result_dict.update(gs_kwargs)
126        tmp_df = pd.DataFrame([result_dict])
127        net_list.append(tmp_df)
128
129    img_gs_df = pd.concat(net_list)
130    img_gs_df.to_csv(result_path, index=False)
131
132    return img_gs_df
133
134
135def _load_image(path, key, roi):
136    if key is None:
137        im = imageio.imread(path)
138        if roi is not None:
139            im = im[roi]
140        return im
141    with open_file(path, "r") as f:
142        im = f[key][:] if roi is None else f[key][roi]
143
144    return im
145
146
147def run_instance_segmentation_grid_search(
148    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
149    grid_search_values: Dict[str, List],
150    image_paths: List[Union[str, os.PathLike]],
151    gt_paths: List[Union[str, os.PathLike]],
152    result_dir: Union[str, os.PathLike],
153    embedding_dir: Optional[Union[str, os.PathLike]],
154    fixed_generate_kwargs: Optional[Dict[str, Any]] = None,
155    verbose_gs: bool = False,
156    image_key: Optional[str] = None,
157    gt_key: Optional[str] = None,
158    rois: Optional[Tuple[slice, ...]] = None,
159    tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None,
160) -> None:
161    """Run grid search for automatic mask generation.
162
163    The parameters and their respective value ranges for the grid search are specified via the
164    'grid_search_values' argument. For example, to run a grid search over the parameters 'pred_iou_thresh'
165    and 'stability_score_thresh', you can pass the following:
166    ```
167    grid_search_values = {
168        "pred_iou_thresh": [0.6, 0.7, 0.8, 0.9],
169        "stability_score_thresh": [0.6, 0.7, 0.8, 0.9],
170    }
171    ```
172    All combinations of the parameters will be checked.
173
174    You can use the functions `default_grid_search_values_instance_segmentation_with_decoder`
175    or `default_grid_search_values_amg` to get the default grid search parameters for the two
176    respective instance segmentation methods.
177
178    Args:
179        segmenter: The class implementing the instance segmentation functionality.
180        grid_search_values: The grid search values for parameters of the `generate` function.
181        image_paths: The input images for the grid search.
182        gt_paths: The ground-truth segmentation for the grid search.
183        result_dir: Folder to cache the evaluation results per image.
184        embedding_dir: Folder to cache the image embeddings.
185        fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter.
186        verbose_gs: Whether to run the grid-search for individual images in a verbose mode.
187        image_key: Key for loading the image data from a more complex file format like HDF5.
188            If not given a simple image format like tif is assumed.
189        gt_key: Key for loading the ground-truth data from a more complex file format like HDF5.
190            If not given a simple image format like tif is assumed.
191        rois: Region of interests to resetrict the evaluation to.
192        tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.
193    """
194    verbose_embeddings = False
195
196    assert len(image_paths) == len(gt_paths)
197    fixed_generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs
198
199    duplicate_params = [gs_param for gs_param in grid_search_values.keys() if gs_param in fixed_generate_kwargs]
200    if duplicate_params:
201        raise ValueError(
202            "You may not pass duplicate parameters in 'grid_search_values' and 'fixed_generate_kwargs'."
203            f"The parameters {duplicate_params} are duplicated."
204        )
205
206    # Compute all combinations of grid search values.
207    gs_combinations = product(*grid_search_values.values())
208    # Map each combination back to a valid kwarg input.
209    gs_combinations = [
210        {k: v for k, v in zip(grid_search_values.keys(), vals)} for vals in gs_combinations
211    ]
212
213    os.makedirs(result_dir, exist_ok=True)
214    predictor = getattr(segmenter, "_predictor", None)
215
216    for i, (image_path, gt_path) in tqdm(
217        enumerate(zip(image_paths, gt_paths)), desc="Run instance segmentation grid-search", total=len(image_paths)
218    ):
219        image_name = Path(image_path).stem
220        result_path = os.path.join(result_dir, f"{image_name}.csv")
221
222        # We skip images for which the grid search was done already.
223        if os.path.exists(result_path):
224            continue
225
226        assert os.path.exists(image_path), image_path
227        assert os.path.exists(gt_path), gt_path
228
229        image = _load_image(image_path, image_key, roi=None if rois is None else rois[i])
230        gt = _load_image(gt_path, gt_key, roi=None if rois is None else rois[i])
231
232        if embedding_dir is None:
233            embedding_path = None
234        else:
235            assert predictor is not None
236            embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
237
238        if tiling_window_params is None:
239            tiling_window_params = {}
240
241        image_embeddings = util.precompute_image_embeddings(
242            predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings, **tiling_window_params
243        )
244
245        segmenter.initialize(image, image_embeddings, **tiling_window_params)
246
247        _grid_search_iteration(
248            segmenter, gs_combinations, gt, image_name,
249            fixed_generate_kwargs=fixed_generate_kwargs, result_path=result_path, verbose=verbose_gs,
250        )
251
252
253def run_instance_segmentation_inference(
254    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
255    image_paths: List[Union[str, os.PathLike]],
256    embedding_dir: Optional[Union[str, os.PathLike]],
257    prediction_dir: Union[str, os.PathLike],
258    generate_kwargs: Optional[Dict[str, Any]] = None,
259    tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None,
260) -> None:
261    """Run inference for automatic mask generation.
262
263    Args:
264        segmenter: The class implementing the instance segmentation functionality.
265        image_paths: The input images.
266        embedding_dir: Folder to cache the image embeddings.
267        prediction_dir: Folder to save the predictions.
268        generate_kwargs: The keyword arguments for the `generate` method of the segmenter.
269        tiling_window_params: The parameters to decide whether to use tiling window operation
270            for automatic segmentation.
271    """
272
273    verbose_embeddings = False
274
275    generate_kwargs = {} if generate_kwargs is None else generate_kwargs
276    predictor = segmenter._predictor
277    min_object_size = generate_kwargs.get("min_mask_region_area", 0)
278
279    for image_path in tqdm(image_paths, desc="Run inference for automatic mask generation"):
280        image_name = os.path.basename(image_path)
281
282        # We skip the images that already have been segmented.
283        prediction_path = os.path.join(prediction_dir, image_name)
284        if os.path.exists(prediction_path):
285            continue
286
287        assert os.path.exists(image_path), image_path
288        image = imageio.imread(image_path)
289
290        if embedding_dir is None:
291            embedding_path = None
292        else:
293            assert predictor is not None
294            embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
295
296        if tiling_window_params is None:
297            tiling_window_params = {}
298
299        image_embeddings = util.precompute_image_embeddings(
300            predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings, **tiling_window_params
301        )
302
303        segmenter.initialize(image, image_embeddings, **tiling_window_params)
304
305        masks = segmenter.generate(**generate_kwargs)
306
307        if len(masks) == 0:  # the instance segmentation can have no masks, hence we just save empty labels
308            if isinstance(segmenter, InstanceSegmentationWithDecoder):
309                this_shape = segmenter._foreground.shape
310            elif isinstance(segmenter, AMGBase):
311                this_shape = segmenter._original_size
312            else:
313                this_shape = image.shape[-2:]
314
315            instances = np.zeros(this_shape, dtype="uint32")
316        else:
317            instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=min_object_size)
318
319        # It's important to compress here, otherwise the predictions would take up a lot of space.
320        imageio.imwrite(prediction_path, instances, compression=5)
321
322
323def evaluate_instance_segmentation_grid_search(
324    result_dir: Union[str, os.PathLike], grid_search_parameters: List[str], criterion: str = "mSA"
325) -> Tuple[Dict[str, Any], float]:
326    """Evaluate gridsearch results.
327
328    Args:
329        result_dir: The folder with the gridsearch results.
330        grid_search_parameters: The names for the gridsearch parameters.
331        criterion: The metric to use for determining the best parameters.
332
333    Returns:
334        The best parameter setting.
335        The evaluation score for the best setting.
336    """
337    # Load all the grid search results.
338    gs_files = glob(os.path.join(result_dir, "*.csv"))
339    gs_result = pd.concat([pd.read_csv(gs_file) for gs_file in gs_files])
340
341    # Retrieve only the relevant columns and group by the gridsearch columns.
342    gs_result = gs_result[grid_search_parameters + [criterion]].reset_index()
343
344    # Compute the mean over the grouped columns.
345    grouped_result = gs_result.groupby(grid_search_parameters).mean().reset_index()
346
347    # Find the best score and corresponding parameters.
348    best_score, best_idx = grouped_result[criterion].max(), grouped_result[criterion].idxmax()
349    best_params = grouped_result.iloc[best_idx]
350    assert np.isclose(best_params[criterion], best_score)
351    best_kwargs = {k: v for k, v in zip(grid_search_parameters, best_params)}
352
353    return best_kwargs, best_score
354
355
356def save_grid_search_best_params(best_kwargs, best_msa, grid_search_result_dir=None):
357    # saving the best parameters estimated from grid-search in the `results` folder
358    param_df = pd.DataFrame.from_dict([best_kwargs])
359    res_df = pd.DataFrame.from_dict([{"best_msa": best_msa}])
360    best_param_df = pd.merge(res_df, param_df, left_index=True, right_index=True)
361
362    path_name = "grid_search_params_amg.csv" if "pred_iou_thresh" and "stability_score_thresh" in best_kwargs \
363        else "grid_search_params_instance_segmentation_with_decoder.csv"
364
365    if grid_search_result_dir is not None:
366        os.makedirs(os.path.join(grid_search_result_dir, "results"), exist_ok=True)
367        res_path = os.path.join(grid_search_result_dir, "results", path_name)
368    else:
369        res_path = path_name
370
371    best_param_df.to_csv(res_path)
372
373
374def run_instance_segmentation_grid_search_and_inference(
375    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
376    grid_search_values: Dict[str, List],
377    val_image_paths: List[Union[str, os.PathLike]],
378    val_gt_paths: List[Union[str, os.PathLike]],
379    test_image_paths: List[Union[str, os.PathLike]],
380    embedding_dir: Optional[Union[str, os.PathLike]],
381    prediction_dir: Union[str, os.PathLike],
382    experiment_folder: Union[str, os.PathLike],
383    result_dir: Union[str, os.PathLike],
384    fixed_generate_kwargs: Optional[Dict[str, Any]] = None,
385    verbose_gs: bool = True,
386    tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None,
387) -> None:
388    """Run grid search and inference for automatic mask generation.
389
390    Please refer to the documentation of `run_instance_segmentation_grid_search`
391    for details on how to specify the grid search parameters.
392
393    Args:
394        segmenter: The class implementing the instance segmentation functionality.
395        grid_search_values: The grid search values for parameters of the `generate` function.
396        val_image_paths: The input images for the grid search.
397        val_gt_paths: The ground-truth segmentation for the grid search.
398        test_image_paths: The input images for inference.
399        embedding_dir: Folder to cache the image embeddings.
400        prediction_dir: Folder to save the predictions.
401        experiment_folder: Folder for caching best grid search parameters in 'results'.
402        result_dir: Folder to cache the evaluation results per image.
403        fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter.
404        verbose_gs: Whether to run the gridsearch for individual images in a verbose mode.
405        tiling_window_params: The parameters to decide whether to use tiling window operation
406            for automatic segmentation.
407    """
408    run_instance_segmentation_grid_search(
409        segmenter=segmenter,
410        grid_search_values=grid_search_values,
411        image_paths=val_image_paths,
412        gt_paths=val_gt_paths,
413        result_dir=result_dir,
414        embedding_dir=embedding_dir,
415        fixed_generate_kwargs=fixed_generate_kwargs,
416        verbose_gs=verbose_gs,
417        tiling_window_params=tiling_window_params,
418    )
419
420    best_kwargs, best_msa = evaluate_instance_segmentation_grid_search(result_dir, list(grid_search_values.keys()))
421    best_param_str = ", ".join(f"{k} = {v}" for k, v in best_kwargs.items())
422    print("Best grid-search result:", best_msa, "with parmeters:\n", best_param_str)
423    print()
424
425    save_grid_search_best_params(best_kwargs, best_msa, experiment_folder)
426
427    generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs
428    generate_kwargs.update(best_kwargs)
429
430    run_instance_segmentation_inference(
431        segmenter=segmenter,
432        image_paths=test_image_paths,
433        embedding_dir=embedding_dir,
434        prediction_dir=prediction_dir,
435        generate_kwargs=generate_kwargs,
436        tiling_window_params=tiling_window_params,
437    )
def default_grid_search_values_amg( iou_thresh_values: Optional[List[float]] = None, stability_score_values: Optional[List[float]] = None) -> Dict[str, List[float]]:
32def default_grid_search_values_amg(
33    iou_thresh_values: Optional[List[float]] = None, stability_score_values: Optional[List[float]] = None,
34) -> Dict[str, List[float]]:
35    """Default grid-search parameter for AMG-based instance segmentation.
36
37    Return grid search values for the two most important parameters:
38    - `pred_iou_thresh`, the threshold for keeping objects according to the IoU predicted by the model.
39    - `stability_score_thresh`, the theshold for keepong objects according to their stability.
40
41    Args:
42        iou_thresh_values: The values for `pred_iou_thresh` used in the gridsearch.
43            By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
44        stability_score_values: The values for `stability_score_thresh` used in the gridsearch.
45            By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
46
47    Returns:
48        The values for grid search.
49    """
50    if iou_thresh_values is None:
51        iou_thresh_values = _get_range_of_search_values([0.6, 0.9], step=0.025)
52    if stability_score_values is None:
53        stability_score_values = _get_range_of_search_values([0.6, 0.95], step=0.025)
54    return {
55        "pred_iou_thresh": iou_thresh_values,
56        "stability_score_thresh": stability_score_values,
57    }

Default grid-search parameter for AMG-based instance segmentation.

Return grid search values for the two most important parameters:

  • pred_iou_thresh, the threshold for keeping objects according to the IoU predicted by the model.
  • stability_score_thresh, the theshold for keepong objects according to their stability.
Arguments:
  • iou_thresh_values: The values for pred_iou_thresh used in the gridsearch. By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
  • stability_score_values: The values for stability_score_thresh used in the gridsearch. By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
Returns:

The values for grid search.

def default_grid_search_values_instance_segmentation_with_decoder( center_distance_threshold_values: Optional[List[float]] = None, boundary_distance_threshold_values: Optional[List[float]] = None, distance_smoothing_values: Optional[List[float]] = None, min_size_values: Optional[List[float]] = None) -> Dict[str, List[float]]:
 60def default_grid_search_values_instance_segmentation_with_decoder(
 61    center_distance_threshold_values: Optional[List[float]] = None,
 62    boundary_distance_threshold_values: Optional[List[float]] = None,
 63    distance_smoothing_values: Optional[List[float]] = None,
 64    min_size_values: Optional[List[float]] = None,
 65) -> Dict[str, List[float]]:
 66    """Default grid-search parameter for decoder-based instance segmentation.
 67
 68    Args:
 69        center_distance_threshold_values: The values for `center_distance_threshold` used in the gridsearch.
 70            By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used.
 71        boundary_distance_threshold_values: The values for `boundary_distance_threshold` used in the gridsearch.
 72            By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used.
 73        distance_smoothing_values: The values for `distance_smoothing` used in the gridsearch.
 74            By default values in the range from 1.0 to 2.0 with a stepsize of 0.1 will be used.
 75        min_size_values: The values for `min_size` used in the gridsearch.
 76            By default the values 50, 100 and 200  are used.
 77
 78    Returns:
 79        The values for grid search.
 80    """
 81    if center_distance_threshold_values is None:
 82        center_distance_threshold_values = _get_range_of_search_values(
 83            [0.3, 0.7], step=0.1
 84        )
 85    if boundary_distance_threshold_values is None:
 86        boundary_distance_threshold_values = _get_range_of_search_values(
 87            [0.3, 0.7], step=0.1
 88        )
 89    if distance_smoothing_values is None:
 90        distance_smoothing_values = _get_range_of_search_values(
 91            [1.0, 2.0], step=0.2
 92        )
 93    if min_size_values is None:
 94        min_size_values = [50, 100, 200]
 95
 96    return {
 97        "center_distance_threshold": center_distance_threshold_values,
 98        "boundary_distance_threshold": boundary_distance_threshold_values,
 99        "distance_smoothing": distance_smoothing_values,
100        "min_size": min_size_values,
101    }

Default grid-search parameter for decoder-based instance segmentation.

Arguments:
  • center_distance_threshold_values: The values for center_distance_threshold used in the gridsearch. By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used.
  • boundary_distance_threshold_values: The values for boundary_distance_threshold used in the gridsearch. By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used.
  • distance_smoothing_values: The values for distance_smoothing used in the gridsearch. By default values in the range from 1.0 to 2.0 with a stepsize of 0.1 will be used.
  • min_size_values: The values for min_size used in the gridsearch. By default the values 50, 100 and 200 are used.
Returns:

The values for grid search.

def run_instance_segmentation_inference( segmenter: Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder], image_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike, NoneType], prediction_dir: Union[str, os.PathLike], generate_kwargs: Optional[Dict[str, Any]] = None, tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None) -> None:
254def run_instance_segmentation_inference(
255    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
256    image_paths: List[Union[str, os.PathLike]],
257    embedding_dir: Optional[Union[str, os.PathLike]],
258    prediction_dir: Union[str, os.PathLike],
259    generate_kwargs: Optional[Dict[str, Any]] = None,
260    tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None,
261) -> None:
262    """Run inference for automatic mask generation.
263
264    Args:
265        segmenter: The class implementing the instance segmentation functionality.
266        image_paths: The input images.
267        embedding_dir: Folder to cache the image embeddings.
268        prediction_dir: Folder to save the predictions.
269        generate_kwargs: The keyword arguments for the `generate` method of the segmenter.
270        tiling_window_params: The parameters to decide whether to use tiling window operation
271            for automatic segmentation.
272    """
273
274    verbose_embeddings = False
275
276    generate_kwargs = {} if generate_kwargs is None else generate_kwargs
277    predictor = segmenter._predictor
278    min_object_size = generate_kwargs.get("min_mask_region_area", 0)
279
280    for image_path in tqdm(image_paths, desc="Run inference for automatic mask generation"):
281        image_name = os.path.basename(image_path)
282
283        # We skip the images that already have been segmented.
284        prediction_path = os.path.join(prediction_dir, image_name)
285        if os.path.exists(prediction_path):
286            continue
287
288        assert os.path.exists(image_path), image_path
289        image = imageio.imread(image_path)
290
291        if embedding_dir is None:
292            embedding_path = None
293        else:
294            assert predictor is not None
295            embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
296
297        if tiling_window_params is None:
298            tiling_window_params = {}
299
300        image_embeddings = util.precompute_image_embeddings(
301            predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings, **tiling_window_params
302        )
303
304        segmenter.initialize(image, image_embeddings, **tiling_window_params)
305
306        masks = segmenter.generate(**generate_kwargs)
307
308        if len(masks) == 0:  # the instance segmentation can have no masks, hence we just save empty labels
309            if isinstance(segmenter, InstanceSegmentationWithDecoder):
310                this_shape = segmenter._foreground.shape
311            elif isinstance(segmenter, AMGBase):
312                this_shape = segmenter._original_size
313            else:
314                this_shape = image.shape[-2:]
315
316            instances = np.zeros(this_shape, dtype="uint32")
317        else:
318            instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=min_object_size)
319
320        # It's important to compress here, otherwise the predictions would take up a lot of space.
321        imageio.imwrite(prediction_path, instances, compression=5)

Run inference for automatic mask generation.

Arguments:
  • segmenter: The class implementing the instance segmentation functionality.
  • image_paths: The input images.
  • embedding_dir: Folder to cache the image embeddings.
  • prediction_dir: Folder to save the predictions.
  • generate_kwargs: The keyword arguments for the generate method of the segmenter.
  • tiling_window_params: The parameters to decide whether to use tiling window operation for automatic segmentation.
def save_grid_search_best_params(best_kwargs, best_msa, grid_search_result_dir=None):
357def save_grid_search_best_params(best_kwargs, best_msa, grid_search_result_dir=None):
358    # saving the best parameters estimated from grid-search in the `results` folder
359    param_df = pd.DataFrame.from_dict([best_kwargs])
360    res_df = pd.DataFrame.from_dict([{"best_msa": best_msa}])
361    best_param_df = pd.merge(res_df, param_df, left_index=True, right_index=True)
362
363    path_name = "grid_search_params_amg.csv" if "pred_iou_thresh" and "stability_score_thresh" in best_kwargs \
364        else "grid_search_params_instance_segmentation_with_decoder.csv"
365
366    if grid_search_result_dir is not None:
367        os.makedirs(os.path.join(grid_search_result_dir, "results"), exist_ok=True)
368        res_path = os.path.join(grid_search_result_dir, "results", path_name)
369    else:
370        res_path = path_name
371
372    best_param_df.to_csv(res_path)
def run_instance_segmentation_grid_search_and_inference( segmenter: Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder], grid_search_values: Dict[str, List], val_image_paths: List[Union[str, os.PathLike]], val_gt_paths: List[Union[str, os.PathLike]], test_image_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike, NoneType], prediction_dir: Union[str, os.PathLike], experiment_folder: Union[str, os.PathLike], result_dir: Union[str, os.PathLike], fixed_generate_kwargs: Optional[Dict[str, Any]] = None, verbose_gs: bool = True, tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None) -> None:
375def run_instance_segmentation_grid_search_and_inference(
376    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
377    grid_search_values: Dict[str, List],
378    val_image_paths: List[Union[str, os.PathLike]],
379    val_gt_paths: List[Union[str, os.PathLike]],
380    test_image_paths: List[Union[str, os.PathLike]],
381    embedding_dir: Optional[Union[str, os.PathLike]],
382    prediction_dir: Union[str, os.PathLike],
383    experiment_folder: Union[str, os.PathLike],
384    result_dir: Union[str, os.PathLike],
385    fixed_generate_kwargs: Optional[Dict[str, Any]] = None,
386    verbose_gs: bool = True,
387    tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None,
388) -> None:
389    """Run grid search and inference for automatic mask generation.
390
391    Please refer to the documentation of `run_instance_segmentation_grid_search`
392    for details on how to specify the grid search parameters.
393
394    Args:
395        segmenter: The class implementing the instance segmentation functionality.
396        grid_search_values: The grid search values for parameters of the `generate` function.
397        val_image_paths: The input images for the grid search.
398        val_gt_paths: The ground-truth segmentation for the grid search.
399        test_image_paths: The input images for inference.
400        embedding_dir: Folder to cache the image embeddings.
401        prediction_dir: Folder to save the predictions.
402        experiment_folder: Folder for caching best grid search parameters in 'results'.
403        result_dir: Folder to cache the evaluation results per image.
404        fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter.
405        verbose_gs: Whether to run the gridsearch for individual images in a verbose mode.
406        tiling_window_params: The parameters to decide whether to use tiling window operation
407            for automatic segmentation.
408    """
409    run_instance_segmentation_grid_search(
410        segmenter=segmenter,
411        grid_search_values=grid_search_values,
412        image_paths=val_image_paths,
413        gt_paths=val_gt_paths,
414        result_dir=result_dir,
415        embedding_dir=embedding_dir,
416        fixed_generate_kwargs=fixed_generate_kwargs,
417        verbose_gs=verbose_gs,
418        tiling_window_params=tiling_window_params,
419    )
420
421    best_kwargs, best_msa = evaluate_instance_segmentation_grid_search(result_dir, list(grid_search_values.keys()))
422    best_param_str = ", ".join(f"{k} = {v}" for k, v in best_kwargs.items())
423    print("Best grid-search result:", best_msa, "with parmeters:\n", best_param_str)
424    print()
425
426    save_grid_search_best_params(best_kwargs, best_msa, experiment_folder)
427
428    generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs
429    generate_kwargs.update(best_kwargs)
430
431    run_instance_segmentation_inference(
432        segmenter=segmenter,
433        image_paths=test_image_paths,
434        embedding_dir=embedding_dir,
435        prediction_dir=prediction_dir,
436        generate_kwargs=generate_kwargs,
437        tiling_window_params=tiling_window_params,
438    )

Run grid search and inference for automatic mask generation.

Please refer to the documentation of run_instance_segmentation_grid_search for details on how to specify the grid search parameters.

Arguments:
  • segmenter: The class implementing the instance segmentation functionality.
  • grid_search_values: The grid search values for parameters of the generate function.
  • val_image_paths: The input images for the grid search.
  • val_gt_paths: The ground-truth segmentation for the grid search.
  • test_image_paths: The input images for inference.
  • embedding_dir: Folder to cache the image embeddings.
  • prediction_dir: Folder to save the predictions.
  • experiment_folder: Folder for caching best grid search parameters in 'results'.
  • result_dir: Folder to cache the evaluation results per image.
  • fixed_generate_kwargs: Fixed keyword arguments for the generate method of the segmenter.
  • verbose_gs: Whether to run the gridsearch for individual images in a verbose mode.
  • tiling_window_params: The parameters to decide whether to use tiling window operation for automatic segmentation.