micro_sam.evaluation.livecell
Inference and evaluation for the LIVECell dataset and the different cell lines contained in it.
1"""Inference and evaluation for the [LIVECell dataset](https://www.nature.com/articles/s41592-021-01249-6) and 2the different cell lines contained in it. 3""" 4 5import os 6import json 7import argparse 8import warnings 9from glob import glob 10from typing import List, Optional, Union 11 12from segment_anything import SamPredictor 13 14from ..instance_segmentation import ( 15 get_predictor_and_decoder, 16 AutomaticMaskGenerator, InstanceSegmentationWithDecoder, 17) 18from ..util import get_sam_model 19from ..evaluation import precompute_all_embeddings 20from . import instance_segmentation, inference, evaluation 21 22 23CELL_TYPES = ["A172", "BT474", "BV2", "Huh7", "MCF7", "SHSY5Y", "SkBr3", "SKOV3"] 24 25 26# 27# Inference 28# 29 30 31def _get_livecell_paths(input_folder, split="test", n_val_per_cell_type=None): 32 assert split in ["val", "test"] 33 assert os.path.exists(input_folder), f"Data not found at {input_folder}. Please download the LIVECell Dataset" 34 35 if split == "test": 36 37 img_dir = os.path.join(input_folder, "images", "livecell_test_images") 38 assert os.path.exists(img_dir), "The LIVECell Dataset is incomplete" 39 gt_dir = os.path.join(input_folder, "annotations", "livecell_test_images") 40 assert os.path.exists(gt_dir), "The LIVECell Dataset is incomplete" 41 image_paths, gt_paths = [], [] 42 for ctype in CELL_TYPES: 43 counter = 0 44 for img_path in glob(os.path.join(img_dir, f"{ctype}*")): 45 if counter == n_val_per_cell_type: 46 continue 47 48 image_paths.append(img_path) 49 img_name = os.path.basename(img_path) 50 gt_path = os.path.join(gt_dir, ctype, img_name) 51 assert os.path.exists(gt_path), gt_path 52 gt_paths.append(gt_path) 53 counter += 1 54 else: 55 56 with open(os.path.join(input_folder, "val.json")) as f: 57 data = json.load(f) 58 livecell_val_ids = [i["file_name"] for i in data["images"]] 59 60 img_dir = os.path.join(input_folder, "images", "livecell_train_val_images") 61 assert os.path.exists(img_dir), "The LIVECell Dataset is incomplete" 62 gt_dir = os.path.join(input_folder, "annotations", "livecell_train_val_images") 63 assert os.path.exists(gt_dir), "The LIVECell Dataset is incomplete" 64 65 image_paths, gt_paths = [], [] 66 count_per_cell_type = {ct: 0 for ct in CELL_TYPES} 67 68 for img_name in livecell_val_ids: 69 cell_type = img_name.split("_")[0] 70 if n_val_per_cell_type is not None and count_per_cell_type[cell_type] >= n_val_per_cell_type: 71 continue 72 73 image_paths.append(os.path.join(img_dir, img_name)) 74 gt_paths.append(os.path.join(gt_dir, cell_type, img_name)) 75 count_per_cell_type[cell_type] += 1 76 77 return sorted(image_paths), sorted(gt_paths) 78 79 80def livecell_inference( 81 checkpoint: Union[str, os.PathLike], 82 input_folder: Union[str, os.PathLike], 83 model_type: str, 84 experiment_folder: Union[str, os.PathLike], 85 use_points: bool, 86 use_boxes: bool, 87 n_positives: Optional[int] = None, 88 n_negatives: Optional[int] = None, 89 prompt_folder: Optional[Union[str, os.PathLike]] = None, 90 predictor: Optional[SamPredictor] = None, 91) -> None: 92 """Run inference for livecell with a fixed prompt setting. 93 94 Args: 95 checkpoint: The segment anything model checkpoint. 96 input_folder: The folder with the livecell data. 97 model_type: The type of the segment anything model. 98 experiment_folder: The folder where to save all data associated with the experiment. 99 use_points: Whether to use point prompts. 100 use_boxes: Whether to use box prompts. 101 n_positives: The number of positive point prompts. 102 n_negatives: The number of negative point prompts. 103 prompt_folder: The folder where the prompts should be saved. 104 predictor: The segment anything predictor. 105 """ 106 image_paths, gt_paths = _get_livecell_paths(input_folder) 107 if predictor is None: 108 predictor = get_sam_model(model_type=model_type, checkpoint_path=checkpoint) 109 110 if use_boxes and use_points: 111 assert (n_positives is not None) and (n_negatives is not None) 112 setting_name = f"box/p{n_positives}-n{n_negatives}" 113 elif use_boxes: 114 setting_name = "box/p0-n0" 115 elif use_points: 116 assert (n_positives is not None) and (n_negatives is not None) 117 setting_name = f"points/p{n_positives}-n{n_negatives}" 118 else: 119 raise ValueError("You need to use at least one of point or box prompts.") 120 121 # we organize all folders with data from this experiment beneath 'experiment_folder' 122 prediction_folder = os.path.join(experiment_folder, setting_name) # where the predicted segmentations are saved 123 os.makedirs(prediction_folder, exist_ok=True) 124 embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved 125 os.makedirs(embedding_folder, exist_ok=True) 126 127 # NOTE: we can pass an external prompt folder, to make re-use prompts from another experiment 128 # for reproducibility / fair comparison of results 129 if prompt_folder is None: 130 prompt_folder = os.path.join(experiment_folder, "prompts") 131 os.makedirs(prompt_folder, exist_ok=True) 132 133 inference.run_inference_with_prompts( 134 predictor, 135 image_paths, 136 gt_paths, 137 embedding_dir=embedding_folder, 138 prediction_dir=prediction_folder, 139 prompt_save_dir=prompt_folder, 140 use_points=use_points, 141 use_boxes=use_boxes, 142 n_positives=n_positives, 143 n_negatives=n_negatives, 144 ) 145 146 147def run_livecell_precompute_embeddings( 148 checkpoint: Union[str, os.PathLike], 149 input_folder: Union[str, os.PathLike], 150 model_type: str, 151 experiment_folder: Union[str, os.PathLike], 152 n_val_per_cell_type: int = 25, 153) -> None: 154 """Run precomputation of val and test image embeddings for livecell. 155 156 Args: 157 checkpoint: The segment anything model checkpoint. 158 input_folder: The folder with the livecell data. 159 model_type: The type of the segmenta anything model. 160 experiment_folder: The folder where to save all data associated with the experiment. 161 n_val_per_cell_type: The number of validation images per cell type. 162 """ 163 embedding_folder = os.path.join(experiment_folder, "embeddings") # where the embeddings will be saved 164 os.makedirs(embedding_folder, exist_ok=True) 165 166 predictor = get_sam_model(model_type=model_type, checkpoint_path=checkpoint) 167 168 val_image_paths, _ = _get_livecell_paths(input_folder, "val", n_val_per_cell_type=n_val_per_cell_type) 169 test_image_paths, _ = _get_livecell_paths(input_folder, "test") 170 171 precompute_all_embeddings(predictor, val_image_paths, embedding_folder) 172 precompute_all_embeddings(predictor, test_image_paths, embedding_folder) 173 174 175def run_livecell_iterative_prompting( 176 checkpoint: Union[str, os.PathLike], 177 input_folder: Union[str, os.PathLike], 178 model_type: str, 179 experiment_folder: Union[str, os.PathLike], 180 start_with_box: bool = False, 181 use_masks: bool = False, 182) -> str: 183 """Run inference on livecell with iterative prompting setting. 184 185 Args: 186 checkpoint: The segment anything model checkpoint. 187 input_folder: The folder with the livecell data. 188 model_type: The type of the segment anything model. 189 experiment_folder: The folder where to save all data associated with the experiment. 190 start_with_box_prompt: Whether to use the first prompt as bounding box or a single point. 191 use_masks: Whether to make use of logits from previous prompt-based segmentation. 192 193 """ 194 embedding_folder = os.path.join(experiment_folder, "embeddings") # where the embeddings will be saved 195 os.makedirs(embedding_folder, exist_ok=True) 196 197 predictor = get_sam_model(model_type=model_type, checkpoint_path=checkpoint) 198 199 # where the predictions are saved 200 prediction_folder = os.path.join( 201 experiment_folder, "start_with_box" if start_with_box else "start_with_point" 202 ) 203 204 image_paths, gt_paths = _get_livecell_paths(input_folder, "test") 205 206 inference.run_inference_with_iterative_prompting( 207 predictor=predictor, 208 image_paths=image_paths, 209 gt_paths=gt_paths, 210 embedding_dir=embedding_folder, 211 prediction_dir=prediction_folder, 212 start_with_box_prompt=start_with_box, 213 use_masks=use_masks, 214 ) 215 return prediction_folder 216 217 218def run_livecell_amg( 219 checkpoint: Union[str, os.PathLike], 220 input_folder: Union[str, os.PathLike], 221 model_type: str, 222 experiment_folder: Union[str, os.PathLike], 223 iou_thresh_values: Optional[List[float]] = None, 224 stability_score_values: Optional[List[float]] = None, 225 verbose_gs: bool = False, 226 n_val_per_cell_type: int = 25, 227) -> str: 228 """Run automatic mask generation grid-search and inference for livecell. 229 230 Args: 231 checkpoint: The segment anything model checkpoint. 232 input_folder: The folder with the livecell data. 233 model_type: The type of the segmenta anything model. 234 experiment_folder: The folder where to save all data associated with the experiment. 235 iou_thresh_values: The values for `pred_iou_thresh` used in the gridsearch. 236 By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. 237 stability_score_values: The values for `stability_score_thresh` used in the gridsearch. 238 By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. 239 verbose_gs: Whether to run the gridsearch for individual images in a verbose mode. 240 n_val_per_cell_type: The number of validation images per cell type. 241 242 Returns: 243 The path where the predicted images are stored. 244 """ 245 embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved 246 os.makedirs(embedding_folder, exist_ok=True) 247 248 predictor = get_sam_model(model_type=model_type, checkpoint_path=checkpoint) 249 amg = AutomaticMaskGenerator(predictor) 250 amg_prefix = "amg" 251 252 # where the predictions are saved 253 prediction_folder = os.path.join(experiment_folder, amg_prefix, "inference") 254 os.makedirs(prediction_folder, exist_ok=True) 255 256 # where the grid-search results are saved 257 gs_result_folder = os.path.join(experiment_folder, amg_prefix, "grid_search") 258 os.makedirs(gs_result_folder, exist_ok=True) 259 260 val_image_paths, val_gt_paths = _get_livecell_paths(input_folder, "val", n_val_per_cell_type=n_val_per_cell_type) 261 test_image_paths, _ = _get_livecell_paths(input_folder, "test") 262 263 grid_search_values = instance_segmentation.default_grid_search_values_amg( 264 iou_thresh_values=iou_thresh_values, 265 stability_score_values=stability_score_values, 266 ) 267 268 instance_segmentation.run_instance_segmentation_grid_search_and_inference( 269 amg, grid_search_values, val_image_paths, val_gt_paths, test_image_paths, 270 embedding_folder, prediction_folder, gs_result_folder, verbose_gs=verbose_gs 271 ) 272 return prediction_folder 273 274 275def run_livecell_instance_segmentation_with_decoder( 276 checkpoint: Union[str, os.PathLike], 277 input_folder: Union[str, os.PathLike], 278 model_type: str, 279 experiment_folder: Union[str, os.PathLike], 280 center_distance_threshold_values: Optional[List[float]] = None, 281 boundary_distance_threshold_values: Optional[List[float]] = None, 282 distance_smoothing_values: Optional[List[float]] = None, 283 min_size_values: Optional[List[float]] = None, 284 verbose_gs: bool = False, 285 n_val_per_cell_type: int = 25, 286) -> str: 287 """Run automatic mask generation grid-search and inference for livecell. 288 289 Args: 290 checkpoint: The segment anything model checkpoint. 291 input_folder: The folder with the livecell data. 292 model_type: The type of the segmenta anything model. 293 experiment_folder: The folder where to save all data associated with the experiment. 294 center_distance_threshold_values: The values for `center_distance_threshold` used in the gridsearch. 295 By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used. 296 boundary_distance_threshold_values: The values for `boundary_distance_threshold` used in the gridsearch. 297 By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used. 298 distance_smoothing_values: The values for `distance_smoothing` used in the gridsearch. 299 By default values in the range from 1.0 to 2.0 with a stepsize of 0.1 will be used. 300 min_size_values: The values for `min_size` used in the gridsearch. 301 By default the values 50, 100 and 200 are used. 302 verbose_gs: Whether to run the gridsearch for individual images in a verbose mode. 303 n_val_per_cell_type: The number of validation images per cell type. 304 305 Returns: 306 The path where the predicted images are stored. 307 """ 308 embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved 309 os.makedirs(embedding_folder, exist_ok=True) 310 311 predictor, decoder = get_predictor_and_decoder(model_type=model_type, checkpoint_path=checkpoint) 312 segmenter = InstanceSegmentationWithDecoder(predictor, decoder) 313 seg_prefix = "instance_segmentation_with_decoder" 314 315 # where the predictions are saved 316 prediction_folder = os.path.join(experiment_folder, seg_prefix, "inference") 317 os.makedirs(prediction_folder, exist_ok=True) 318 319 # where the grid-search results are saved 320 gs_result_folder = os.path.join(experiment_folder, seg_prefix, "grid_search") 321 os.makedirs(gs_result_folder, exist_ok=True) 322 323 val_image_paths, val_gt_paths = _get_livecell_paths(input_folder, "val", n_val_per_cell_type=n_val_per_cell_type) 324 test_image_paths, _ = _get_livecell_paths(input_folder, "test") 325 326 grid_search_values = instance_segmentation.default_grid_search_values_instance_segmentation_with_decoder( 327 center_distance_threshold_values=center_distance_threshold_values, 328 boundary_distance_threshold_values=boundary_distance_threshold_values, 329 distance_smoothing_values=distance_smoothing_values, 330 min_size_values=min_size_values 331 ) 332 333 instance_segmentation.run_instance_segmentation_grid_search_and_inference( 334 segmenter, grid_search_values, val_image_paths, val_gt_paths, test_image_paths, embedding_dir=embedding_folder, 335 prediction_dir=prediction_folder, result_dir=gs_result_folder, verbose_gs=verbose_gs 336 ) 337 return prediction_folder 338 339 340def run_livecell_inference() -> None: 341 """Run LIVECell inference with command line tool.""" 342 parser = argparse.ArgumentParser() 343 344 # the checkpoint, input and experiment folder 345 parser.add_argument( 346 "-c", "--ckpt", type=str, required=True, 347 help="Provide model checkpoints (vanilla / finetuned)." 348 ) 349 parser.add_argument( 350 "-i", "--input", type=str, required=True, 351 help="Provide the data directory for LIVECell Dataset." 352 ) 353 parser.add_argument( 354 "-e", "--experiment_folder", type=str, required=True, 355 help="Provide the path where all data for the inference run will be stored." 356 ) 357 parser.add_argument( 358 "-m", "--model", type=str, required=True, 359 help="Pass the checkpoint-specific model name being used for inference." 360 ) 361 362 # the experiment type: 363 # 1. precompute image embeddings 364 # 2. iterative prompting-based interactive instance segmentation (iterative prompting) 365 # - iterative prompting 366 # - starting with point 367 # - starting with box 368 # 3. automatic segmentation (auto) 369 # - automatic mask generation (amg) 370 # - automatic instance segmentation (ais) 371 372 parser.add_argument("-p", "--precompute_embeddings", action="store_true") 373 parser.add_argument("-ip", "--iterative_prompting", action="store_true") 374 parser.add_argument("-amg", "--auto_mask_generation", action="store_true") 375 parser.add_argument("-ais", "--auto_instance_segmentation", action="store_true") 376 377 # the prompt settings for starting iterative prompting for interactive instance segmentation 378 # - (default: start with points) 379 parser.add_argument( 380 "-b", "--start_with_box", action="store_true", help="Start with box for iterative prompt-based segmentation." 381 ) 382 parser.add_argument( 383 "--use_masks", action="store_true", 384 help="Whether to use logits from previous interactive segmentation as inputs for iterative prompting." 385 ) 386 parser.add_argument( 387 "--n_val_per_cell_type", default=25, type=int, 388 help="How many validation samples per cell type to be used for grid search." 389 ) 390 391 args = parser.parse_args() 392 if sum([args.iterative_prompting, args.auto_mask_generation, args.auto_instance_segmentation]) > 1: 393 warnings.warn( 394 "It's recommended to choose either from 'iterative_prompting', 'auto_mask_generation' or " 395 "'auto_instance_segmentation' at once, else it might take a while." 396 ) 397 398 if args.precompute_embeddings: 399 run_livecell_precompute_embeddings( 400 args.ckpt, args.input, args.model, args.experiment_folder, args.n_val_per_cell_type 401 ) 402 403 if args.iterative_prompting: 404 run_livecell_iterative_prompting( 405 args.ckpt, args.input, args.model, args.experiment_folder, 406 start_with_box=args.start_with_box, use_masks=args.use_masks 407 ) 408 409 if args.auto_instance_segmentation: 410 run_livecell_instance_segmentation_with_decoder( 411 args.ckpt, args.input, args.model, args.experiment_folder, n_val_per_cell_type=args.n_val_per_cell_type 412 ) 413 414 if args.auto_mask_generation: 415 run_livecell_amg( 416 args.ckpt, args.input, args.model, args.experiment_folder, n_val_per_cell_type=args.n_val_per_cell_type 417 ) 418 419 420# 421# Evaluation 422# 423 424 425def run_livecell_evaluation() -> None: 426 """Run LIVECell evaluation with command line tool.""" 427 parser = argparse.ArgumentParser() 428 parser.add_argument( 429 "-i", "--input", required=True, help="Provide the data directory for LIVECell Dataset" 430 ) 431 parser.add_argument( 432 "-e", "--experiment_folder", required=True, 433 help="Provide the path where the inference data is stored." 434 ) 435 parser.add_argument( 436 "-f", "--force", action="store_true", 437 help="Force recomputation of already cached eval results." 438 ) 439 args = parser.parse_args() 440 441 _, gt_paths = _get_livecell_paths(args.input, "test") 442 443 experiment_folder = args.experiment_folder 444 save_root = os.path.join(experiment_folder, "results") 445 446 inference_root_names = [ 447 "amg/inference", "instance_segmentation_with_decoder/inference", "start_with_box", "start_with_point" 448 ] 449 for inf_root in inference_root_names: 450 pred_root = os.path.join(experiment_folder, inf_root) 451 if not os.path.exists(pred_root): 452 print( 453 f"The inference for '{inf_root}' were not generated.", 454 "Please run the inference first to evaluate on the predictions." 455 ) 456 continue 457 458 if inf_root.startswith("start_with"): 459 evaluation.run_evaluation_for_iterative_prompting( 460 gt_paths=gt_paths, 461 prediction_root=pred_root, 462 experiment_folder=experiment_folder, 463 start_with_box_prompt=(inf_root == "start_with_box"), 464 overwrite_results=args.force 465 ) 466 else: 467 pred_paths = sorted(glob(os.path.join(pred_root, "*"))) 468 save_name = inf_root.split("/")[0] 469 save_path = os.path.join(save_root, f"{save_name}.csv") 470 471 if args.force and os.path.exists(save_path): 472 os.remove(save_path) 473 474 results = evaluation.run_evaluation( 475 gt_paths=gt_paths, 476 prediction_paths=pred_paths, 477 save_path=save_path, 478 ) 479 print(results)
CELL_TYPES =
['A172', 'BT474', 'BV2', 'Huh7', 'MCF7', 'SHSY5Y', 'SkBr3', 'SKOV3']
def
livecell_inference( checkpoint: Union[str, os.PathLike], input_folder: Union[str, os.PathLike], model_type: str, experiment_folder: Union[str, os.PathLike], use_points: bool, use_boxes: bool, n_positives: Optional[int] = None, n_negatives: Optional[int] = None, prompt_folder: Union[os.PathLike, str, NoneType] = None, predictor: Optional[segment_anything.predictor.SamPredictor] = None) -> None:
81def livecell_inference( 82 checkpoint: Union[str, os.PathLike], 83 input_folder: Union[str, os.PathLike], 84 model_type: str, 85 experiment_folder: Union[str, os.PathLike], 86 use_points: bool, 87 use_boxes: bool, 88 n_positives: Optional[int] = None, 89 n_negatives: Optional[int] = None, 90 prompt_folder: Optional[Union[str, os.PathLike]] = None, 91 predictor: Optional[SamPredictor] = None, 92) -> None: 93 """Run inference for livecell with a fixed prompt setting. 94 95 Args: 96 checkpoint: The segment anything model checkpoint. 97 input_folder: The folder with the livecell data. 98 model_type: The type of the segment anything model. 99 experiment_folder: The folder where to save all data associated with the experiment. 100 use_points: Whether to use point prompts. 101 use_boxes: Whether to use box prompts. 102 n_positives: The number of positive point prompts. 103 n_negatives: The number of negative point prompts. 104 prompt_folder: The folder where the prompts should be saved. 105 predictor: The segment anything predictor. 106 """ 107 image_paths, gt_paths = _get_livecell_paths(input_folder) 108 if predictor is None: 109 predictor = get_sam_model(model_type=model_type, checkpoint_path=checkpoint) 110 111 if use_boxes and use_points: 112 assert (n_positives is not None) and (n_negatives is not None) 113 setting_name = f"box/p{n_positives}-n{n_negatives}" 114 elif use_boxes: 115 setting_name = "box/p0-n0" 116 elif use_points: 117 assert (n_positives is not None) and (n_negatives is not None) 118 setting_name = f"points/p{n_positives}-n{n_negatives}" 119 else: 120 raise ValueError("You need to use at least one of point or box prompts.") 121 122 # we organize all folders with data from this experiment beneath 'experiment_folder' 123 prediction_folder = os.path.join(experiment_folder, setting_name) # where the predicted segmentations are saved 124 os.makedirs(prediction_folder, exist_ok=True) 125 embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved 126 os.makedirs(embedding_folder, exist_ok=True) 127 128 # NOTE: we can pass an external prompt folder, to make re-use prompts from another experiment 129 # for reproducibility / fair comparison of results 130 if prompt_folder is None: 131 prompt_folder = os.path.join(experiment_folder, "prompts") 132 os.makedirs(prompt_folder, exist_ok=True) 133 134 inference.run_inference_with_prompts( 135 predictor, 136 image_paths, 137 gt_paths, 138 embedding_dir=embedding_folder, 139 prediction_dir=prediction_folder, 140 prompt_save_dir=prompt_folder, 141 use_points=use_points, 142 use_boxes=use_boxes, 143 n_positives=n_positives, 144 n_negatives=n_negatives, 145 )
Run inference for livecell with a fixed prompt setting.
Arguments:
- checkpoint: The segment anything model checkpoint.
- input_folder: The folder with the livecell data.
- model_type: The type of the segment anything model.
- experiment_folder: The folder where to save all data associated with the experiment.
- use_points: Whether to use point prompts.
- use_boxes: Whether to use box prompts.
- n_positives: The number of positive point prompts.
- n_negatives: The number of negative point prompts.
- prompt_folder: The folder where the prompts should be saved.
- predictor: The segment anything predictor.
def
run_livecell_precompute_embeddings( checkpoint: Union[str, os.PathLike], input_folder: Union[str, os.PathLike], model_type: str, experiment_folder: Union[str, os.PathLike], n_val_per_cell_type: int = 25) -> None:
148def run_livecell_precompute_embeddings( 149 checkpoint: Union[str, os.PathLike], 150 input_folder: Union[str, os.PathLike], 151 model_type: str, 152 experiment_folder: Union[str, os.PathLike], 153 n_val_per_cell_type: int = 25, 154) -> None: 155 """Run precomputation of val and test image embeddings for livecell. 156 157 Args: 158 checkpoint: The segment anything model checkpoint. 159 input_folder: The folder with the livecell data. 160 model_type: The type of the segmenta anything model. 161 experiment_folder: The folder where to save all data associated with the experiment. 162 n_val_per_cell_type: The number of validation images per cell type. 163 """ 164 embedding_folder = os.path.join(experiment_folder, "embeddings") # where the embeddings will be saved 165 os.makedirs(embedding_folder, exist_ok=True) 166 167 predictor = get_sam_model(model_type=model_type, checkpoint_path=checkpoint) 168 169 val_image_paths, _ = _get_livecell_paths(input_folder, "val", n_val_per_cell_type=n_val_per_cell_type) 170 test_image_paths, _ = _get_livecell_paths(input_folder, "test") 171 172 precompute_all_embeddings(predictor, val_image_paths, embedding_folder) 173 precompute_all_embeddings(predictor, test_image_paths, embedding_folder)
Run precomputation of val and test image embeddings for livecell.
Arguments:
- checkpoint: The segment anything model checkpoint.
- input_folder: The folder with the livecell data.
- model_type: The type of the segmenta anything model.
- experiment_folder: The folder where to save all data associated with the experiment.
- n_val_per_cell_type: The number of validation images per cell type.
def
run_livecell_iterative_prompting( checkpoint: Union[str, os.PathLike], input_folder: Union[str, os.PathLike], model_type: str, experiment_folder: Union[str, os.PathLike], start_with_box: bool = False, use_masks: bool = False) -> str:
176def run_livecell_iterative_prompting( 177 checkpoint: Union[str, os.PathLike], 178 input_folder: Union[str, os.PathLike], 179 model_type: str, 180 experiment_folder: Union[str, os.PathLike], 181 start_with_box: bool = False, 182 use_masks: bool = False, 183) -> str: 184 """Run inference on livecell with iterative prompting setting. 185 186 Args: 187 checkpoint: The segment anything model checkpoint. 188 input_folder: The folder with the livecell data. 189 model_type: The type of the segment anything model. 190 experiment_folder: The folder where to save all data associated with the experiment. 191 start_with_box_prompt: Whether to use the first prompt as bounding box or a single point. 192 use_masks: Whether to make use of logits from previous prompt-based segmentation. 193 194 """ 195 embedding_folder = os.path.join(experiment_folder, "embeddings") # where the embeddings will be saved 196 os.makedirs(embedding_folder, exist_ok=True) 197 198 predictor = get_sam_model(model_type=model_type, checkpoint_path=checkpoint) 199 200 # where the predictions are saved 201 prediction_folder = os.path.join( 202 experiment_folder, "start_with_box" if start_with_box else "start_with_point" 203 ) 204 205 image_paths, gt_paths = _get_livecell_paths(input_folder, "test") 206 207 inference.run_inference_with_iterative_prompting( 208 predictor=predictor, 209 image_paths=image_paths, 210 gt_paths=gt_paths, 211 embedding_dir=embedding_folder, 212 prediction_dir=prediction_folder, 213 start_with_box_prompt=start_with_box, 214 use_masks=use_masks, 215 ) 216 return prediction_folder
Run inference on livecell with iterative prompting setting.
Arguments:
- checkpoint: The segment anything model checkpoint.
- input_folder: The folder with the livecell data.
- model_type: The type of the segment anything model.
- experiment_folder: The folder where to save all data associated with the experiment.
- start_with_box_prompt: Whether to use the first prompt as bounding box or a single point.
- use_masks: Whether to make use of logits from previous prompt-based segmentation.
def
run_livecell_amg( checkpoint: Union[str, os.PathLike], input_folder: Union[str, os.PathLike], model_type: str, experiment_folder: Union[str, os.PathLike], iou_thresh_values: Optional[List[float]] = None, stability_score_values: Optional[List[float]] = None, verbose_gs: bool = False, n_val_per_cell_type: int = 25) -> str:
219def run_livecell_amg( 220 checkpoint: Union[str, os.PathLike], 221 input_folder: Union[str, os.PathLike], 222 model_type: str, 223 experiment_folder: Union[str, os.PathLike], 224 iou_thresh_values: Optional[List[float]] = None, 225 stability_score_values: Optional[List[float]] = None, 226 verbose_gs: bool = False, 227 n_val_per_cell_type: int = 25, 228) -> str: 229 """Run automatic mask generation grid-search and inference for livecell. 230 231 Args: 232 checkpoint: The segment anything model checkpoint. 233 input_folder: The folder with the livecell data. 234 model_type: The type of the segmenta anything model. 235 experiment_folder: The folder where to save all data associated with the experiment. 236 iou_thresh_values: The values for `pred_iou_thresh` used in the gridsearch. 237 By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. 238 stability_score_values: The values for `stability_score_thresh` used in the gridsearch. 239 By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. 240 verbose_gs: Whether to run the gridsearch for individual images in a verbose mode. 241 n_val_per_cell_type: The number of validation images per cell type. 242 243 Returns: 244 The path where the predicted images are stored. 245 """ 246 embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved 247 os.makedirs(embedding_folder, exist_ok=True) 248 249 predictor = get_sam_model(model_type=model_type, checkpoint_path=checkpoint) 250 amg = AutomaticMaskGenerator(predictor) 251 amg_prefix = "amg" 252 253 # where the predictions are saved 254 prediction_folder = os.path.join(experiment_folder, amg_prefix, "inference") 255 os.makedirs(prediction_folder, exist_ok=True) 256 257 # where the grid-search results are saved 258 gs_result_folder = os.path.join(experiment_folder, amg_prefix, "grid_search") 259 os.makedirs(gs_result_folder, exist_ok=True) 260 261 val_image_paths, val_gt_paths = _get_livecell_paths(input_folder, "val", n_val_per_cell_type=n_val_per_cell_type) 262 test_image_paths, _ = _get_livecell_paths(input_folder, "test") 263 264 grid_search_values = instance_segmentation.default_grid_search_values_amg( 265 iou_thresh_values=iou_thresh_values, 266 stability_score_values=stability_score_values, 267 ) 268 269 instance_segmentation.run_instance_segmentation_grid_search_and_inference( 270 amg, grid_search_values, val_image_paths, val_gt_paths, test_image_paths, 271 embedding_folder, prediction_folder, gs_result_folder, verbose_gs=verbose_gs 272 ) 273 return prediction_folder
Run automatic mask generation grid-search and inference for livecell.
Arguments:
- checkpoint: The segment anything model checkpoint.
- input_folder: The folder with the livecell data.
- model_type: The type of the segmenta anything model.
- experiment_folder: The folder where to save all data associated with the experiment.
- iou_thresh_values: The values for
pred_iou_thresh
used in the gridsearch. By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. - stability_score_values: The values for
stability_score_thresh
used in the gridsearch. By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. - verbose_gs: Whether to run the gridsearch for individual images in a verbose mode.
- n_val_per_cell_type: The number of validation images per cell type.
Returns:
The path where the predicted images are stored.
def
run_livecell_instance_segmentation_with_decoder( checkpoint: Union[str, os.PathLike], input_folder: Union[str, os.PathLike], model_type: str, experiment_folder: Union[str, os.PathLike], center_distance_threshold_values: Optional[List[float]] = None, boundary_distance_threshold_values: Optional[List[float]] = None, distance_smoothing_values: Optional[List[float]] = None, min_size_values: Optional[List[float]] = None, verbose_gs: bool = False, n_val_per_cell_type: int = 25) -> str:
276def run_livecell_instance_segmentation_with_decoder( 277 checkpoint: Union[str, os.PathLike], 278 input_folder: Union[str, os.PathLike], 279 model_type: str, 280 experiment_folder: Union[str, os.PathLike], 281 center_distance_threshold_values: Optional[List[float]] = None, 282 boundary_distance_threshold_values: Optional[List[float]] = None, 283 distance_smoothing_values: Optional[List[float]] = None, 284 min_size_values: Optional[List[float]] = None, 285 verbose_gs: bool = False, 286 n_val_per_cell_type: int = 25, 287) -> str: 288 """Run automatic mask generation grid-search and inference for livecell. 289 290 Args: 291 checkpoint: The segment anything model checkpoint. 292 input_folder: The folder with the livecell data. 293 model_type: The type of the segmenta anything model. 294 experiment_folder: The folder where to save all data associated with the experiment. 295 center_distance_threshold_values: The values for `center_distance_threshold` used in the gridsearch. 296 By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used. 297 boundary_distance_threshold_values: The values for `boundary_distance_threshold` used in the gridsearch. 298 By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used. 299 distance_smoothing_values: The values for `distance_smoothing` used in the gridsearch. 300 By default values in the range from 1.0 to 2.0 with a stepsize of 0.1 will be used. 301 min_size_values: The values for `min_size` used in the gridsearch. 302 By default the values 50, 100 and 200 are used. 303 verbose_gs: Whether to run the gridsearch for individual images in a verbose mode. 304 n_val_per_cell_type: The number of validation images per cell type. 305 306 Returns: 307 The path where the predicted images are stored. 308 """ 309 embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved 310 os.makedirs(embedding_folder, exist_ok=True) 311 312 predictor, decoder = get_predictor_and_decoder(model_type=model_type, checkpoint_path=checkpoint) 313 segmenter = InstanceSegmentationWithDecoder(predictor, decoder) 314 seg_prefix = "instance_segmentation_with_decoder" 315 316 # where the predictions are saved 317 prediction_folder = os.path.join(experiment_folder, seg_prefix, "inference") 318 os.makedirs(prediction_folder, exist_ok=True) 319 320 # where the grid-search results are saved 321 gs_result_folder = os.path.join(experiment_folder, seg_prefix, "grid_search") 322 os.makedirs(gs_result_folder, exist_ok=True) 323 324 val_image_paths, val_gt_paths = _get_livecell_paths(input_folder, "val", n_val_per_cell_type=n_val_per_cell_type) 325 test_image_paths, _ = _get_livecell_paths(input_folder, "test") 326 327 grid_search_values = instance_segmentation.default_grid_search_values_instance_segmentation_with_decoder( 328 center_distance_threshold_values=center_distance_threshold_values, 329 boundary_distance_threshold_values=boundary_distance_threshold_values, 330 distance_smoothing_values=distance_smoothing_values, 331 min_size_values=min_size_values 332 ) 333 334 instance_segmentation.run_instance_segmentation_grid_search_and_inference( 335 segmenter, grid_search_values, val_image_paths, val_gt_paths, test_image_paths, embedding_dir=embedding_folder, 336 prediction_dir=prediction_folder, result_dir=gs_result_folder, verbose_gs=verbose_gs 337 ) 338 return prediction_folder
Run automatic mask generation grid-search and inference for livecell.
Arguments:
- checkpoint: The segment anything model checkpoint.
- input_folder: The folder with the livecell data.
- model_type: The type of the segmenta anything model.
- experiment_folder: The folder where to save all data associated with the experiment.
- center_distance_threshold_values: The values for
center_distance_threshold
used in the gridsearch. By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used. - boundary_distance_threshold_values: The values for
boundary_distance_threshold
used in the gridsearch. By default values in the range from 0.3 to 0.7 with a stepsize of 0.1 will be used. - distance_smoothing_values: The values for
distance_smoothing
used in the gridsearch. By default values in the range from 1.0 to 2.0 with a stepsize of 0.1 will be used. - min_size_values: The values for
min_size
used in the gridsearch. By default the values 50, 100 and 200 are used. - verbose_gs: Whether to run the gridsearch for individual images in a verbose mode.
- n_val_per_cell_type: The number of validation images per cell type.
Returns:
The path where the predicted images are stored.
def
run_livecell_inference() -> None:
341def run_livecell_inference() -> None: 342 """Run LIVECell inference with command line tool.""" 343 parser = argparse.ArgumentParser() 344 345 # the checkpoint, input and experiment folder 346 parser.add_argument( 347 "-c", "--ckpt", type=str, required=True, 348 help="Provide model checkpoints (vanilla / finetuned)." 349 ) 350 parser.add_argument( 351 "-i", "--input", type=str, required=True, 352 help="Provide the data directory for LIVECell Dataset." 353 ) 354 parser.add_argument( 355 "-e", "--experiment_folder", type=str, required=True, 356 help="Provide the path where all data for the inference run will be stored." 357 ) 358 parser.add_argument( 359 "-m", "--model", type=str, required=True, 360 help="Pass the checkpoint-specific model name being used for inference." 361 ) 362 363 # the experiment type: 364 # 1. precompute image embeddings 365 # 2. iterative prompting-based interactive instance segmentation (iterative prompting) 366 # - iterative prompting 367 # - starting with point 368 # - starting with box 369 # 3. automatic segmentation (auto) 370 # - automatic mask generation (amg) 371 # - automatic instance segmentation (ais) 372 373 parser.add_argument("-p", "--precompute_embeddings", action="store_true") 374 parser.add_argument("-ip", "--iterative_prompting", action="store_true") 375 parser.add_argument("-amg", "--auto_mask_generation", action="store_true") 376 parser.add_argument("-ais", "--auto_instance_segmentation", action="store_true") 377 378 # the prompt settings for starting iterative prompting for interactive instance segmentation 379 # - (default: start with points) 380 parser.add_argument( 381 "-b", "--start_with_box", action="store_true", help="Start with box for iterative prompt-based segmentation." 382 ) 383 parser.add_argument( 384 "--use_masks", action="store_true", 385 help="Whether to use logits from previous interactive segmentation as inputs for iterative prompting." 386 ) 387 parser.add_argument( 388 "--n_val_per_cell_type", default=25, type=int, 389 help="How many validation samples per cell type to be used for grid search." 390 ) 391 392 args = parser.parse_args() 393 if sum([args.iterative_prompting, args.auto_mask_generation, args.auto_instance_segmentation]) > 1: 394 warnings.warn( 395 "It's recommended to choose either from 'iterative_prompting', 'auto_mask_generation' or " 396 "'auto_instance_segmentation' at once, else it might take a while." 397 ) 398 399 if args.precompute_embeddings: 400 run_livecell_precompute_embeddings( 401 args.ckpt, args.input, args.model, args.experiment_folder, args.n_val_per_cell_type 402 ) 403 404 if args.iterative_prompting: 405 run_livecell_iterative_prompting( 406 args.ckpt, args.input, args.model, args.experiment_folder, 407 start_with_box=args.start_with_box, use_masks=args.use_masks 408 ) 409 410 if args.auto_instance_segmentation: 411 run_livecell_instance_segmentation_with_decoder( 412 args.ckpt, args.input, args.model, args.experiment_folder, n_val_per_cell_type=args.n_val_per_cell_type 413 ) 414 415 if args.auto_mask_generation: 416 run_livecell_amg( 417 args.ckpt, args.input, args.model, args.experiment_folder, n_val_per_cell_type=args.n_val_per_cell_type 418 )
Run LIVECell inference with command line tool.
def
run_livecell_evaluation() -> None:
426def run_livecell_evaluation() -> None: 427 """Run LIVECell evaluation with command line tool.""" 428 parser = argparse.ArgumentParser() 429 parser.add_argument( 430 "-i", "--input", required=True, help="Provide the data directory for LIVECell Dataset" 431 ) 432 parser.add_argument( 433 "-e", "--experiment_folder", required=True, 434 help="Provide the path where the inference data is stored." 435 ) 436 parser.add_argument( 437 "-f", "--force", action="store_true", 438 help="Force recomputation of already cached eval results." 439 ) 440 args = parser.parse_args() 441 442 _, gt_paths = _get_livecell_paths(args.input, "test") 443 444 experiment_folder = args.experiment_folder 445 save_root = os.path.join(experiment_folder, "results") 446 447 inference_root_names = [ 448 "amg/inference", "instance_segmentation_with_decoder/inference", "start_with_box", "start_with_point" 449 ] 450 for inf_root in inference_root_names: 451 pred_root = os.path.join(experiment_folder, inf_root) 452 if not os.path.exists(pred_root): 453 print( 454 f"The inference for '{inf_root}' were not generated.", 455 "Please run the inference first to evaluate on the predictions." 456 ) 457 continue 458 459 if inf_root.startswith("start_with"): 460 evaluation.run_evaluation_for_iterative_prompting( 461 gt_paths=gt_paths, 462 prediction_root=pred_root, 463 experiment_folder=experiment_folder, 464 start_with_box_prompt=(inf_root == "start_with_box"), 465 overwrite_results=args.force 466 ) 467 else: 468 pred_paths = sorted(glob(os.path.join(pred_root, "*"))) 469 save_name = inf_root.split("/")[0] 470 save_path = os.path.join(save_root, f"{save_name}.csv") 471 472 if args.force and os.path.exists(save_path): 473 os.remove(save_path) 474 475 results = evaluation.run_evaluation( 476 gt_paths=gt_paths, 477 prediction_paths=pred_paths, 478 save_path=save_path, 479 ) 480 print(results)
Run LIVECell evaluation with command line tool.