micro_sam.evaluation.instance_segmentation

Inference and evaluation for the automatic instance segmentation functionality.

  1"""Inference and evaluation for the automatic instance segmentation functionality.
  2"""
  3
  4import os
  5from glob import glob
  6from tqdm import tqdm
  7from pathlib import Path
  8from itertools import product
  9from typing import Any, Dict, List, Optional, Tuple, Union
 10
 11import numpy as np
 12import pandas as pd
 13import imageio.v3 as imageio
 14
 15from elf.io import open_file
 16from elf.evaluation import mean_segmentation_accuracy
 17
 18from .. import util
 19from ..instance_segmentation import AMGBase, InstanceSegmentationWithDecoder, mask_data_to_segmentation
 20
 21
 22def _get_range_of_search_values(input_vals, step):
 23    if isinstance(input_vals, list):
 24        search_range = np.arange(input_vals[0], input_vals[1] + step, step)
 25        search_range = [round(e, 3) for e in search_range]
 26    else:
 27        search_range = [input_vals]
 28    return search_range
 29
 30
 31def default_grid_search_values_amg(
 32    iou_thresh_values: Optional[List[float]] = None,
 33    stability_score_values: Optional[List[float]] = None,
 34) -> Dict[str, List[float]]:
 35    """Default grid-search parameter for AMG-based instance segmentation.
 36
 37    Return grid search values for the two most important parameters:
 38    - `pred_iou_thresh`, the threshold for keeping objects according to the IoU predicted by the model.
 39    - `stability_score_thresh`, the theshold for keepong objects according to their stability.
 40
 41    Args:
 42        iou_thresh_values: The values for `pred_iou_thresh` used in the gridsearch.
 43            By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
 44        stability_score_values: The values for `stability_score_thresh` used in the gridsearch.
 45            By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
 46
 47    Returns:
 48        The values for grid search.
 49    """
 50    if iou_thresh_values is None:
 51        iou_thresh_values = _get_range_of_search_values([0.6, 0.9], step=0.025)
 52    if stability_score_values is None:
 53        stability_score_values = _get_range_of_search_values([0.6, 0.95], step=0.025)
 54    return {
 55        "pred_iou_thresh": iou_thresh_values,
 56        "stability_score_thresh": stability_score_values,
 57    }
 58
 59
 60def default_grid_search_values_instance_segmentation_with_decoder(
 61    center_distance_threshold_values: Optional[List[float]] = None,
 62    boundary_distance_threshold_values: Optional[List[float]] = None,
 63    distance_smoothing_values: Optional[List[float]] = None,
 64    min_size_values: Optional[List[float]] = None,
 65) -> Dict[str, List[float]]:
 66    """Default grid-search parameter for decoder-based instance segmentation.
 67
 68    Args:
 69        center_distance_threshold_values: The values for `center_distance_threshold` used in the gridsearch.
 70            By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used.
 71        boundary_distance_threshold_values: The values for `boundary_distance_threshold` used in the gridsearch.
 72            By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used.
 73        distance_smoothing_values: The values for `distance_smoothing` used in the gridsearch.
 74            By default values in the range from 1.0 to 2.0 with a stepsize of 0.1 will be used.
 75        min_size_values: The values for `min_size` used in the gridsearch.
 76            By default the values 50, 100 and 200  are used.
 77
 78    Returns:
 79        The values for grid search.
 80    """
 81    if center_distance_threshold_values is None:
 82        center_distance_threshold_values = _get_range_of_search_values(
 83            [0.3, 0.7], step=0.1
 84        )
 85    if boundary_distance_threshold_values is None:
 86        boundary_distance_threshold_values = _get_range_of_search_values(
 87            [0.3, 0.7], step=0.1
 88        )
 89    if distance_smoothing_values is None:
 90        distance_smoothing_values = _get_range_of_search_values(
 91            [1.0, 2.0], step=0.2
 92        )
 93    if min_size_values is None:
 94        min_size_values = [50, 100, 200]
 95    return {
 96        "center_distance_threshold": center_distance_threshold_values,
 97        "boundary_distance_threshold": boundary_distance_threshold_values,
 98        "distance_smoothing": distance_smoothing_values,
 99        "min_size": min_size_values,
100    }
101
102
103def _grid_search_iteration(
104    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
105    gs_combinations: List[Dict],
106    gt: np.ndarray,
107    image_name: str,
108    fixed_generate_kwargs: Dict[str, Any],
109    result_path: Optional[Union[str, os.PathLike]],
110    verbose: bool = False,
111) -> pd.DataFrame:
112    net_list = []
113    for gs_kwargs in tqdm(gs_combinations, disable=not verbose):
114        generate_kwargs = gs_kwargs | fixed_generate_kwargs
115        masks = segmenter.generate(**generate_kwargs)
116
117        min_object_size = generate_kwargs.get("min_mask_region_area", 0)
118        if len(masks) == 0:
119            instance_labels = np.zeros(gt.shape, dtype="uint32")
120        else:
121            instance_labels = mask_data_to_segmentation(masks, with_background=True, min_object_size=min_object_size)
122        m_sas, sas = mean_segmentation_accuracy(instance_labels, gt, return_accuracies=True)  # type: ignore
123
124        result_dict = {"image_name": image_name, "mSA": m_sas, "SA50": sas[0], "SA75": sas[5]}
125        result_dict.update(gs_kwargs)
126        tmp_df = pd.DataFrame([result_dict])
127        net_list.append(tmp_df)
128
129    img_gs_df = pd.concat(net_list)
130    img_gs_df.to_csv(result_path, index=False)
131
132    return img_gs_df
133
134
135def _load_image(path, key, roi):
136    if key is None:
137        im = imageio.imread(path)
138        if roi is not None:
139            im = im[roi]
140        return im
141    with open_file(path, "r") as f:
142        im = f[key][:] if roi is None else f[key][roi]
143    return im
144
145
146def run_instance_segmentation_grid_search(
147    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
148    grid_search_values: Dict[str, List],
149    image_paths: List[Union[str, os.PathLike]],
150    gt_paths: List[Union[str, os.PathLike]],
151    result_dir: Union[str, os.PathLike],
152    embedding_dir: Optional[Union[str, os.PathLike]],
153    fixed_generate_kwargs: Optional[Dict[str, Any]] = None,
154    verbose_gs: bool = False,
155    image_key: Optional[str] = None,
156    gt_key: Optional[str] = None,
157    rois: Optional[Tuple[slice, ...]] = None,
158) -> None:
159    """Run grid search for automatic mask generation.
160
161    The parameters and their respective value ranges for the grid search are specified via the
162    'grid_search_values' argument. For example, to run a grid search over the parameters 'pred_iou_thresh'
163    and 'stability_score_thresh', you can pass the following:
164    ```
165    grid_search_values = {
166        "pred_iou_thresh": [0.6, 0.7, 0.8, 0.9],
167        "stability_score_thresh": [0.6, 0.7, 0.8, 0.9],
168    }
169    ```
170    All combinations of the parameters will be checked.
171
172    You can use the functions `default_grid_search_values_instance_segmentation_with_decoder`
173    or `default_grid_search_values_amg` to get the default grid search parameters for the two
174    respective instance segmentation methods.
175
176    Args:
177        segmenter: The class implementing the instance segmentation functionality.
178        grid_search_values: The grid search values for parameters of the `generate` function.
179        image_paths: The input images for the grid search.
180        gt_paths: The ground-truth segmentation for the grid search.
181        result_dir: Folder to cache the evaluation results per image.
182        embedding_dir: Folder to cache the image embeddings.
183        fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter.
184        verbose_gs: Whether to run the grid-search for individual images in a verbose mode.
185        image_key: Key for loading the image data from a more complex file format like HDF5.
186            If not given a simple image format like tif is assumed.
187        gt_key: Key for loading the ground-truth data from a more complex file format like HDF5.
188            If not given a simple image format like tif is assumed.
189        rois: Region of interests to resetrict the evaluation to.
190    """
191    verbose_embeddings = False
192
193    assert len(image_paths) == len(gt_paths)
194    fixed_generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs
195
196    duplicate_params = [gs_param for gs_param in grid_search_values.keys() if gs_param in fixed_generate_kwargs]
197    if duplicate_params:
198        raise ValueError(
199            "You may not pass duplicate parameters in 'grid_search_values' and 'fixed_generate_kwargs'."
200            f"The parameters {duplicate_params} are duplicated."
201        )
202
203    # Compute all combinations of grid search values.
204    gs_combinations = product(*grid_search_values.values())
205    # Map each combination back to a valid kwarg input.
206    gs_combinations = [
207        {k: v for k, v in zip(grid_search_values.keys(), vals)} for vals in gs_combinations
208    ]
209
210    os.makedirs(result_dir, exist_ok=True)
211    predictor = getattr(segmenter, "_predictor", None)
212
213    for i, (image_path, gt_path) in tqdm(
214        enumerate(zip(image_paths, gt_paths)), desc="Run instance segmentation grid-search", total=len(image_paths)
215    ):
216        image_name = Path(image_path).stem
217        result_path = os.path.join(result_dir, f"{image_name}.csv")
218
219        # We skip images for which the grid search was done already.
220        if os.path.exists(result_path):
221            continue
222
223        assert os.path.exists(image_path), image_path
224        assert os.path.exists(gt_path), gt_path
225
226        image = _load_image(image_path, image_key, roi=None if rois is None else rois[i])
227        gt = _load_image(gt_path, gt_key, roi=None if rois is None else rois[i])
228
229        if embedding_dir is None:
230            embedding_path = None
231        else:
232            assert predictor is not None
233            embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
234
235        image_embeddings = util.precompute_image_embeddings(
236            predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings
237        )
238
239        segmenter.initialize(image, image_embeddings)
240
241        _grid_search_iteration(
242            segmenter, gs_combinations, gt, image_name,
243            fixed_generate_kwargs=fixed_generate_kwargs, result_path=result_path, verbose=verbose_gs,
244        )
245
246
247def run_instance_segmentation_inference(
248    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
249    image_paths: List[Union[str, os.PathLike]],
250    embedding_dir: Optional[Union[str, os.PathLike]],
251    prediction_dir: Union[str, os.PathLike],
252    generate_kwargs: Optional[Dict[str, Any]] = None,
253) -> None:
254    """Run inference for automatic mask generation.
255
256    Args:
257        segmenter: The class implementing the instance segmentation functionality.
258        image_paths: The input images.
259        embedding_dir: Folder to cache the image embeddings.
260        prediction_dir: Folder to save the predictions.
261        generate_kwargs: The keyword arguments for the `generate` method of the segmenter.
262    """
263
264    verbose_embeddings = False
265
266    generate_kwargs = {} if generate_kwargs is None else generate_kwargs
267    predictor = segmenter._predictor
268    min_object_size = generate_kwargs.get("min_mask_region_area", 0)
269
270    for image_path in tqdm(image_paths, desc="Run inference for automatic mask generation"):
271        image_name = os.path.basename(image_path)
272
273        # We skip the images that already have been segmented.
274        prediction_path = os.path.join(prediction_dir, image_name)
275        if os.path.exists(prediction_path):
276            continue
277
278        assert os.path.exists(image_path), image_path
279        image = imageio.imread(image_path)
280
281        if embedding_dir is None:
282            embedding_path = None
283        else:
284            assert predictor is not None
285            embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
286
287        image_embeddings = util.precompute_image_embeddings(
288            predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings
289        )
290
291        segmenter.initialize(image, image_embeddings)
292
293        masks = segmenter.generate(**generate_kwargs)
294
295        if len(masks) == 0:  # the instance segmentation can have no masks, hence we just save empty labels
296            if isinstance(segmenter, InstanceSegmentationWithDecoder):
297                this_shape = segmenter._foreground.shape
298            elif isinstance(segmenter, AMGBase):
299                this_shape = segmenter._original_size
300            else:
301                this_shape = image.shape[-2:]
302
303            instances = np.zeros(this_shape, dtype="uint32")
304        else:
305            instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=min_object_size)
306
307        # It's important to compress here, otherwise the predictions would take up a lot of space.
308        imageio.imwrite(prediction_path, instances, compression=5)
309
310
311def evaluate_instance_segmentation_grid_search(
312    result_dir: Union[str, os.PathLike], grid_search_parameters: List[str], criterion: str = "mSA"
313) -> Tuple[Dict[str, Any], float]:
314    """Evaluate gridsearch results.
315
316    Args:
317        result_dir: The folder with the gridsearch results.
318        grid_search_parameters: The names for the gridsearch parameters.
319        criterion: The metric to use for determining the best parameters.
320
321    Returns:
322        The best parameter setting.
323        The evaluation score for the best setting.
324    """
325    # Load all the grid search results.
326    gs_files = glob(os.path.join(result_dir, "*.csv"))
327    gs_result = pd.concat([pd.read_csv(gs_file) for gs_file in gs_files])
328
329    # Retrieve only the relevant columns and group by the gridsearch columns.
330    gs_result = gs_result[grid_search_parameters + [criterion]].reset_index()
331
332    # Compute the mean over the grouped columns.
333    grouped_result = gs_result.groupby(grid_search_parameters).mean().reset_index()
334
335    # Find the best score and corresponding parameters.
336    best_score, best_idx = grouped_result[criterion].max(), grouped_result[criterion].idxmax()
337    best_params = grouped_result.iloc[best_idx]
338    assert np.isclose(best_params[criterion], best_score)
339    best_kwargs = {k: v for k, v in zip(grid_search_parameters, best_params)}
340
341    return best_kwargs, best_score
342
343
344def save_grid_search_best_params(best_kwargs, best_msa, grid_search_result_dir=None):
345    # saving the best parameters estimated from grid-search in the `results` folder
346    param_df = pd.DataFrame.from_dict([best_kwargs])
347    res_df = pd.DataFrame.from_dict([{"best_msa": best_msa}])
348    best_param_df = pd.merge(res_df, param_df, left_index=True, right_index=True)
349
350    path_name = "grid_search_params_amg.csv" if "pred_iou_thresh" and "stability_score_thresh" in best_kwargs \
351        else "grid_search_params_instance_segmentation_with_decoder.csv"
352
353    if grid_search_result_dir is not None:
354        os.makedirs(os.path.join(grid_search_result_dir, "results"), exist_ok=True)
355        res_path = os.path.join(grid_search_result_dir, "results", path_name)
356    else:
357        res_path = path_name
358
359    best_param_df.to_csv(res_path)
360
361
362def run_instance_segmentation_grid_search_and_inference(
363    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
364    grid_search_values: Dict[str, List],
365    val_image_paths: List[Union[str, os.PathLike]],
366    val_gt_paths: List[Union[str, os.PathLike]],
367    test_image_paths: List[Union[str, os.PathLike]],
368    embedding_dir: Optional[Union[str, os.PathLike]],
369    prediction_dir: Union[str, os.PathLike],
370    experiment_folder: Union[str, os.PathLike],
371    result_dir: Union[str, os.PathLike],
372    fixed_generate_kwargs: Optional[Dict[str, Any]] = None,
373    verbose_gs: bool = True,
374) -> None:
375    """Run grid search and inference for automatic mask generation.
376
377    Please refer to the documentation of `run_instance_segmentation_grid_search`
378    for details on how to specify the grid search parameters.
379
380    Args:
381        segmenter: The class implementing the instance segmentation functionality.
382        grid_search_values: The grid search values for parameters of the `generate` function.
383        val_image_paths: The input images for the grid search.
384        val_gt_paths: The ground-truth segmentation for the grid search.
385        test_image_paths: The input images for inference.
386        embedding_dir: Folder to cache the image embeddings.
387        prediction_dir: Folder to save the predictions.
388        experiment_folder: Folder for caching best grid search parameters in 'results'.
389        result_dir: Folder to cache the evaluation results per image.
390        fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter.
391        verbose_gs: Whether to run the gridsearch for individual images in a verbose mode.
392    """
393    run_instance_segmentation_grid_search(
394        segmenter, grid_search_values, val_image_paths, val_gt_paths,
395        result_dir=result_dir, embedding_dir=embedding_dir,
396        fixed_generate_kwargs=fixed_generate_kwargs, verbose_gs=verbose_gs,
397    )
398
399    best_kwargs, best_msa = evaluate_instance_segmentation_grid_search(result_dir, list(grid_search_values.keys()))
400    best_param_str = ", ".join(f"{k} = {v}" for k, v in best_kwargs.items())
401    print("Best grid-search result:", best_msa, "with parmeters:\n", best_param_str)
402    print()
403
404    save_grid_search_best_params(best_kwargs, best_msa, experiment_folder)
405
406    generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs
407    generate_kwargs.update(best_kwargs)
408
409    run_instance_segmentation_inference(
410        segmenter, test_image_paths, embedding_dir, prediction_dir, generate_kwargs
411    )
def default_grid_search_values_amg( iou_thresh_values: Optional[List[float]] = None, stability_score_values: Optional[List[float]] = None) -> Dict[str, List[float]]:
32def default_grid_search_values_amg(
33    iou_thresh_values: Optional[List[float]] = None,
34    stability_score_values: Optional[List[float]] = None,
35) -> Dict[str, List[float]]:
36    """Default grid-search parameter for AMG-based instance segmentation.
37
38    Return grid search values for the two most important parameters:
39    - `pred_iou_thresh`, the threshold for keeping objects according to the IoU predicted by the model.
40    - `stability_score_thresh`, the theshold for keepong objects according to their stability.
41
42    Args:
43        iou_thresh_values: The values for `pred_iou_thresh` used in the gridsearch.
44            By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
45        stability_score_values: The values for `stability_score_thresh` used in the gridsearch.
46            By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
47
48    Returns:
49        The values for grid search.
50    """
51    if iou_thresh_values is None:
52        iou_thresh_values = _get_range_of_search_values([0.6, 0.9], step=0.025)
53    if stability_score_values is None:
54        stability_score_values = _get_range_of_search_values([0.6, 0.95], step=0.025)
55    return {
56        "pred_iou_thresh": iou_thresh_values,
57        "stability_score_thresh": stability_score_values,
58    }

Default grid-search parameter for AMG-based instance segmentation.

Return grid search values for the two most important parameters:

  • pred_iou_thresh, the threshold for keeping objects according to the IoU predicted by the model.
  • stability_score_thresh, the theshold for keepong objects according to their stability.
Arguments:
  • iou_thresh_values: The values for pred_iou_thresh used in the gridsearch. By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
  • stability_score_values: The values for stability_score_thresh used in the gridsearch. By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
Returns:

The values for grid search.

def default_grid_search_values_instance_segmentation_with_decoder( center_distance_threshold_values: Optional[List[float]] = None, boundary_distance_threshold_values: Optional[List[float]] = None, distance_smoothing_values: Optional[List[float]] = None, min_size_values: Optional[List[float]] = None) -> Dict[str, List[float]]:
 61def default_grid_search_values_instance_segmentation_with_decoder(
 62    center_distance_threshold_values: Optional[List[float]] = None,
 63    boundary_distance_threshold_values: Optional[List[float]] = None,
 64    distance_smoothing_values: Optional[List[float]] = None,
 65    min_size_values: Optional[List[float]] = None,
 66) -> Dict[str, List[float]]:
 67    """Default grid-search parameter for decoder-based instance segmentation.
 68
 69    Args:
 70        center_distance_threshold_values: The values for `center_distance_threshold` used in the gridsearch.
 71            By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used.
 72        boundary_distance_threshold_values: The values for `boundary_distance_threshold` used in the gridsearch.
 73            By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used.
 74        distance_smoothing_values: The values for `distance_smoothing` used in the gridsearch.
 75            By default values in the range from 1.0 to 2.0 with a stepsize of 0.1 will be used.
 76        min_size_values: The values for `min_size` used in the gridsearch.
 77            By default the values 50, 100 and 200  are used.
 78
 79    Returns:
 80        The values for grid search.
 81    """
 82    if center_distance_threshold_values is None:
 83        center_distance_threshold_values = _get_range_of_search_values(
 84            [0.3, 0.7], step=0.1
 85        )
 86    if boundary_distance_threshold_values is None:
 87        boundary_distance_threshold_values = _get_range_of_search_values(
 88            [0.3, 0.7], step=0.1
 89        )
 90    if distance_smoothing_values is None:
 91        distance_smoothing_values = _get_range_of_search_values(
 92            [1.0, 2.0], step=0.2
 93        )
 94    if min_size_values is None:
 95        min_size_values = [50, 100, 200]
 96    return {
 97        "center_distance_threshold": center_distance_threshold_values,
 98        "boundary_distance_threshold": boundary_distance_threshold_values,
 99        "distance_smoothing": distance_smoothing_values,
100        "min_size": min_size_values,
101    }

Default grid-search parameter for decoder-based instance segmentation.

Arguments:
  • center_distance_threshold_values: The values for center_distance_threshold used in the gridsearch. By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used.
  • boundary_distance_threshold_values: The values for boundary_distance_threshold used in the gridsearch. By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used.
  • distance_smoothing_values: The values for distance_smoothing used in the gridsearch. By default values in the range from 1.0 to 2.0 with a stepsize of 0.1 will be used.
  • min_size_values: The values for min_size used in the gridsearch. By default the values 50, 100 and 200 are used.
Returns:

The values for grid search.

def run_instance_segmentation_inference( segmenter: Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder], image_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike, NoneType], prediction_dir: Union[str, os.PathLike], generate_kwargs: Optional[Dict[str, Any]] = None) -> None:
248def run_instance_segmentation_inference(
249    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
250    image_paths: List[Union[str, os.PathLike]],
251    embedding_dir: Optional[Union[str, os.PathLike]],
252    prediction_dir: Union[str, os.PathLike],
253    generate_kwargs: Optional[Dict[str, Any]] = None,
254) -> None:
255    """Run inference for automatic mask generation.
256
257    Args:
258        segmenter: The class implementing the instance segmentation functionality.
259        image_paths: The input images.
260        embedding_dir: Folder to cache the image embeddings.
261        prediction_dir: Folder to save the predictions.
262        generate_kwargs: The keyword arguments for the `generate` method of the segmenter.
263    """
264
265    verbose_embeddings = False
266
267    generate_kwargs = {} if generate_kwargs is None else generate_kwargs
268    predictor = segmenter._predictor
269    min_object_size = generate_kwargs.get("min_mask_region_area", 0)
270
271    for image_path in tqdm(image_paths, desc="Run inference for automatic mask generation"):
272        image_name = os.path.basename(image_path)
273
274        # We skip the images that already have been segmented.
275        prediction_path = os.path.join(prediction_dir, image_name)
276        if os.path.exists(prediction_path):
277            continue
278
279        assert os.path.exists(image_path), image_path
280        image = imageio.imread(image_path)
281
282        if embedding_dir is None:
283            embedding_path = None
284        else:
285            assert predictor is not None
286            embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
287
288        image_embeddings = util.precompute_image_embeddings(
289            predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings
290        )
291
292        segmenter.initialize(image, image_embeddings)
293
294        masks = segmenter.generate(**generate_kwargs)
295
296        if len(masks) == 0:  # the instance segmentation can have no masks, hence we just save empty labels
297            if isinstance(segmenter, InstanceSegmentationWithDecoder):
298                this_shape = segmenter._foreground.shape
299            elif isinstance(segmenter, AMGBase):
300                this_shape = segmenter._original_size
301            else:
302                this_shape = image.shape[-2:]
303
304            instances = np.zeros(this_shape, dtype="uint32")
305        else:
306            instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=min_object_size)
307
308        # It's important to compress here, otherwise the predictions would take up a lot of space.
309        imageio.imwrite(prediction_path, instances, compression=5)

Run inference for automatic mask generation.

Arguments:
  • segmenter: The class implementing the instance segmentation functionality.
  • image_paths: The input images.
  • embedding_dir: Folder to cache the image embeddings.
  • prediction_dir: Folder to save the predictions.
  • generate_kwargs: The keyword arguments for the generate method of the segmenter.
def save_grid_search_best_params(best_kwargs, best_msa, grid_search_result_dir=None):
345def save_grid_search_best_params(best_kwargs, best_msa, grid_search_result_dir=None):
346    # saving the best parameters estimated from grid-search in the `results` folder
347    param_df = pd.DataFrame.from_dict([best_kwargs])
348    res_df = pd.DataFrame.from_dict([{"best_msa": best_msa}])
349    best_param_df = pd.merge(res_df, param_df, left_index=True, right_index=True)
350
351    path_name = "grid_search_params_amg.csv" if "pred_iou_thresh" and "stability_score_thresh" in best_kwargs \
352        else "grid_search_params_instance_segmentation_with_decoder.csv"
353
354    if grid_search_result_dir is not None:
355        os.makedirs(os.path.join(grid_search_result_dir, "results"), exist_ok=True)
356        res_path = os.path.join(grid_search_result_dir, "results", path_name)
357    else:
358        res_path = path_name
359
360    best_param_df.to_csv(res_path)
def run_instance_segmentation_grid_search_and_inference( segmenter: Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder], grid_search_values: Dict[str, List], val_image_paths: List[Union[str, os.PathLike]], val_gt_paths: List[Union[str, os.PathLike]], test_image_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike, NoneType], prediction_dir: Union[str, os.PathLike], experiment_folder: Union[str, os.PathLike], result_dir: Union[str, os.PathLike], fixed_generate_kwargs: Optional[Dict[str, Any]] = None, verbose_gs: bool = True) -> None:
363def run_instance_segmentation_grid_search_and_inference(
364    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
365    grid_search_values: Dict[str, List],
366    val_image_paths: List[Union[str, os.PathLike]],
367    val_gt_paths: List[Union[str, os.PathLike]],
368    test_image_paths: List[Union[str, os.PathLike]],
369    embedding_dir: Optional[Union[str, os.PathLike]],
370    prediction_dir: Union[str, os.PathLike],
371    experiment_folder: Union[str, os.PathLike],
372    result_dir: Union[str, os.PathLike],
373    fixed_generate_kwargs: Optional[Dict[str, Any]] = None,
374    verbose_gs: bool = True,
375) -> None:
376    """Run grid search and inference for automatic mask generation.
377
378    Please refer to the documentation of `run_instance_segmentation_grid_search`
379    for details on how to specify the grid search parameters.
380
381    Args:
382        segmenter: The class implementing the instance segmentation functionality.
383        grid_search_values: The grid search values for parameters of the `generate` function.
384        val_image_paths: The input images for the grid search.
385        val_gt_paths: The ground-truth segmentation for the grid search.
386        test_image_paths: The input images for inference.
387        embedding_dir: Folder to cache the image embeddings.
388        prediction_dir: Folder to save the predictions.
389        experiment_folder: Folder for caching best grid search parameters in 'results'.
390        result_dir: Folder to cache the evaluation results per image.
391        fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter.
392        verbose_gs: Whether to run the gridsearch for individual images in a verbose mode.
393    """
394    run_instance_segmentation_grid_search(
395        segmenter, grid_search_values, val_image_paths, val_gt_paths,
396        result_dir=result_dir, embedding_dir=embedding_dir,
397        fixed_generate_kwargs=fixed_generate_kwargs, verbose_gs=verbose_gs,
398    )
399
400    best_kwargs, best_msa = evaluate_instance_segmentation_grid_search(result_dir, list(grid_search_values.keys()))
401    best_param_str = ", ".join(f"{k} = {v}" for k, v in best_kwargs.items())
402    print("Best grid-search result:", best_msa, "with parmeters:\n", best_param_str)
403    print()
404
405    save_grid_search_best_params(best_kwargs, best_msa, experiment_folder)
406
407    generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs
408    generate_kwargs.update(best_kwargs)
409
410    run_instance_segmentation_inference(
411        segmenter, test_image_paths, embedding_dir, prediction_dir, generate_kwargs
412    )

Run grid search and inference for automatic mask generation.

Please refer to the documentation of run_instance_segmentation_grid_search for details on how to specify the grid search parameters.

Arguments:
  • segmenter: The class implementing the instance segmentation functionality.
  • grid_search_values: The grid search values for parameters of the generate function.
  • val_image_paths: The input images for the grid search.
  • val_gt_paths: The ground-truth segmentation for the grid search.
  • test_image_paths: The input images for inference.
  • embedding_dir: Folder to cache the image embeddings.
  • prediction_dir: Folder to save the predictions.
  • experiment_folder: Folder for caching best grid search parameters in 'results'.
  • result_dir: Folder to cache the evaluation results per image.
  • fixed_generate_kwargs: Fixed keyword arguments for the generate method of the segmenter.
  • verbose_gs: Whether to run the gridsearch for individual images in a verbose mode.