micro_sam.evaluation.livecell

Inference and evaluation for the LIVECell dataset and the different cell lines contained in it.

  1"""Inference and evaluation for the [LIVECell dataset](https://www.nature.com/articles/s41592-021-01249-6) and
  2the different cell lines contained in it.
  3"""
  4
  5import os
  6import json
  7import argparse
  8import warnings
  9from glob import glob
 10from typing import List, Optional, Union
 11
 12from segment_anything import SamPredictor
 13
 14from ..instance_segmentation import (
 15    get_predictor_and_decoder,
 16    AutomaticMaskGenerator, InstanceSegmentationWithDecoder,
 17)
 18from ..util import get_sam_model
 19from ..evaluation import precompute_all_embeddings
 20from . import instance_segmentation, inference, evaluation
 21
 22
 23CELL_TYPES = ["A172", "BT474", "BV2", "Huh7", "MCF7", "SHSY5Y", "SkBr3", "SKOV3"]
 24
 25
 26#
 27# Inference
 28#
 29
 30
 31def _get_livecell_paths(input_folder, split="test", n_val_per_cell_type=None):
 32    assert split in ["val", "test"]
 33    assert os.path.exists(input_folder), f"Data not found at {input_folder}. Please download the LIVECell Dataset"
 34
 35    if split == "test":
 36
 37        img_dir = os.path.join(input_folder, "images", "livecell_test_images")
 38        assert os.path.exists(img_dir), "The LIVECell Dataset is incomplete"
 39        gt_dir = os.path.join(input_folder, "annotations", "livecell_test_images")
 40        assert os.path.exists(gt_dir), "The LIVECell Dataset is incomplete"
 41        image_paths, gt_paths = [], []
 42        for ctype in CELL_TYPES:
 43            counter = 0
 44            for img_path in glob(os.path.join(img_dir, f"{ctype}*")):
 45                if counter == n_val_per_cell_type:
 46                    continue
 47
 48                image_paths.append(img_path)
 49                img_name = os.path.basename(img_path)
 50                gt_path = os.path.join(gt_dir, ctype, img_name)
 51                assert os.path.exists(gt_path), gt_path
 52                gt_paths.append(gt_path)
 53                counter += 1
 54    else:
 55
 56        with open(os.path.join(input_folder, "val.json")) as f:
 57            data = json.load(f)
 58        livecell_val_ids = [i["file_name"] for i in data["images"]]
 59
 60        img_dir = os.path.join(input_folder, "images", "livecell_train_val_images")
 61        assert os.path.exists(img_dir), "The LIVECell Dataset is incomplete"
 62        gt_dir = os.path.join(input_folder, "annotations", "livecell_train_val_images")
 63        assert os.path.exists(gt_dir), "The LIVECell Dataset is incomplete"
 64
 65        image_paths, gt_paths = [], []
 66        count_per_cell_type = {ct: 0 for ct in CELL_TYPES}
 67
 68        for img_name in livecell_val_ids:
 69            cell_type = img_name.split("_")[0]
 70            if n_val_per_cell_type is not None and count_per_cell_type[cell_type] >= n_val_per_cell_type:
 71                continue
 72
 73            image_paths.append(os.path.join(img_dir, img_name))
 74            gt_paths.append(os.path.join(gt_dir, cell_type, img_name))
 75            count_per_cell_type[cell_type] += 1
 76
 77    return sorted(image_paths), sorted(gt_paths)
 78
 79
 80def livecell_inference(
 81    checkpoint: Union[str, os.PathLike],
 82    input_folder: Union[str, os.PathLike],
 83    model_type: str,
 84    experiment_folder: Union[str, os.PathLike],
 85    use_points: bool,
 86    use_boxes: bool,
 87    n_positives: Optional[int] = None,
 88    n_negatives: Optional[int] = None,
 89    prompt_folder: Optional[Union[str, os.PathLike]] = None,
 90    predictor: Optional[SamPredictor] = None,
 91) -> None:
 92    """Run inference for livecell with a fixed prompt setting.
 93
 94    Args:
 95        checkpoint: The segment anything model checkpoint.
 96        input_folder: The folder with the livecell data.
 97        model_type: The type of the segment anything model.
 98        experiment_folder: The folder where to save all data associated with the experiment.
 99        use_points: Whether to use point prompts.
100        use_boxes: Whether to use box prompts.
101        n_positives: The number of positive point prompts.
102        n_negatives: The number of negative point prompts.
103        prompt_folder: The folder where the prompts should be saved.
104        predictor: The segment anything predictor.
105    """
106    image_paths, gt_paths = _get_livecell_paths(input_folder)
107    if predictor is None:
108        predictor = get_sam_model(model_type=model_type, checkpoint_path=checkpoint)
109
110    if use_boxes and use_points:
111        assert (n_positives is not None) and (n_negatives is not None)
112        setting_name = f"box/p{n_positives}-n{n_negatives}"
113    elif use_boxes:
114        setting_name = "box/p0-n0"
115    elif use_points:
116        assert (n_positives is not None) and (n_negatives is not None)
117        setting_name = f"points/p{n_positives}-n{n_negatives}"
118    else:
119        raise ValueError("You need to use at least one of point or box prompts.")
120
121    # we organize all folders with data from this experiment beneath 'experiment_folder'
122    prediction_folder = os.path.join(experiment_folder, setting_name)  # where the predicted segmentations are saved
123    os.makedirs(prediction_folder, exist_ok=True)
124    embedding_folder = os.path.join(experiment_folder, "embeddings")  # where the precomputed embeddings are saved
125    os.makedirs(embedding_folder, exist_ok=True)
126
127    # NOTE: we can pass an external prompt folder, to make re-use prompts from another experiment
128    # for reproducibility / fair comparison of results
129    if prompt_folder is None:
130        prompt_folder = os.path.join(experiment_folder, "prompts")
131        os.makedirs(prompt_folder, exist_ok=True)
132
133    inference.run_inference_with_prompts(
134        predictor,
135        image_paths,
136        gt_paths,
137        embedding_dir=embedding_folder,
138        prediction_dir=prediction_folder,
139        prompt_save_dir=prompt_folder,
140        use_points=use_points,
141        use_boxes=use_boxes,
142        n_positives=n_positives,
143        n_negatives=n_negatives,
144    )
145
146
147def run_livecell_precompute_embeddings(
148    checkpoint: Union[str, os.PathLike],
149    input_folder: Union[str, os.PathLike],
150    model_type: str,
151    experiment_folder: Union[str, os.PathLike],
152    n_val_per_cell_type: int = 25,
153) -> None:
154    """Run precomputation of val and test image embeddings for livecell.
155
156    Args:
157        checkpoint: The segment anything model checkpoint.
158        input_folder: The folder with the livecell data.
159        model_type: The type of the segmenta anything model.
160        experiment_folder: The folder where to save all data associated with the experiment.
161        n_val_per_cell_type: The number of validation images per cell type.
162    """
163    embedding_folder = os.path.join(experiment_folder, "embeddings")  # where the embeddings will be saved
164    os.makedirs(embedding_folder, exist_ok=True)
165
166    predictor = get_sam_model(model_type=model_type, checkpoint_path=checkpoint)
167
168    val_image_paths, _ = _get_livecell_paths(input_folder, "val", n_val_per_cell_type=n_val_per_cell_type)
169    test_image_paths, _ = _get_livecell_paths(input_folder, "test")
170
171    precompute_all_embeddings(predictor, val_image_paths, embedding_folder)
172    precompute_all_embeddings(predictor, test_image_paths, embedding_folder)
173
174
175def run_livecell_iterative_prompting(
176    checkpoint: Union[str, os.PathLike],
177    input_folder: Union[str, os.PathLike],
178    model_type: str,
179    experiment_folder: Union[str, os.PathLike],
180    start_with_box: bool = False,
181    use_masks: bool = False,
182) -> str:
183    """Run inference on livecell with iterative prompting setting.
184
185    Args:
186        checkpoint: The segment anything model checkpoint.
187        input_folder: The folder with the livecell data.
188        model_type: The type of the segment anything model.
189        experiment_folder: The folder where to save all data associated with the experiment.
190        start_with_box_prompt: Whether to use the first prompt as bounding box or a single point.
191        use_masks: Whether to make use of logits from previous prompt-based segmentation.
192
193    """
194    embedding_folder = os.path.join(experiment_folder, "embeddings")  # where the embeddings will be saved
195    os.makedirs(embedding_folder, exist_ok=True)
196
197    predictor = get_sam_model(model_type=model_type, checkpoint_path=checkpoint)
198
199    # where the predictions are saved
200    prediction_folder = os.path.join(
201        experiment_folder, "start_with_box" if start_with_box else "start_with_point"
202    )
203
204    image_paths, gt_paths = _get_livecell_paths(input_folder, "test")
205
206    inference.run_inference_with_iterative_prompting(
207        predictor=predictor,
208        image_paths=image_paths,
209        gt_paths=gt_paths,
210        embedding_dir=embedding_folder,
211        prediction_dir=prediction_folder,
212        start_with_box_prompt=start_with_box,
213        use_masks=use_masks,
214    )
215    return prediction_folder
216
217
218def run_livecell_amg(
219    checkpoint: Union[str, os.PathLike],
220    input_folder: Union[str, os.PathLike],
221    model_type: str,
222    experiment_folder: Union[str, os.PathLike],
223    iou_thresh_values: Optional[List[float]] = None,
224    stability_score_values: Optional[List[float]] = None,
225    verbose_gs: bool = False,
226    n_val_per_cell_type: int = 25,
227) -> str:
228    """Run automatic mask generation grid-search and inference for livecell.
229
230    Args:
231        checkpoint: The segment anything model checkpoint.
232        input_folder: The folder with the livecell data.
233        model_type: The type of the segmenta anything model.
234        experiment_folder: The folder where to save all data associated with the experiment.
235        iou_thresh_values: The values for `pred_iou_thresh` used in the gridsearch.
236            By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
237        stability_score_values: The values for `stability_score_thresh` used in the gridsearch.
238            By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
239        verbose_gs: Whether to run the gridsearch for individual images in a verbose mode.
240        n_val_per_cell_type: The number of validation images per cell type.
241
242    Returns:
243        The path where the predicted images are stored.
244    """
245    embedding_folder = os.path.join(experiment_folder, "embeddings")  # where the precomputed embeddings are saved
246    os.makedirs(embedding_folder, exist_ok=True)
247
248    predictor = get_sam_model(model_type=model_type, checkpoint_path=checkpoint)
249    amg = AutomaticMaskGenerator(predictor)
250    amg_prefix = "amg"
251
252    # where the predictions are saved
253    prediction_folder = os.path.join(experiment_folder, amg_prefix, "inference")
254    os.makedirs(prediction_folder, exist_ok=True)
255
256    # where the grid-search results are saved
257    gs_result_folder = os.path.join(experiment_folder, amg_prefix, "grid_search")
258    os.makedirs(gs_result_folder, exist_ok=True)
259
260    val_image_paths, val_gt_paths = _get_livecell_paths(input_folder, "val", n_val_per_cell_type=n_val_per_cell_type)
261    test_image_paths, _ = _get_livecell_paths(input_folder, "test")
262
263    grid_search_values = instance_segmentation.default_grid_search_values_amg(
264        iou_thresh_values=iou_thresh_values,
265        stability_score_values=stability_score_values,
266    )
267
268    instance_segmentation.run_instance_segmentation_grid_search_and_inference(
269        amg, grid_search_values, val_image_paths, val_gt_paths, test_image_paths,
270        embedding_folder, prediction_folder, gs_result_folder, verbose_gs=verbose_gs
271    )
272    return prediction_folder
273
274
275def run_livecell_instance_segmentation_with_decoder(
276    checkpoint: Union[str, os.PathLike],
277    input_folder: Union[str, os.PathLike],
278    model_type: str,
279    experiment_folder: Union[str, os.PathLike],
280    center_distance_threshold_values: Optional[List[float]] = None,
281    boundary_distance_threshold_values: Optional[List[float]] = None,
282    distance_smoothing_values: Optional[List[float]] = None,
283    min_size_values: Optional[List[float]] = None,
284    verbose_gs: bool = False,
285    n_val_per_cell_type: int = 25,
286) -> str:
287    """Run automatic mask generation grid-search and inference for livecell.
288
289    Args:
290        checkpoint: The segment anything model checkpoint.
291        input_folder: The folder with the livecell data.
292        model_type: The type of the segmenta anything model.
293        experiment_folder: The folder where to save all data associated with the experiment.
294        center_distance_threshold_values: The values for `center_distance_threshold` used in the gridsearch.
295            By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used.
296        boundary_distance_threshold_values: The values for `boundary_distance_threshold` used in the gridsearch.
297            By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used.
298        distance_smoothing_values: The values for `distance_smoothing` used in the gridsearch.
299            By default values in the range from 1.0 to 2.0 with a stepsize of 0.1 will be used.
300        min_size_values: The values for `min_size` used in the gridsearch.
301            By default the values 50, 100 and 200  are used.
302        verbose_gs: Whether to run the gridsearch for individual images in a verbose mode.
303        n_val_per_cell_type: The number of validation images per cell type.
304
305    Returns:
306        The path where the predicted images are stored.
307    """
308    embedding_folder = os.path.join(experiment_folder, "embeddings")  # where the precomputed embeddings are saved
309    os.makedirs(embedding_folder, exist_ok=True)
310
311    predictor, decoder = get_predictor_and_decoder(model_type=model_type, checkpoint_path=checkpoint)
312    segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
313    seg_prefix = "instance_segmentation_with_decoder"
314
315    # where the predictions are saved
316    prediction_folder = os.path.join(experiment_folder, seg_prefix, "inference")
317    os.makedirs(prediction_folder, exist_ok=True)
318
319    # where the grid-search results are saved
320    gs_result_folder = os.path.join(experiment_folder, seg_prefix, "grid_search")
321    os.makedirs(gs_result_folder, exist_ok=True)
322
323    val_image_paths, val_gt_paths = _get_livecell_paths(input_folder, "val", n_val_per_cell_type=n_val_per_cell_type)
324    test_image_paths, _ = _get_livecell_paths(input_folder, "test")
325
326    grid_search_values = instance_segmentation.default_grid_search_values_instance_segmentation_with_decoder(
327        center_distance_threshold_values=center_distance_threshold_values,
328        boundary_distance_threshold_values=boundary_distance_threshold_values,
329        distance_smoothing_values=distance_smoothing_values,
330        min_size_values=min_size_values
331    )
332
333    instance_segmentation.run_instance_segmentation_grid_search_and_inference(
334        segmenter, grid_search_values, val_image_paths, val_gt_paths, test_image_paths, embedding_dir=embedding_folder,
335        prediction_dir=prediction_folder, result_dir=gs_result_folder, verbose_gs=verbose_gs
336    )
337    return prediction_folder
338
339
340def run_livecell_inference() -> None:
341    """Run LIVECell inference with command line tool."""
342    parser = argparse.ArgumentParser()
343
344    # the checkpoint, input and experiment folder
345    parser.add_argument(
346        "-c", "--ckpt", type=str, required=True,
347        help="Provide model checkpoints (vanilla / finetuned)."
348    )
349    parser.add_argument(
350        "-i", "--input", type=str, required=True,
351        help="Provide the data directory for LIVECell Dataset."
352    )
353    parser.add_argument(
354        "-e", "--experiment_folder", type=str, required=True,
355        help="Provide the path where all data for the inference run will be stored."
356    )
357    parser.add_argument(
358        "-m", "--model", type=str, required=True,
359        help="Pass the checkpoint-specific model name being used for inference."
360    )
361
362    # the experiment type:
363    # 1. precompute image embeddings
364    # 2. iterative prompting-based interactive instance segmentation (iterative prompting)
365    #     - iterative prompting
366    #         - starting with point
367    #         - starting with box
368    # 3. automatic segmentation (auto)
369    #     - automatic mask generation (amg)
370    #     - automatic instance segmentation (ais)
371
372    parser.add_argument("-p", "--precompute_embeddings", action="store_true")
373    parser.add_argument("-ip", "--iterative_prompting", action="store_true")
374    parser.add_argument("-amg", "--auto_mask_generation", action="store_true")
375    parser.add_argument("-ais", "--auto_instance_segmentation", action="store_true")
376
377    # the prompt settings for starting iterative prompting for interactive instance segmentation
378    #     - (default: start with points)
379    parser.add_argument(
380        "-b", "--start_with_box", action="store_true", help="Start with box for iterative prompt-based segmentation."
381    )
382    parser.add_argument(
383        "--use_masks", action="store_true",
384        help="Whether to use logits from previous interactive segmentation as inputs for iterative prompting."
385    )
386    parser.add_argument(
387        "--n_val_per_cell_type", default=25, type=int,
388        help="How many validation samples per cell type to be used for grid search."
389    )
390
391    args = parser.parse_args()
392    if sum([args.iterative_prompting, args.auto_mask_generation, args.auto_instance_segmentation]) > 1:
393        warnings.warn(
394            "It's recommended to choose either from 'iterative_prompting', 'auto_mask_generation' or "
395            "'auto_instance_segmentation' at once, else it might take a while."
396        )
397
398    if args.precompute_embeddings:
399        run_livecell_precompute_embeddings(
400            args.ckpt, args.input, args.model, args.experiment_folder, args.n_val_per_cell_type
401        )
402
403    if args.iterative_prompting:
404        run_livecell_iterative_prompting(
405            args.ckpt, args.input, args.model, args.experiment_folder,
406            start_with_box=args.start_with_box, use_masks=args.use_masks
407        )
408
409    if args.auto_instance_segmentation:
410        run_livecell_instance_segmentation_with_decoder(
411            args.ckpt, args.input, args.model, args.experiment_folder, n_val_per_cell_type=args.n_val_per_cell_type
412        )
413
414    if args.auto_mask_generation:
415        run_livecell_amg(
416            args.ckpt, args.input, args.model, args.experiment_folder, n_val_per_cell_type=args.n_val_per_cell_type
417        )
418
419
420#
421# Evaluation
422#
423
424
425def run_livecell_evaluation() -> None:
426    """Run LIVECell evaluation with command line tool."""
427    parser = argparse.ArgumentParser()
428    parser.add_argument(
429        "-i", "--input", required=True, help="Provide the data directory for LIVECell Dataset"
430    )
431    parser.add_argument(
432        "-e", "--experiment_folder", required=True,
433        help="Provide the path where the inference data is stored."
434    )
435    parser.add_argument(
436        "-f", "--force", action="store_true",
437        help="Force recomputation of already cached eval results."
438    )
439    args = parser.parse_args()
440
441    _, gt_paths = _get_livecell_paths(args.input, "test")
442
443    experiment_folder = args.experiment_folder
444    save_root = os.path.join(experiment_folder, "results")
445
446    inference_root_names = [
447        "amg/inference", "instance_segmentation_with_decoder/inference", "start_with_box", "start_with_point"
448    ]
449    for inf_root in inference_root_names:
450        pred_root = os.path.join(experiment_folder, inf_root)
451        if not os.path.exists(pred_root):
452            print(
453                f"The inference for '{inf_root}' were not generated.",
454                "Please run the inference first to evaluate on the predictions."
455            )
456            continue
457
458        if inf_root.startswith("start_with"):
459            evaluation.run_evaluation_for_iterative_prompting(
460                gt_paths=gt_paths,
461                prediction_root=pred_root,
462                experiment_folder=experiment_folder,
463                start_with_box_prompt=(inf_root == "start_with_box"),
464                overwrite_results=args.force
465            )
466        else:
467            pred_paths = sorted(glob(os.path.join(pred_root, "*")))
468            save_name = inf_root.split("/")[0]
469            save_path = os.path.join(save_root, f"{save_name}.csv")
470
471            if args.force and os.path.exists(save_path):
472                os.remove(save_path)
473
474            results = evaluation.run_evaluation(
475                gt_paths=gt_paths,
476                prediction_paths=pred_paths,
477                save_path=save_path,
478            )
479            print(results)
CELL_TYPES = ['A172', 'BT474', 'BV2', 'Huh7', 'MCF7', 'SHSY5Y', 'SkBr3', 'SKOV3']
def livecell_inference( checkpoint: Union[str, os.PathLike], input_folder: Union[str, os.PathLike], model_type: str, experiment_folder: Union[str, os.PathLike], use_points: bool, use_boxes: bool, n_positives: Optional[int] = None, n_negatives: Optional[int] = None, prompt_folder: Union[os.PathLike, str, NoneType] = None, predictor: Optional[segment_anything.predictor.SamPredictor] = None) -> None:
 81def livecell_inference(
 82    checkpoint: Union[str, os.PathLike],
 83    input_folder: Union[str, os.PathLike],
 84    model_type: str,
 85    experiment_folder: Union[str, os.PathLike],
 86    use_points: bool,
 87    use_boxes: bool,
 88    n_positives: Optional[int] = None,
 89    n_negatives: Optional[int] = None,
 90    prompt_folder: Optional[Union[str, os.PathLike]] = None,
 91    predictor: Optional[SamPredictor] = None,
 92) -> None:
 93    """Run inference for livecell with a fixed prompt setting.
 94
 95    Args:
 96        checkpoint: The segment anything model checkpoint.
 97        input_folder: The folder with the livecell data.
 98        model_type: The type of the segment anything model.
 99        experiment_folder: The folder where to save all data associated with the experiment.
100        use_points: Whether to use point prompts.
101        use_boxes: Whether to use box prompts.
102        n_positives: The number of positive point prompts.
103        n_negatives: The number of negative point prompts.
104        prompt_folder: The folder where the prompts should be saved.
105        predictor: The segment anything predictor.
106    """
107    image_paths, gt_paths = _get_livecell_paths(input_folder)
108    if predictor is None:
109        predictor = get_sam_model(model_type=model_type, checkpoint_path=checkpoint)
110
111    if use_boxes and use_points:
112        assert (n_positives is not None) and (n_negatives is not None)
113        setting_name = f"box/p{n_positives}-n{n_negatives}"
114    elif use_boxes:
115        setting_name = "box/p0-n0"
116    elif use_points:
117        assert (n_positives is not None) and (n_negatives is not None)
118        setting_name = f"points/p{n_positives}-n{n_negatives}"
119    else:
120        raise ValueError("You need to use at least one of point or box prompts.")
121
122    # we organize all folders with data from this experiment beneath 'experiment_folder'
123    prediction_folder = os.path.join(experiment_folder, setting_name)  # where the predicted segmentations are saved
124    os.makedirs(prediction_folder, exist_ok=True)
125    embedding_folder = os.path.join(experiment_folder, "embeddings")  # where the precomputed embeddings are saved
126    os.makedirs(embedding_folder, exist_ok=True)
127
128    # NOTE: we can pass an external prompt folder, to make re-use prompts from another experiment
129    # for reproducibility / fair comparison of results
130    if prompt_folder is None:
131        prompt_folder = os.path.join(experiment_folder, "prompts")
132        os.makedirs(prompt_folder, exist_ok=True)
133
134    inference.run_inference_with_prompts(
135        predictor,
136        image_paths,
137        gt_paths,
138        embedding_dir=embedding_folder,
139        prediction_dir=prediction_folder,
140        prompt_save_dir=prompt_folder,
141        use_points=use_points,
142        use_boxes=use_boxes,
143        n_positives=n_positives,
144        n_negatives=n_negatives,
145    )

Run inference for livecell with a fixed prompt setting.

Arguments:
  • checkpoint: The segment anything model checkpoint.
  • input_folder: The folder with the livecell data.
  • model_type: The type of the segment anything model.
  • experiment_folder: The folder where to save all data associated with the experiment.
  • use_points: Whether to use point prompts.
  • use_boxes: Whether to use box prompts.
  • n_positives: The number of positive point prompts.
  • n_negatives: The number of negative point prompts.
  • prompt_folder: The folder where the prompts should be saved.
  • predictor: The segment anything predictor.
def run_livecell_precompute_embeddings( checkpoint: Union[str, os.PathLike], input_folder: Union[str, os.PathLike], model_type: str, experiment_folder: Union[str, os.PathLike], n_val_per_cell_type: int = 25) -> None:
148def run_livecell_precompute_embeddings(
149    checkpoint: Union[str, os.PathLike],
150    input_folder: Union[str, os.PathLike],
151    model_type: str,
152    experiment_folder: Union[str, os.PathLike],
153    n_val_per_cell_type: int = 25,
154) -> None:
155    """Run precomputation of val and test image embeddings for livecell.
156
157    Args:
158        checkpoint: The segment anything model checkpoint.
159        input_folder: The folder with the livecell data.
160        model_type: The type of the segmenta anything model.
161        experiment_folder: The folder where to save all data associated with the experiment.
162        n_val_per_cell_type: The number of validation images per cell type.
163    """
164    embedding_folder = os.path.join(experiment_folder, "embeddings")  # where the embeddings will be saved
165    os.makedirs(embedding_folder, exist_ok=True)
166
167    predictor = get_sam_model(model_type=model_type, checkpoint_path=checkpoint)
168
169    val_image_paths, _ = _get_livecell_paths(input_folder, "val", n_val_per_cell_type=n_val_per_cell_type)
170    test_image_paths, _ = _get_livecell_paths(input_folder, "test")
171
172    precompute_all_embeddings(predictor, val_image_paths, embedding_folder)
173    precompute_all_embeddings(predictor, test_image_paths, embedding_folder)

Run precomputation of val and test image embeddings for livecell.

Arguments:
  • checkpoint: The segment anything model checkpoint.
  • input_folder: The folder with the livecell data.
  • model_type: The type of the segmenta anything model.
  • experiment_folder: The folder where to save all data associated with the experiment.
  • n_val_per_cell_type: The number of validation images per cell type.
def run_livecell_iterative_prompting( checkpoint: Union[str, os.PathLike], input_folder: Union[str, os.PathLike], model_type: str, experiment_folder: Union[str, os.PathLike], start_with_box: bool = False, use_masks: bool = False) -> str:
176def run_livecell_iterative_prompting(
177    checkpoint: Union[str, os.PathLike],
178    input_folder: Union[str, os.PathLike],
179    model_type: str,
180    experiment_folder: Union[str, os.PathLike],
181    start_with_box: bool = False,
182    use_masks: bool = False,
183) -> str:
184    """Run inference on livecell with iterative prompting setting.
185
186    Args:
187        checkpoint: The segment anything model checkpoint.
188        input_folder: The folder with the livecell data.
189        model_type: The type of the segment anything model.
190        experiment_folder: The folder where to save all data associated with the experiment.
191        start_with_box_prompt: Whether to use the first prompt as bounding box or a single point.
192        use_masks: Whether to make use of logits from previous prompt-based segmentation.
193
194    """
195    embedding_folder = os.path.join(experiment_folder, "embeddings")  # where the embeddings will be saved
196    os.makedirs(embedding_folder, exist_ok=True)
197
198    predictor = get_sam_model(model_type=model_type, checkpoint_path=checkpoint)
199
200    # where the predictions are saved
201    prediction_folder = os.path.join(
202        experiment_folder, "start_with_box" if start_with_box else "start_with_point"
203    )
204
205    image_paths, gt_paths = _get_livecell_paths(input_folder, "test")
206
207    inference.run_inference_with_iterative_prompting(
208        predictor=predictor,
209        image_paths=image_paths,
210        gt_paths=gt_paths,
211        embedding_dir=embedding_folder,
212        prediction_dir=prediction_folder,
213        start_with_box_prompt=start_with_box,
214        use_masks=use_masks,
215    )
216    return prediction_folder

Run inference on livecell with iterative prompting setting.

Arguments:
  • checkpoint: The segment anything model checkpoint.
  • input_folder: The folder with the livecell data.
  • model_type: The type of the segment anything model.
  • experiment_folder: The folder where to save all data associated with the experiment.
  • start_with_box_prompt: Whether to use the first prompt as bounding box or a single point.
  • use_masks: Whether to make use of logits from previous prompt-based segmentation.
def run_livecell_amg( checkpoint: Union[str, os.PathLike], input_folder: Union[str, os.PathLike], model_type: str, experiment_folder: Union[str, os.PathLike], iou_thresh_values: Optional[List[float]] = None, stability_score_values: Optional[List[float]] = None, verbose_gs: bool = False, n_val_per_cell_type: int = 25) -> str:
219def run_livecell_amg(
220    checkpoint: Union[str, os.PathLike],
221    input_folder: Union[str, os.PathLike],
222    model_type: str,
223    experiment_folder: Union[str, os.PathLike],
224    iou_thresh_values: Optional[List[float]] = None,
225    stability_score_values: Optional[List[float]] = None,
226    verbose_gs: bool = False,
227    n_val_per_cell_type: int = 25,
228) -> str:
229    """Run automatic mask generation grid-search and inference for livecell.
230
231    Args:
232        checkpoint: The segment anything model checkpoint.
233        input_folder: The folder with the livecell data.
234        model_type: The type of the segmenta anything model.
235        experiment_folder: The folder where to save all data associated with the experiment.
236        iou_thresh_values: The values for `pred_iou_thresh` used in the gridsearch.
237            By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
238        stability_score_values: The values for `stability_score_thresh` used in the gridsearch.
239            By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
240        verbose_gs: Whether to run the gridsearch for individual images in a verbose mode.
241        n_val_per_cell_type: The number of validation images per cell type.
242
243    Returns:
244        The path where the predicted images are stored.
245    """
246    embedding_folder = os.path.join(experiment_folder, "embeddings")  # where the precomputed embeddings are saved
247    os.makedirs(embedding_folder, exist_ok=True)
248
249    predictor = get_sam_model(model_type=model_type, checkpoint_path=checkpoint)
250    amg = AutomaticMaskGenerator(predictor)
251    amg_prefix = "amg"
252
253    # where the predictions are saved
254    prediction_folder = os.path.join(experiment_folder, amg_prefix, "inference")
255    os.makedirs(prediction_folder, exist_ok=True)
256
257    # where the grid-search results are saved
258    gs_result_folder = os.path.join(experiment_folder, amg_prefix, "grid_search")
259    os.makedirs(gs_result_folder, exist_ok=True)
260
261    val_image_paths, val_gt_paths = _get_livecell_paths(input_folder, "val", n_val_per_cell_type=n_val_per_cell_type)
262    test_image_paths, _ = _get_livecell_paths(input_folder, "test")
263
264    grid_search_values = instance_segmentation.default_grid_search_values_amg(
265        iou_thresh_values=iou_thresh_values,
266        stability_score_values=stability_score_values,
267    )
268
269    instance_segmentation.run_instance_segmentation_grid_search_and_inference(
270        amg, grid_search_values, val_image_paths, val_gt_paths, test_image_paths,
271        embedding_folder, prediction_folder, gs_result_folder, verbose_gs=verbose_gs
272    )
273    return prediction_folder

Run automatic mask generation grid-search and inference for livecell.

Arguments:
  • checkpoint: The segment anything model checkpoint.
  • input_folder: The folder with the livecell data.
  • model_type: The type of the segmenta anything model.
  • experiment_folder: The folder where to save all data associated with the experiment.
  • iou_thresh_values: The values for pred_iou_thresh used in the gridsearch. By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
  • stability_score_values: The values for stability_score_thresh used in the gridsearch. By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used.
  • verbose_gs: Whether to run the gridsearch for individual images in a verbose mode.
  • n_val_per_cell_type: The number of validation images per cell type.
Returns:

The path where the predicted images are stored.

def run_livecell_instance_segmentation_with_decoder( checkpoint: Union[str, os.PathLike], input_folder: Union[str, os.PathLike], model_type: str, experiment_folder: Union[str, os.PathLike], center_distance_threshold_values: Optional[List[float]] = None, boundary_distance_threshold_values: Optional[List[float]] = None, distance_smoothing_values: Optional[List[float]] = None, min_size_values: Optional[List[float]] = None, verbose_gs: bool = False, n_val_per_cell_type: int = 25) -> str:
276def run_livecell_instance_segmentation_with_decoder(
277    checkpoint: Union[str, os.PathLike],
278    input_folder: Union[str, os.PathLike],
279    model_type: str,
280    experiment_folder: Union[str, os.PathLike],
281    center_distance_threshold_values: Optional[List[float]] = None,
282    boundary_distance_threshold_values: Optional[List[float]] = None,
283    distance_smoothing_values: Optional[List[float]] = None,
284    min_size_values: Optional[List[float]] = None,
285    verbose_gs: bool = False,
286    n_val_per_cell_type: int = 25,
287) -> str:
288    """Run automatic mask generation grid-search and inference for livecell.
289
290    Args:
291        checkpoint: The segment anything model checkpoint.
292        input_folder: The folder with the livecell data.
293        model_type: The type of the segmenta anything model.
294        experiment_folder: The folder where to save all data associated with the experiment.
295        center_distance_threshold_values: The values for `center_distance_threshold` used in the gridsearch.
296            By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used.
297        boundary_distance_threshold_values: The values for `boundary_distance_threshold` used in the gridsearch.
298            By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used.
299        distance_smoothing_values: The values for `distance_smoothing` used in the gridsearch.
300            By default values in the range from 1.0 to 2.0 with a stepsize of 0.1 will be used.
301        min_size_values: The values for `min_size` used in the gridsearch.
302            By default the values 50, 100 and 200  are used.
303        verbose_gs: Whether to run the gridsearch for individual images in a verbose mode.
304        n_val_per_cell_type: The number of validation images per cell type.
305
306    Returns:
307        The path where the predicted images are stored.
308    """
309    embedding_folder = os.path.join(experiment_folder, "embeddings")  # where the precomputed embeddings are saved
310    os.makedirs(embedding_folder, exist_ok=True)
311
312    predictor, decoder = get_predictor_and_decoder(model_type=model_type, checkpoint_path=checkpoint)
313    segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
314    seg_prefix = "instance_segmentation_with_decoder"
315
316    # where the predictions are saved
317    prediction_folder = os.path.join(experiment_folder, seg_prefix, "inference")
318    os.makedirs(prediction_folder, exist_ok=True)
319
320    # where the grid-search results are saved
321    gs_result_folder = os.path.join(experiment_folder, seg_prefix, "grid_search")
322    os.makedirs(gs_result_folder, exist_ok=True)
323
324    val_image_paths, val_gt_paths = _get_livecell_paths(input_folder, "val", n_val_per_cell_type=n_val_per_cell_type)
325    test_image_paths, _ = _get_livecell_paths(input_folder, "test")
326
327    grid_search_values = instance_segmentation.default_grid_search_values_instance_segmentation_with_decoder(
328        center_distance_threshold_values=center_distance_threshold_values,
329        boundary_distance_threshold_values=boundary_distance_threshold_values,
330        distance_smoothing_values=distance_smoothing_values,
331        min_size_values=min_size_values
332    )
333
334    instance_segmentation.run_instance_segmentation_grid_search_and_inference(
335        segmenter, grid_search_values, val_image_paths, val_gt_paths, test_image_paths, embedding_dir=embedding_folder,
336        prediction_dir=prediction_folder, result_dir=gs_result_folder, verbose_gs=verbose_gs
337    )
338    return prediction_folder

Run automatic mask generation grid-search and inference for livecell.

Arguments:
  • checkpoint: The segment anything model checkpoint.
  • input_folder: The folder with the livecell data.
  • model_type: The type of the segmenta anything model.
  • experiment_folder: The folder where to save all data associated with the experiment.
  • center_distance_threshold_values: The values for center_distance_threshold used in the gridsearch. By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used.
  • boundary_distance_threshold_values: The values for boundary_distance_threshold used in the gridsearch. By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used.
  • distance_smoothing_values: The values for distance_smoothing used in the gridsearch. By default values in the range from 1.0 to 2.0 with a stepsize of 0.1 will be used.
  • min_size_values: The values for min_size used in the gridsearch. By default the values 50, 100 and 200 are used.
  • verbose_gs: Whether to run the gridsearch for individual images in a verbose mode.
  • n_val_per_cell_type: The number of validation images per cell type.
Returns:

The path where the predicted images are stored.

def run_livecell_inference() -> None:
341def run_livecell_inference() -> None:
342    """Run LIVECell inference with command line tool."""
343    parser = argparse.ArgumentParser()
344
345    # the checkpoint, input and experiment folder
346    parser.add_argument(
347        "-c", "--ckpt", type=str, required=True,
348        help="Provide model checkpoints (vanilla / finetuned)."
349    )
350    parser.add_argument(
351        "-i", "--input", type=str, required=True,
352        help="Provide the data directory for LIVECell Dataset."
353    )
354    parser.add_argument(
355        "-e", "--experiment_folder", type=str, required=True,
356        help="Provide the path where all data for the inference run will be stored."
357    )
358    parser.add_argument(
359        "-m", "--model", type=str, required=True,
360        help="Pass the checkpoint-specific model name being used for inference."
361    )
362
363    # the experiment type:
364    # 1. precompute image embeddings
365    # 2. iterative prompting-based interactive instance segmentation (iterative prompting)
366    #     - iterative prompting
367    #         - starting with point
368    #         - starting with box
369    # 3. automatic segmentation (auto)
370    #     - automatic mask generation (amg)
371    #     - automatic instance segmentation (ais)
372
373    parser.add_argument("-p", "--precompute_embeddings", action="store_true")
374    parser.add_argument("-ip", "--iterative_prompting", action="store_true")
375    parser.add_argument("-amg", "--auto_mask_generation", action="store_true")
376    parser.add_argument("-ais", "--auto_instance_segmentation", action="store_true")
377
378    # the prompt settings for starting iterative prompting for interactive instance segmentation
379    #     - (default: start with points)
380    parser.add_argument(
381        "-b", "--start_with_box", action="store_true", help="Start with box for iterative prompt-based segmentation."
382    )
383    parser.add_argument(
384        "--use_masks", action="store_true",
385        help="Whether to use logits from previous interactive segmentation as inputs for iterative prompting."
386    )
387    parser.add_argument(
388        "--n_val_per_cell_type", default=25, type=int,
389        help="How many validation samples per cell type to be used for grid search."
390    )
391
392    args = parser.parse_args()
393    if sum([args.iterative_prompting, args.auto_mask_generation, args.auto_instance_segmentation]) > 1:
394        warnings.warn(
395            "It's recommended to choose either from 'iterative_prompting', 'auto_mask_generation' or "
396            "'auto_instance_segmentation' at once, else it might take a while."
397        )
398
399    if args.precompute_embeddings:
400        run_livecell_precompute_embeddings(
401            args.ckpt, args.input, args.model, args.experiment_folder, args.n_val_per_cell_type
402        )
403
404    if args.iterative_prompting:
405        run_livecell_iterative_prompting(
406            args.ckpt, args.input, args.model, args.experiment_folder,
407            start_with_box=args.start_with_box, use_masks=args.use_masks
408        )
409
410    if args.auto_instance_segmentation:
411        run_livecell_instance_segmentation_with_decoder(
412            args.ckpt, args.input, args.model, args.experiment_folder, n_val_per_cell_type=args.n_val_per_cell_type
413        )
414
415    if args.auto_mask_generation:
416        run_livecell_amg(
417            args.ckpt, args.input, args.model, args.experiment_folder, n_val_per_cell_type=args.n_val_per_cell_type
418        )

Run LIVECell inference with command line tool.

def run_livecell_evaluation() -> None:
426def run_livecell_evaluation() -> None:
427    """Run LIVECell evaluation with command line tool."""
428    parser = argparse.ArgumentParser()
429    parser.add_argument(
430        "-i", "--input", required=True, help="Provide the data directory for LIVECell Dataset"
431    )
432    parser.add_argument(
433        "-e", "--experiment_folder", required=True,
434        help="Provide the path where the inference data is stored."
435    )
436    parser.add_argument(
437        "-f", "--force", action="store_true",
438        help="Force recomputation of already cached eval results."
439    )
440    args = parser.parse_args()
441
442    _, gt_paths = _get_livecell_paths(args.input, "test")
443
444    experiment_folder = args.experiment_folder
445    save_root = os.path.join(experiment_folder, "results")
446
447    inference_root_names = [
448        "amg/inference", "instance_segmentation_with_decoder/inference", "start_with_box", "start_with_point"
449    ]
450    for inf_root in inference_root_names:
451        pred_root = os.path.join(experiment_folder, inf_root)
452        if not os.path.exists(pred_root):
453            print(
454                f"The inference for '{inf_root}' were not generated.",
455                "Please run the inference first to evaluate on the predictions."
456            )
457            continue
458
459        if inf_root.startswith("start_with"):
460            evaluation.run_evaluation_for_iterative_prompting(
461                gt_paths=gt_paths,
462                prediction_root=pred_root,
463                experiment_folder=experiment_folder,
464                start_with_box_prompt=(inf_root == "start_with_box"),
465                overwrite_results=args.force
466            )
467        else:
468            pred_paths = sorted(glob(os.path.join(pred_root, "*")))
469            save_name = inf_root.split("/")[0]
470            save_path = os.path.join(save_root, f"{save_name}.csv")
471
472            if args.force and os.path.exists(save_path):
473                os.remove(save_path)
474
475            results = evaluation.run_evaluation(
476                gt_paths=gt_paths,
477                prediction_paths=pred_paths,
478                save_path=save_path,
479            )
480            print(results)

Run LIVECell evaluation with command line tool.