micro_sam.evaluation.instance_segmentation

Inference and evaluation for the automatic instance segmentation functionality.

  1"""Inference and evaluation for the automatic instance segmentation functionality.
  2"""
  3
  4import os
  5from glob import glob
  6from tqdm import tqdm
  7from pathlib import Path
  8from itertools import product
  9from typing import Any, Dict, List, Optional, Tuple, Union
 10
 11import numpy as np
 12import pandas as pd
 13import imageio.v3 as imageio
 14
 15from elf.io import open_file
 16from elf.evaluation import mean_segmentation_accuracy
 17
 18from .. import util
 19from ..instance_segmentation import AMGBase, InstanceSegmentationWithDecoder, mask_data_to_segmentation
 20
 21
 22def _get_range_of_search_values(input_vals, step):
 23    if isinstance(input_vals, list):
 24        search_range = np.arange(input_vals[0], input_vals[1] + step, step)
 25        search_range = [round(e, 3) for e in search_range]
 26    else:
 27        search_range = [input_vals]
 28    return search_range
 29
 30
 31def default_grid_search_values_amg(
 32    iou_thresh_values: Optional[List[float]] = None,
 33    stability_score_values: Optional[List[float]] = None,
 34) -> Dict[str, List[float]]:
 35    """Default grid-search parameter for AMG-based instance segmentation.
 36
 37    Return grid search values for the two most important parameters:
 38    - `pred_iou_thresh`, the threshold for keeping objects according to the IoU predicted by the model.
 39    - `stability_score_thresh`, the theshold for keepong objects according to their stability.
 40
 41    Args:
 42        iou_thresh_values: The values for `pred_iou_thresh` used in the gridsearch.
 43            By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
 44        stability_score_values: The values for `stability_score_thresh` used in the gridsearch.
 45            By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
 46
 47    Returns:
 48        The values for grid search.
 49    """
 50    if iou_thresh_values is None:
 51        iou_thresh_values = _get_range_of_search_values([0.6, 0.9], step=0.025)
 52    if stability_score_values is None:
 53        stability_score_values = _get_range_of_search_values([0.6, 0.95], step=0.025)
 54    return {
 55        "pred_iou_thresh": iou_thresh_values,
 56        "stability_score_thresh": stability_score_values,
 57    }
 58
 59
 60def default_grid_search_values_instance_segmentation_with_decoder(
 61    center_distance_threshold_values: Optional[List[float]] = None,
 62    boundary_distance_threshold_values: Optional[List[float]] = None,
 63    distance_smoothing_values: Optional[List[float]] = None,
 64    min_size_values: Optional[List[float]] = None,
 65) -> Dict[str, List[float]]:
 66    """Default grid-search parameter for decoder-based instance segmentation.
 67
 68    Args:
 69        center_distance_threshold_values: The values for `center_distance_threshold` used in the gridsearch.
 70            By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used.
 71        boundary_distance_threshold_values: The values for `boundary_distance_threshold` used in the gridsearch.
 72            By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used.
 73        distance_smoothing_values: The values for `distance_smoothing` used in the gridsearch.
 74            By default values in the range from 1.0 to 2.0 with a stepsize of 0.1 will be used.
 75        min_size_values: The values for `min_size` used in the gridsearch.
 76            By default the values 50, 100 and 200  are used.
 77
 78    Returns:
 79        The values for grid search.
 80    """
 81    if center_distance_threshold_values is None:
 82        center_distance_threshold_values = _get_range_of_search_values(
 83            [0.3, 0.7], step=0.1
 84        )
 85    if boundary_distance_threshold_values is None:
 86        boundary_distance_threshold_values = _get_range_of_search_values(
 87            [0.3, 0.7], step=0.1
 88        )
 89    if distance_smoothing_values is None:
 90        distance_smoothing_values = _get_range_of_search_values(
 91            [1.0, 2.0], step=0.2
 92        )
 93    if min_size_values is None:
 94        min_size_values = [50, 100, 200]
 95    return {
 96        "center_distance_threshold": center_distance_threshold_values,
 97        "boundary_distance_threshold": boundary_distance_threshold_values,
 98        "distance_smoothing": distance_smoothing_values,
 99        "min_size": min_size_values,
100    }
101
102
103def _grid_search_iteration(
104    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
105    gs_combinations: List[Dict],
106    gt: np.ndarray,
107    image_name: str,
108    fixed_generate_kwargs: Dict[str, Any],
109    result_path: Optional[Union[str, os.PathLike]],
110    verbose: bool = False,
111) -> pd.DataFrame:
112    net_list = []
113    for gs_kwargs in tqdm(gs_combinations, disable=not verbose):
114        generate_kwargs = gs_kwargs | fixed_generate_kwargs
115        masks = segmenter.generate(**generate_kwargs)
116
117        min_object_size = generate_kwargs.get("min_mask_region_area", 0)
118        if len(masks) == 0:
119            instance_labels = np.zeros(gt.shape, dtype="uint32")
120        else:
121            instance_labels = mask_data_to_segmentation(masks, with_background=True, min_object_size=min_object_size)
122        m_sas, sas = mean_segmentation_accuracy(instance_labels, gt, return_accuracies=True)  # type: ignore
123
124        result_dict = {"image_name": image_name, "mSA": m_sas, "SA50": sas[0], "SA75": sas[5]}
125        result_dict.update(gs_kwargs)
126        tmp_df = pd.DataFrame([result_dict])
127        net_list.append(tmp_df)
128
129    img_gs_df = pd.concat(net_list)
130    img_gs_df.to_csv(result_path, index=False)
131
132    return img_gs_df
133
134
135def _load_image(path, key, roi):
136    if key is None:
137        im = imageio.imread(path)
138        if roi is not None:
139            im = im[roi]
140        return im
141    with open_file(path, "r") as f:
142        im = f[key][:] if roi is None else f[key][roi]
143    return im
144
145
146def run_instance_segmentation_grid_search(
147    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
148    grid_search_values: Dict[str, List],
149    image_paths: List[Union[str, os.PathLike]],
150    gt_paths: List[Union[str, os.PathLike]],
151    result_dir: Union[str, os.PathLike],
152    embedding_dir: Optional[Union[str, os.PathLike]],
153    fixed_generate_kwargs: Optional[Dict[str, Any]] = None,
154    verbose_gs: bool = False,
155    image_key: Optional[str] = None,
156    gt_key: Optional[str] = None,
157    rois: Optional[Tuple[slice, ...]] = None,
158) -> None:
159    """Run grid search for automatic mask generation.
160
161    The parameters and their respective value ranges for the grid search are specified via the
162    'grid_search_values' argument. For example, to run a grid search over the parameters 'pred_iou_thresh'
163    and 'stability_score_thresh', you can pass the following:
164    ```
165    grid_search_values = {
166        "pred_iou_thresh": [0.6, 0.7, 0.8, 0.9],
167        "stability_score_thresh": [0.6, 0.7, 0.8, 0.9],
168    }
169    ```
170    All combinations of the parameters will be checked.
171
172    You can use the functions `default_grid_search_values_instance_segmentation_with_decoder`
173    or `default_grid_search_values_amg` to get the default grid search parameters for the two
174    respective instance segmentation methods.
175
176    Args:
177        segmenter: The class implementing the instance segmentation functionality.
178        grid_search_values: The grid search values for parameters of the `generate` function.
179        image_paths: The input images for the grid search.
180        gt_paths: The ground-truth segmentation for the grid search.
181        result_dir: Folder to cache the evaluation results per image.
182        embedding_dir: Folder to cache the image embeddings.
183        fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter.
184        verbose_gs: Whether to run the grid-search for individual images in a verbose mode.
185        image_key: Key for loading the image data from a more complex file format like HDF5.
186            If not given a simple image format like tif is assumed.
187        gt_key: Key for loading the ground-truth data from a more complex file format like HDF5.
188            If not given a simple image format like tif is assumed.
189        rois: Region of interests to resetrict the evaluation to.
190    """
191    verbose_embeddings = False
192
193    assert len(image_paths) == len(gt_paths)
194    fixed_generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs
195
196    duplicate_params = [gs_param for gs_param in grid_search_values.keys() if gs_param in fixed_generate_kwargs]
197    if duplicate_params:
198        raise ValueError(
199            "You may not pass duplicate parameters in 'grid_search_values' and 'fixed_generate_kwargs'."
200            f"The parameters {duplicate_params} are duplicated."
201        )
202
203    # Compute all combinations of grid search values.
204    gs_combinations = product(*grid_search_values.values())
205    # Map each combination back to a valid kwarg input.
206    gs_combinations = [
207        {k: v for k, v in zip(grid_search_values.keys(), vals)} for vals in gs_combinations
208    ]
209
210    os.makedirs(result_dir, exist_ok=True)
211    predictor = getattr(segmenter, "_predictor", None)
212
213    for i, (image_path, gt_path) in tqdm(
214        enumerate(zip(image_paths, gt_paths)), desc="Run instance segmentation grid-search", total=len(image_paths)
215    ):
216        image_name = Path(image_path).stem
217        result_path = os.path.join(result_dir, f"{image_name}.csv")
218
219        # We skip images for which the grid search was done already.
220        if os.path.exists(result_path):
221            continue
222
223        assert os.path.exists(image_path), image_path
224        assert os.path.exists(gt_path), gt_path
225
226        image = _load_image(image_path, image_key, roi=None if rois is None else rois[i])
227        gt = _load_image(gt_path, gt_key, roi=None if rois is None else rois[i])
228
229        if embedding_dir is None:
230            segmenter.initialize(image)
231        else:
232            assert predictor is not None
233            embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
234            image_embeddings = util.precompute_image_embeddings(
235                predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings
236            )
237            segmenter.initialize(image, image_embeddings)
238
239        _grid_search_iteration(
240            segmenter, gs_combinations, gt, image_name,
241            fixed_generate_kwargs=fixed_generate_kwargs, result_path=result_path, verbose=verbose_gs,
242        )
243
244
245def run_instance_segmentation_inference(
246    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
247    image_paths: List[Union[str, os.PathLike]],
248    embedding_dir: Union[str, os.PathLike],
249    prediction_dir: Union[str, os.PathLike],
250    generate_kwargs: Optional[Dict[str, Any]] = None,
251) -> None:
252    """Run inference for automatic mask generation.
253
254    Args:
255        segmenter: The class implementing the instance segmentation functionality.
256        image_paths: The input images.
257        embedding_dir: Folder to cache the image embeddings.
258        prediction_dir: Folder to save the predictions.
259        generate_kwargs: The keyword arguments for the `generate` method of the segmenter.
260    """
261
262    verbose_embeddings = False
263
264    generate_kwargs = {} if generate_kwargs is None else generate_kwargs
265    predictor = segmenter._predictor
266    min_object_size = generate_kwargs.get("min_mask_region_area", 0)
267
268    for image_path in tqdm(image_paths, desc="Run inference for automatic mask generation"):
269        image_name = os.path.basename(image_path)
270
271        # We skip the images that already have been segmented.
272        prediction_path = os.path.join(prediction_dir, image_name)
273        if os.path.exists(prediction_path):
274            continue
275
276        assert os.path.exists(image_path), image_path
277        image = imageio.imread(image_path)
278
279        embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
280        image_embeddings = util.precompute_image_embeddings(
281            predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings
282        )
283
284        segmenter.initialize(image, image_embeddings)
285        masks = segmenter.generate(**generate_kwargs)
286
287        if len(masks) == 0:  # the instance segmentation can have no masks, hence we just save empty labels
288            if isinstance(segmenter, InstanceSegmentationWithDecoder):
289                this_shape = segmenter._foreground.shape
290            elif isinstance(segmenter, AMGBase):
291                this_shape = segmenter._original_size
292            else:
293                this_shape = image.shape[-2:]
294
295            instances = np.zeros(this_shape, dtype="uint32")
296        else:
297            instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=min_object_size)
298
299        # It's important to compress here, otherwise the predictions would take up a lot of space.
300        imageio.imwrite(prediction_path, instances, compression=5)
301
302
303def evaluate_instance_segmentation_grid_search(
304    result_dir: Union[str, os.PathLike],
305    grid_search_parameters: List[str],
306    criterion: str = "mSA"
307) -> Tuple[Dict[str, Any], float]:
308    """Evaluate gridsearch results.
309
310    Args:
311        result_dir: The folder with the gridsearch results.
312        grid_search_parameters: The names for the gridsearch parameters.
313        criterion: The metric to use for determining the best parameters.
314
315    Returns:
316        The best parameter setting.
317        The evaluation score for the best setting.
318    """
319
320    # Load all the grid search results.
321    gs_files = glob(os.path.join(result_dir, "*.csv"))
322    gs_result = pd.concat([pd.read_csv(gs_file) for gs_file in gs_files])
323
324    # Retrieve only the relevant columns and group by the gridsearch columns.
325    gs_result = gs_result[grid_search_parameters + [criterion]].reset_index()
326
327    # Compute the mean over the grouped columns.
328    grouped_result = gs_result.groupby(grid_search_parameters).mean().reset_index()
329
330    # Find the best score and corresponding parameters.
331    best_score, best_idx = grouped_result[criterion].max(), grouped_result[criterion].idxmax()
332    best_params = grouped_result.iloc[best_idx]
333    assert np.isclose(best_params[criterion], best_score)
334    best_kwargs = {k: v for k, v in zip(grid_search_parameters, best_params)}
335
336    return best_kwargs, best_score
337
338
339def save_grid_search_best_params(best_kwargs, best_msa, grid_search_result_dir=None):
340    # saving the best parameters estimated from grid-search in the `results` folder
341    param_df = pd.DataFrame.from_dict([best_kwargs])
342    res_df = pd.DataFrame.from_dict([{"best_msa": best_msa}])
343    best_param_df = pd.merge(res_df, param_df, left_index=True, right_index=True)
344
345    path_name = "grid_search_params_amg.csv" if "pred_iou_thresh" and "stability_score_thresh" in best_kwargs \
346        else "grid_search_params_instance_segmentation_with_decoder.csv"
347
348    if grid_search_result_dir is not None:
349        os.makedirs(os.path.join(grid_search_result_dir, "results"), exist_ok=True)
350        res_path = os.path.join(grid_search_result_dir, "results", path_name)
351    else:
352        res_path = path_name
353
354    best_param_df.to_csv(res_path)
355
356
357def run_instance_segmentation_grid_search_and_inference(
358    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
359    grid_search_values: Dict[str, List],
360    val_image_paths: List[Union[str, os.PathLike]],
361    val_gt_paths: List[Union[str, os.PathLike]],
362    test_image_paths: List[Union[str, os.PathLike]],
363    embedding_dir: Union[str, os.PathLike],
364    prediction_dir: Union[str, os.PathLike],
365    result_dir: Union[str, os.PathLike],
366    fixed_generate_kwargs: Optional[Dict[str, Any]] = None,
367    verbose_gs: bool = True,
368) -> None:
369    """Run grid search and inference for automatic mask generation.
370
371    Please refer to the documentation of `run_instance_segmentation_grid_search`
372    for details on how to specify the grid search parameters.
373
374    Args:
375        segmenter: The class implementing the instance segmentation functionality.
376        grid_search_values: The grid search values for parameters of the `generate` function.
377        val_image_paths: The input images for the grid search.
378        val_gt_paths: The ground-truth segmentation for the grid search.
379        test_image_paths: The input images for inference.
380        embedding_dir: Folder to cache the image embeddings.
381        prediction_dir: Folder to save the predictions.
382        result_dir: Folder to cache the evaluation results per image.
383        fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter.
384        verbose_gs: Whether to run the gridsearch for individual images in a verbose mode.
385    """
386    run_instance_segmentation_grid_search(
387        segmenter, grid_search_values, val_image_paths, val_gt_paths,
388        result_dir=result_dir, embedding_dir=embedding_dir,
389        fixed_generate_kwargs=fixed_generate_kwargs, verbose_gs=verbose_gs,
390    )
391
392    best_kwargs, best_msa = evaluate_instance_segmentation_grid_search(result_dir, list(grid_search_values.keys()))
393    best_param_str = ", ".join(f"{k} = {v}" for k, v in best_kwargs.items())
394    print("Best grid-search result:", best_msa, "with parmeters:\n", best_param_str)
395    print()
396
397    save_grid_search_best_params(best_kwargs, best_msa, Path(embedding_dir).parent)
398
399    generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs
400    generate_kwargs.update(best_kwargs)
401
402    run_instance_segmentation_inference(
403        segmenter, test_image_paths, embedding_dir, prediction_dir, generate_kwargs
404    )
def default_grid_search_values_amg( iou_thresh_values: Optional[List[float]] = None, stability_score_values: Optional[List[float]] = None) -> Dict[str, List[float]]:
32def default_grid_search_values_amg(
33    iou_thresh_values: Optional[List[float]] = None,
34    stability_score_values: Optional[List[float]] = None,
35) -> Dict[str, List[float]]:
36    """Default grid-search parameter for AMG-based instance segmentation.
37
38    Return grid search values for the two most important parameters:
39    - `pred_iou_thresh`, the threshold for keeping objects according to the IoU predicted by the model.
40    - `stability_score_thresh`, the theshold for keepong objects according to their stability.
41
42    Args:
43        iou_thresh_values: The values for `pred_iou_thresh` used in the gridsearch.
44            By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
45        stability_score_values: The values for `stability_score_thresh` used in the gridsearch.
46            By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
47
48    Returns:
49        The values for grid search.
50    """
51    if iou_thresh_values is None:
52        iou_thresh_values = _get_range_of_search_values([0.6, 0.9], step=0.025)
53    if stability_score_values is None:
54        stability_score_values = _get_range_of_search_values([0.6, 0.95], step=0.025)
55    return {
56        "pred_iou_thresh": iou_thresh_values,
57        "stability_score_thresh": stability_score_values,
58    }

Default grid-search parameter for AMG-based instance segmentation.

Return grid search values for the two most important parameters:

  • pred_iou_thresh, the threshold for keeping objects according to the IoU predicted by the model.
  • stability_score_thresh, the theshold for keepong objects according to their stability.
Arguments:
  • iou_thresh_values: The values for pred_iou_thresh used in the gridsearch. By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
  • stability_score_values: The values for stability_score_thresh used in the gridsearch. By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
Returns:

The values for grid search.

def default_grid_search_values_instance_segmentation_with_decoder( center_distance_threshold_values: Optional[List[float]] = None, boundary_distance_threshold_values: Optional[List[float]] = None, distance_smoothing_values: Optional[List[float]] = None, min_size_values: Optional[List[float]] = None) -> Dict[str, List[float]]:
 61def default_grid_search_values_instance_segmentation_with_decoder(
 62    center_distance_threshold_values: Optional[List[float]] = None,
 63    boundary_distance_threshold_values: Optional[List[float]] = None,
 64    distance_smoothing_values: Optional[List[float]] = None,
 65    min_size_values: Optional[List[float]] = None,
 66) -> Dict[str, List[float]]:
 67    """Default grid-search parameter for decoder-based instance segmentation.
 68
 69    Args:
 70        center_distance_threshold_values: The values for `center_distance_threshold` used in the gridsearch.
 71            By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used.
 72        boundary_distance_threshold_values: The values for `boundary_distance_threshold` used in the gridsearch.
 73            By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used.
 74        distance_smoothing_values: The values for `distance_smoothing` used in the gridsearch.
 75            By default values in the range from 1.0 to 2.0 with a stepsize of 0.1 will be used.
 76        min_size_values: The values for `min_size` used in the gridsearch.
 77            By default the values 50, 100 and 200  are used.
 78
 79    Returns:
 80        The values for grid search.
 81    """
 82    if center_distance_threshold_values is None:
 83        center_distance_threshold_values = _get_range_of_search_values(
 84            [0.3, 0.7], step=0.1
 85        )
 86    if boundary_distance_threshold_values is None:
 87        boundary_distance_threshold_values = _get_range_of_search_values(
 88            [0.3, 0.7], step=0.1
 89        )
 90    if distance_smoothing_values is None:
 91        distance_smoothing_values = _get_range_of_search_values(
 92            [1.0, 2.0], step=0.2
 93        )
 94    if min_size_values is None:
 95        min_size_values = [50, 100, 200]
 96    return {
 97        "center_distance_threshold": center_distance_threshold_values,
 98        "boundary_distance_threshold": boundary_distance_threshold_values,
 99        "distance_smoothing": distance_smoothing_values,
100        "min_size": min_size_values,
101    }

Default grid-search parameter for decoder-based instance segmentation.

Arguments:
  • center_distance_threshold_values: The values for center_distance_threshold used in the gridsearch. By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used.
  • boundary_distance_threshold_values: The values for boundary_distance_threshold used in the gridsearch. By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used.
  • distance_smoothing_values: The values for distance_smoothing used in the gridsearch. By default values in the range from 1.0 to 2.0 with a stepsize of 0.1 will be used.
  • min_size_values: The values for min_size used in the gridsearch. By default the values 50, 100 and 200 are used.
Returns:

The values for grid search.

def run_instance_segmentation_inference( segmenter: Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder], image_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike], prediction_dir: Union[str, os.PathLike], generate_kwargs: Optional[Dict[str, Any]] = None) -> None:
246def run_instance_segmentation_inference(
247    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
248    image_paths: List[Union[str, os.PathLike]],
249    embedding_dir: Union[str, os.PathLike],
250    prediction_dir: Union[str, os.PathLike],
251    generate_kwargs: Optional[Dict[str, Any]] = None,
252) -> None:
253    """Run inference for automatic mask generation.
254
255    Args:
256        segmenter: The class implementing the instance segmentation functionality.
257        image_paths: The input images.
258        embedding_dir: Folder to cache the image embeddings.
259        prediction_dir: Folder to save the predictions.
260        generate_kwargs: The keyword arguments for the `generate` method of the segmenter.
261    """
262
263    verbose_embeddings = False
264
265    generate_kwargs = {} if generate_kwargs is None else generate_kwargs
266    predictor = segmenter._predictor
267    min_object_size = generate_kwargs.get("min_mask_region_area", 0)
268
269    for image_path in tqdm(image_paths, desc="Run inference for automatic mask generation"):
270        image_name = os.path.basename(image_path)
271
272        # We skip the images that already have been segmented.
273        prediction_path = os.path.join(prediction_dir, image_name)
274        if os.path.exists(prediction_path):
275            continue
276
277        assert os.path.exists(image_path), image_path
278        image = imageio.imread(image_path)
279
280        embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
281        image_embeddings = util.precompute_image_embeddings(
282            predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings
283        )
284
285        segmenter.initialize(image, image_embeddings)
286        masks = segmenter.generate(**generate_kwargs)
287
288        if len(masks) == 0:  # the instance segmentation can have no masks, hence we just save empty labels
289            if isinstance(segmenter, InstanceSegmentationWithDecoder):
290                this_shape = segmenter._foreground.shape
291            elif isinstance(segmenter, AMGBase):
292                this_shape = segmenter._original_size
293            else:
294                this_shape = image.shape[-2:]
295
296            instances = np.zeros(this_shape, dtype="uint32")
297        else:
298            instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=min_object_size)
299
300        # It's important to compress here, otherwise the predictions would take up a lot of space.
301        imageio.imwrite(prediction_path, instances, compression=5)

Run inference for automatic mask generation.

Arguments:
  • segmenter: The class implementing the instance segmentation functionality.
  • image_paths: The input images.
  • embedding_dir: Folder to cache the image embeddings.
  • prediction_dir: Folder to save the predictions.
  • generate_kwargs: The keyword arguments for the generate method of the segmenter.
def save_grid_search_best_params(best_kwargs, best_msa, grid_search_result_dir=None):
340def save_grid_search_best_params(best_kwargs, best_msa, grid_search_result_dir=None):
341    # saving the best parameters estimated from grid-search in the `results` folder
342    param_df = pd.DataFrame.from_dict([best_kwargs])
343    res_df = pd.DataFrame.from_dict([{"best_msa": best_msa}])
344    best_param_df = pd.merge(res_df, param_df, left_index=True, right_index=True)
345
346    path_name = "grid_search_params_amg.csv" if "pred_iou_thresh" and "stability_score_thresh" in best_kwargs \
347        else "grid_search_params_instance_segmentation_with_decoder.csv"
348
349    if grid_search_result_dir is not None:
350        os.makedirs(os.path.join(grid_search_result_dir, "results"), exist_ok=True)
351        res_path = os.path.join(grid_search_result_dir, "results", path_name)
352    else:
353        res_path = path_name
354
355    best_param_df.to_csv(res_path)
def run_instance_segmentation_grid_search_and_inference( segmenter: Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder], grid_search_values: Dict[str, List], val_image_paths: List[Union[str, os.PathLike]], val_gt_paths: List[Union[str, os.PathLike]], test_image_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike], prediction_dir: Union[str, os.PathLike], result_dir: Union[str, os.PathLike], fixed_generate_kwargs: Optional[Dict[str, Any]] = None, verbose_gs: bool = True) -> None:
358def run_instance_segmentation_grid_search_and_inference(
359    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
360    grid_search_values: Dict[str, List],
361    val_image_paths: List[Union[str, os.PathLike]],
362    val_gt_paths: List[Union[str, os.PathLike]],
363    test_image_paths: List[Union[str, os.PathLike]],
364    embedding_dir: Union[str, os.PathLike],
365    prediction_dir: Union[str, os.PathLike],
366    result_dir: Union[str, os.PathLike],
367    fixed_generate_kwargs: Optional[Dict[str, Any]] = None,
368    verbose_gs: bool = True,
369) -> None:
370    """Run grid search and inference for automatic mask generation.
371
372    Please refer to the documentation of `run_instance_segmentation_grid_search`
373    for details on how to specify the grid search parameters.
374
375    Args:
376        segmenter: The class implementing the instance segmentation functionality.
377        grid_search_values: The grid search values for parameters of the `generate` function.
378        val_image_paths: The input images for the grid search.
379        val_gt_paths: The ground-truth segmentation for the grid search.
380        test_image_paths: The input images for inference.
381        embedding_dir: Folder to cache the image embeddings.
382        prediction_dir: Folder to save the predictions.
383        result_dir: Folder to cache the evaluation results per image.
384        fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter.
385        verbose_gs: Whether to run the gridsearch for individual images in a verbose mode.
386    """
387    run_instance_segmentation_grid_search(
388        segmenter, grid_search_values, val_image_paths, val_gt_paths,
389        result_dir=result_dir, embedding_dir=embedding_dir,
390        fixed_generate_kwargs=fixed_generate_kwargs, verbose_gs=verbose_gs,
391    )
392
393    best_kwargs, best_msa = evaluate_instance_segmentation_grid_search(result_dir, list(grid_search_values.keys()))
394    best_param_str = ", ".join(f"{k} = {v}" for k, v in best_kwargs.items())
395    print("Best grid-search result:", best_msa, "with parmeters:\n", best_param_str)
396    print()
397
398    save_grid_search_best_params(best_kwargs, best_msa, Path(embedding_dir).parent)
399
400    generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs
401    generate_kwargs.update(best_kwargs)
402
403    run_instance_segmentation_inference(
404        segmenter, test_image_paths, embedding_dir, prediction_dir, generate_kwargs
405    )

Run grid search and inference for automatic mask generation.

Please refer to the documentation of run_instance_segmentation_grid_search for details on how to specify the grid search parameters.

Arguments:
  • segmenter: The class implementing the instance segmentation functionality.
  • grid_search_values: The grid search values for parameters of the generate function.
  • val_image_paths: The input images for the grid search.
  • val_gt_paths: The ground-truth segmentation for the grid search.
  • test_image_paths: The input images for inference.
  • embedding_dir: Folder to cache the image embeddings.
  • prediction_dir: Folder to save the predictions.
  • result_dir: Folder to cache the evaluation results per image.
  • fixed_generate_kwargs: Fixed keyword arguments for the generate method of the segmenter.
  • verbose_gs: Whether to run the gridsearch for individual images in a verbose mode.