micro_sam.sam_annotator.image_series_annotator
1import os 2import time 3from glob import glob 4from pathlib import Path 5from typing import List, Optional, Union, Tuple 6 7import numpy as np 8import imageio.v3 as imageio 9 10import torch 11 12import napari 13from magicgui import magicgui 14from qtpy import QtWidgets 15from qtpy.QtCore import QTimer 16 17from .. import util 18from . import _widgets as widgets 19from ._tooltips import get_tooltip 20from ._state import AnnotatorState 21from .annotator_2d import Annotator2d 22from .annotator_3d import Annotator3d 23from .util import _sync_embedding_widget 24from ..instance_segmentation import get_decoder 25from ..precompute_state import _precompute_state_for_files 26 27 28def _precompute( 29 images, model_type, embedding_path, 30 tile_shape, halo, precompute_amg_state, 31 checkpoint_path, device, ndim, prefer_decoder, 32): 33 t_start = time.time() 34 35 device = util.get_device(device) 36 predictor, state = util.get_sam_model( 37 model_type=model_type, checkpoint_path=checkpoint_path, device=device, return_state=True 38 ) 39 if prefer_decoder and "decoder_state" in state: 40 decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device) 41 else: 42 decoder = None 43 44 if embedding_path is None: 45 embedding_paths = [None] * len(images) 46 else: 47 _precompute_state_for_files( 48 predictor, images, embedding_path, ndim=ndim, tile_shape=tile_shape, halo=halo, 49 precompute_amg_state=precompute_amg_state, decoder=decoder, 50 ) 51 if isinstance(images[0], np.ndarray): 52 embedding_paths = [ 53 os.path.join(embedding_path, f"embedding_{i:05}.zarr") for i, path in enumerate(images) 54 ] 55 else: 56 embedding_paths = [ 57 os.path.join(embedding_path, f"{Path(path).stem}.zarr") for path in images 58 ] 59 assert all(os.path.exists(emb_path) for emb_path in embedding_paths) 60 61 t_run = time.time() - t_start 62 minutes = int(t_run // 60) 63 seconds = int(round(t_run % 60, 0)) 64 print("Precomputation took", t_run, f"seconds (= {minutes:02}:{seconds:02} minutes)") 65 66 return predictor, decoder, embedding_paths 67 68 69def _get_input_shape(image, is_volumetric=False): 70 if image.ndim == 2: 71 image_shape = image.shape 72 elif image.ndim == 3: 73 if is_volumetric: 74 image_shape = image.shape 75 else: 76 image_shape = image.shape[:-1] 77 elif image.ndim == 4: 78 image_shape = image.shape[:-1] 79 80 return image_shape 81 82 83def _initialize_annotator( 84 viewer, image, image_embedding_path, 85 model_type, halo, tile_shape, predictor, decoder, is_volumetric, 86 precompute_amg_state, checkpoint_path, device, embedding_path, 87 segmentation_path, initial_seg, 88): 89 if viewer is None: 90 viewer = napari.Viewer() 91 viewer.add_image(image, name="image") 92 93 state = AnnotatorState() 94 state.initialize_predictor( 95 image, model_type=model_type, save_path=image_embedding_path, halo=halo, tile_shape=tile_shape, 96 predictor=predictor, decoder=decoder, 97 ndim=3 if is_volumetric else 2, precompute_amg_state=precompute_amg_state, 98 checkpoint_path=checkpoint_path, device=device, skip_load=False, use_cli=True, 99 ) 100 state.image_shape = _get_input_shape(image, is_volumetric) 101 102 if is_volumetric: 103 if image.ndim not in [3, 4]: 104 raise ValueError(f"Invalid image dimensions for 3d annotator, expect 3 or 4 dimensions, got {image.ndim}") 105 annotator = Annotator3d(viewer, reset_state=False) 106 else: 107 if image.ndim not in (2, 3): 108 raise ValueError(f"Invalid image dimensions for 2d annotator, expect 2 or 3 dimensions, got {image.ndim}") 109 annotator = Annotator2d(viewer, reset_state=False) 110 111 if os.path.exists(segmentation_path): 112 segmentation_result = imageio.imread(segmentation_path) 113 else: 114 segmentation_result = None 115 if initial_seg is not None and segmentation_result is None: 116 segmentation_result = initial_seg if isinstance(initial_seg, np.ndarray) else imageio.imread(initial_seg) 117 annotator._update_image(segmentation_result=segmentation_result) 118 119 # Add the annotator widget to the viewer and sync widgets. 120 viewer.window.add_dock_widget(annotator) 121 _sync_embedding_widget( 122 widget=state.widgets["embeddings"], 123 model_type=model_type if checkpoint_path is None else state.predictor.model_type, 124 save_path=embedding_path, 125 checkpoint_path=checkpoint_path, 126 device=device, 127 tile_shape=tile_shape, 128 halo=halo, 129 ) 130 return viewer, annotator 131 132 133def image_series_annotator( 134 images: Union[List[Union[os.PathLike, str]], List[np.ndarray]], 135 output_folder: str, 136 model_type: str = util._DEFAULT_MODEL, 137 embedding_path: Optional[str] = None, 138 initial_segmentations: Optional[Union[List[Union[os.PathLike, str]], List[np.ndarray]]] = None, 139 tile_shape: Optional[Tuple[int, int]] = None, 140 halo: Optional[Tuple[int, int]] = None, 141 viewer: Optional["napari.viewer.Viewer"] = None, 142 return_viewer: bool = False, 143 precompute_amg_state: bool = False, 144 checkpoint_path: Optional[str] = None, 145 is_volumetric: bool = False, 146 device: Optional[Union[str, torch.device]] = None, 147 prefer_decoder: bool = True, 148 skip_segmented: bool = True, 149) -> Optional["napari.viewer.Viewer"]: 150 """Run the annotation tool for a series of images (supported for both 2d and 3d images). 151 152 Args: 153 images: List of the file paths or list of (set of) slices for the images to be annotated. 154 output_folder: The folder where the segmentation results are saved. 155 model_type: The Segment Anything model to use. For details on the available models check out 156 https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models. 157 embedding_path: Filepath where to save the embeddings. 158 initial_segmentations: Initial segmentations to be corrected. 159 By default no initial segmentations are loaded. 160 If given, the initial segmentations will be loaded into 'committed_objects'. 161 tile_shape: Shape of tiles for tiled embedding prediction. 162 If `None` then the whole image is passed to Segment Anything. 163 halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders. 164 viewer: The viewer to which the Segment Anything functionality should be added. 165 This enables using a pre-initialized viewer. 166 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 167 By default, does not return the napari viewer. 168 precompute_amg_state: Whether to precompute the state for automatic mask generation. 169 This will take more time when precomputing embeddings, but will then make 170 automatic mask generation much faster. By default, set to 'False'. 171 checkpoint_path: Path to a custom checkpoint from which to load the SAM model. 172 is_volumetric: Whether to use the 3d annotator. By default, set to 'False'. 173 prefer_decoder: Whether to use decoder based instance segmentation if 174 the model used has an additional decoder for instance segmentation. 175 By default, set to 'True'. 176 skip_segmented: Whether to skip images that were already segmented. 177 If set to False, then segmentations that already exist will be loaded 178 and used to populate the 'committed_objects' layer. 179 180 Returns: 181 The napari viewer, only returned if `return_viewer=True`. 182 """ 183 if initial_segmentations is not None and len(initial_segmentations) != len(images): 184 raise ValueError( 185 "You have passed initial segmentations, but the number of images and segmentations is not the same: " 186 f"{len(images)} != {len(initial_segmentations)}." 187 ) 188 189 end_msg = "You have annotated the last image. Do you wish to close napari?" 190 os.makedirs(output_folder, exist_ok=True) 191 192 # Precompute embeddings and amg state (if corresponding options set). 193 predictor, decoder, embedding_paths = _precompute( 194 images, model_type, 195 embedding_path, tile_shape, halo, precompute_amg_state, 196 checkpoint_path=checkpoint_path, device=device, 197 ndim=3 if is_volumetric else 2, prefer_decoder=prefer_decoder, 198 ) 199 200 next_image_id = 0 201 have_inputs_as_arrays = isinstance(images[next_image_id], np.ndarray) 202 203 def _get_save_path(image_path, current_idx): 204 if have_inputs_as_arrays: 205 fname = f"seg_{current_idx:05}.tif" 206 else: 207 fname = os.path.basename(image_path) 208 fname = os.path.splitext(fname)[0] + ".tif" 209 return os.path.join(output_folder, fname) 210 211 def _load_image(image_id): 212 image = images[next_image_id] 213 if not have_inputs_as_arrays: 214 image = imageio.imread(image) 215 image_embedding_path = embedding_paths[next_image_id] 216 return image, image_embedding_path 217 218 # Check which image to load next if we skip segmented images. 219 if skip_segmented: 220 while True: 221 if next_image_id == len(images): 222 print("All images have already been annotated and you have set 'skip_segmented=True'. Nothing to do.") 223 return 224 225 save_path = _get_save_path(images[next_image_id], next_image_id) 226 if not os.path.exists(save_path): 227 print("The first image to annotate is image number", next_image_id) 228 image, image_embedding_path = _load_image(next_image_id) 229 break 230 231 next_image_id += 1 232 233 else: 234 save_path = _get_save_path(images[next_image_id], next_image_id) 235 image, image_embedding_path = _load_image(next_image_id) 236 237 # Initialize the viewer and annotator for this image. 238 state = AnnotatorState() 239 viewer, annotator = _initialize_annotator( 240 viewer, image, image_embedding_path, 241 model_type, halo, tile_shape, predictor, decoder, is_volumetric, 242 precompute_amg_state, checkpoint_path, device, embedding_path, 243 save_path, None if initial_segmentations is None else initial_segmentations[next_image_id], 244 ) 245 246 def _save_segmentation(image_path, current_idx, segmentation): 247 save_path = _get_save_path(image_path, next_image_id) 248 imageio.imwrite(save_path, segmentation, compression="zlib") 249 250 # Add functionality for going to the next image. 251 @magicgui(call_button="Next Image [N]") 252 def next_image(*args): 253 nonlocal next_image_id 254 255 segmentation = viewer.layers["committed_objects"].data 256 abort = False 257 if segmentation.sum() == 0: 258 msg = "Nothing is segmented yet. Do you wish to continue to the next image?" 259 abort = widgets._generate_message("info", msg) 260 if abort: 261 return 262 263 # Save the current segmentation. 264 _save_segmentation(images[next_image_id], next_image_id, segmentation) 265 266 # Check if we are done. 267 if (next_image_id + 1) == len(images): 268 # Inform the user via dialog. 269 abort = widgets._generate_message("info", end_msg) 270 if not abort: 271 QTimer.singleShot(0, viewer.close) 272 return 273 274 # Clear the segmentation already to avoid lagging removal. 275 viewer.layers["committed_objects"].data = np.zeros_like(viewer.layers["committed_objects"].data) 276 277 # Go to the next image. 278 next_image_id += 1 279 280 # If we are skipping images that are already segmented, then check if we have to load the next image. 281 save_path = _get_save_path(images[next_image_id], next_image_id) 282 if skip_segmented: 283 segmentation_result = None 284 while os.path.exists(save_path): 285 next_image_id += 1 286 287 # Check if we are done. 288 if next_image_id == len(images): 289 # Inform the user via dialog. 290 abort = widgets._generate_message("info", end_msg) 291 if not abort: 292 viewer.close() 293 return 294 295 save_path = _get_save_path(images[next_image_id], next_image_id) 296 else: 297 if os.path.exists(save_path): 298 segmentation_result = imageio.imread(save_path) 299 else: 300 segmentation_result = None 301 302 # Load initial segmentation if it exists and if we don't have a segmentation result loaded yet. 303 if initial_segmentations is not None and segmentation_result is None: 304 initial_seg = initial_segmentations[next_image_id] 305 segmentation_result = initial_seg if isinstance(initial_seg, np.ndarray) else imageio.imread(initial_seg) 306 307 print( 308 "Loading next image:", images[next_image_id] if not have_inputs_as_arrays else f"at index {next_image_id}" 309 ) 310 311 if have_inputs_as_arrays: 312 image = images[next_image_id] 313 else: 314 image = imageio.imread(images[next_image_id]) 315 316 image_embedding_path = embedding_paths[next_image_id] 317 318 # Set the new image in the viewer, state and annotator. 319 viewer.layers["image"].data = image 320 321 if state.amg is not None: 322 state.amg.clear_state() 323 324 state.initialize_predictor( 325 image, model_type=model_type, ndim=3 if is_volumetric else 2, 326 save_path=image_embedding_path, 327 tile_shape=tile_shape, halo=halo, 328 predictor=predictor, decoder=decoder, 329 precompute_amg_state=precompute_amg_state, device=device, 330 skip_load=False, 331 ) 332 state.image_shape = _get_input_shape(image, is_volumetric) 333 334 annotator._update_image(segmentation_result=segmentation_result) 335 336 viewer.window.add_dock_widget(next_image) 337 338 @viewer.bind_key("n", overwrite=True) 339 def _next_image(viewer): 340 next_image(viewer) 341 342 if return_viewer: 343 return viewer 344 napari.run() 345 346 347def image_folder_annotator( 348 input_folder: str, 349 output_folder: str, 350 pattern: str = "*", 351 initial_segmentation_folder: Optional[str] = None, 352 initial_segmentation_pattern: str = "*", 353 viewer: Optional["napari.viewer.Viewer"] = None, 354 return_viewer: bool = False, 355 **kwargs 356) -> Optional["napari.viewer.Viewer"]: 357 """Run the 2d annotation tool for a series of images in a folder. 358 359 Args: 360 input_folder: The folder with the images to be annotated. 361 output_folder: The folder where the segmentation results are saved. 362 pattern: The glob pattern for loading files from `input_folder`. 363 By default all files will be loaded. 364 initial_segmentation_folder: A folder with initial segmentation results. 365 By default no initial segmentations are loaded. 366 initial_segmentation_pattern: The glob pattern for loading files from `initial_segmentation_folder`. 367 viewer: The viewer to which the Segment Anything functionality should be added. 368 This enables using a pre-initialized viewer. 369 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 370 By default, does not return the napari viewer. 371 kwargs: The keyword arguments for `micro_sam.sam_annotator.image_series_annotator`. 372 373 Returns: 374 The napari viewer, only returned if `return_viewer=True`. 375 """ 376 image_files = sorted(glob(os.path.join(input_folder, pattern))) 377 if initial_segmentation_folder is None: 378 initial_segmentations = None 379 else: 380 initial_segmentations = sorted(glob(os.path.join( 381 initial_segmentation_folder, initial_segmentation_pattern 382 ))) 383 384 return image_series_annotator( 385 image_files, output_folder, 386 initial_segmentations=initial_segmentations, 387 viewer=viewer, return_viewer=return_viewer, **kwargs 388 ) 389 390 391class ImageSeriesAnnotator(widgets._WidgetBase): 392 def __init__(self, viewer: napari.Viewer, parent=None): 393 super().__init__(parent=parent) 394 self._viewer = viewer 395 396 # Create the UI: the general options. 397 self._create_options() 398 399 # Add the settings (collapsible). 400 self.layout().addWidget(self._create_settings()) 401 402 # Add the run button to trigger the embedding computation. 403 self.run_button = QtWidgets.QPushButton("Annotate Images") 404 self.run_button.clicked.connect(self.__call__) 405 self.layout().addWidget(self.run_button) 406 407 def _create_options(self): 408 self.folder = None 409 _, layout = self._add_path_param( 410 "folder", self.folder, "directory", 411 title="Input Folder", placeholder="Folder with images ...", 412 tooltip=get_tooltip("image_series_annotator", "folder") 413 ) 414 self.layout().addLayout(layout) 415 416 self.output_folder = None 417 _, layout = self._add_path_param( 418 "output_folder", self.output_folder, "directory", 419 title="Output Folder", placeholder="Folder to save the results ...", 420 tooltip=get_tooltip("image_series_annotator", "output_folder") 421 ) 422 self.layout().addLayout(layout) 423 424 # Add the model family widget section. 425 layout = self._create_model_section(create_layout=False) 426 self.layout().addLayout(layout) 427 428 def _create_settings(self): 429 setting_values = QtWidgets.QWidget() 430 setting_values.setLayout(QtWidgets.QVBoxLayout()) 431 432 # Add the model size widget section. 433 layout = self._create_model_size_section() 434 setting_values.layout().addLayout(layout) 435 436 self.pattern = "*" 437 _, layout = self._add_string_param( 438 "pattern", self.pattern, tooltip=get_tooltip("image_series_annotator", "pattern") 439 ) 440 setting_values.layout().addLayout(layout) 441 442 self.is_volumetric = False 443 setting_values.layout().addWidget(self._add_boolean_param( 444 "is_volumetric", self.is_volumetric, tooltip=get_tooltip("image_series_annotator", "is_volumetric") 445 )) 446 447 self.device = "auto" 448 device_options = ["auto"] + util._available_devices() 449 self.device_dropdown, layout = self._add_choice_param( 450 "device", self.device, device_options, tooltip=get_tooltip("embedding", "device") 451 ) 452 setting_values.layout().addLayout(layout) 453 454 self.embeddings_save_path = None 455 _, layout = self._add_path_param( 456 "embeddings_save_path", self.embeddings_save_path, "directory", title="embeddings save path:", 457 tooltip=get_tooltip("embedding", "embeddings_save_path") 458 ) 459 setting_values.layout().addLayout(layout) 460 461 self.custom_weights = None # select_file 462 _, layout = self._add_path_param( 463 "custom_weights", self.custom_weights, "file", title="custom weights path:", 464 tooltip=get_tooltip("embedding", "custom_weights") 465 ) 466 setting_values.layout().addLayout(layout) 467 468 self.tile_x, self.tile_y = 0, 0 469 self.tile_x_param, self.tile_y_param, layout = self._add_shape_param( 470 ("tile_x", "tile_y"), (self.tile_x, self.tile_y), min_val=0, max_val=2048, step=16, 471 tooltip=get_tooltip("embedding", "tiling") 472 ) 473 setting_values.layout().addLayout(layout) 474 475 self.halo_x, self.halo_y = 0, 0 476 self.halo_x_param, self.halo_y_param, layout = self._add_shape_param( 477 ("halo_x", "halo_y"), (self.halo_x, self.halo_y), min_val=0, max_val=512, 478 tooltip=get_tooltip("embedding", "halo") 479 ) 480 setting_values.layout().addLayout(layout) 481 482 settings = widgets._make_collapsible(setting_values, title="Advanced Settings") 483 return settings 484 485 def _validate_inputs(self): 486 missing_data = self.folder is None or len(glob(os.path.join(self.folder, self.pattern))) == 0 487 missing_output = self.output_folder is None 488 if missing_data or missing_output: 489 msg = "" 490 if missing_data: 491 msg += "The input folder is missing or empty. " 492 if missing_output: 493 msg += "The output folder is missing." 494 return widgets._generate_message("error", msg) 495 return False 496 497 def __call__(self, skip_validate=False): 498 self._validate_model_type_and_custom_weights() 499 500 if not skip_validate and self._validate_inputs(): 501 return 502 tile_shape, halo = widgets._process_tiling_inputs(self.tile_x, self.tile_y, self.halo_x, self.halo_y) 503 504 image_folder_annotator( 505 input_folder=self.folder, 506 output_folder=self.output_folder, 507 pattern=self.pattern, 508 model_type=self.model_type, 509 embedding_path=self.embeddings_save_path, 510 tile_shape=tile_shape, halo=halo, checkpoint_path=self.custom_weights, 511 device=self.device, is_volumetric=self.is_volumetric, 512 viewer=self._viewer, return_viewer=True, 513 ) 514 515 516def main(): 517 """@private""" 518 import argparse 519 520 available_models = list(util.get_model_names()) 521 available_models = ", ".join(available_models) 522 523 parser = argparse.ArgumentParser(description="Annotate a series of images from a folder.") 524 parser.add_argument( 525 "-i", "--input_folder", required=True, 526 help="The folder containing the image data. The data can be stored in any common format (tif, jpg, png, ...)." 527 ) 528 parser.add_argument( 529 "-o", "--output_folder", required=True, 530 help="The folder where the segmentation results will be stored." 531 ) 532 parser.add_argument( 533 "-p", "--pattern", default="*", 534 help="The pattern to select the images to annotator from the input folder. E.g. *.tif to annotate all tifs." 535 "By default all files in the folder will be loaded and annotated." 536 ) 537 parser.add_argument( 538 "--initial_segmentation_folder", 539 help="A folder with initial segmentation results. By default no initial segmentations are loaded." 540 ) 541 parser.add_argument( 542 "--initial_segmentation_pattern", 543 help="The glob pattern for loading files from `initial_segmentation_folder`." 544 ) 545 parser.add_argument( 546 "-e", "--embedding_path", 547 help="The filepath for saving/loading the pre-computed image embeddings. " 548 "NOTE: It is recommended to pass this argument and store the embeddings, " 549 "otherwise they will be recomputed every time (which can take a long time)." 550 ) 551 parser.add_argument( 552 "-m", "--model_type", default=util._DEFAULT_MODEL, 553 help=f"The segment anything model that will be used, one of {available_models}." 554 ) 555 parser.add_argument( 556 "-c", "--checkpoint", default=None, 557 help="Checkpoint from which the SAM model will be loaded loaded." 558 ) 559 parser.add_argument( 560 "-d", "--device", default=None, 561 help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)." 562 "By default the most performant available device will be selected." 563 ) 564 parser.add_argument( 565 "--is_volumetric", action="store_true", help="Whether to use the 3d annotator for a set of 3d volumes." 566 ) 567 568 parser.add_argument( 569 "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction", default=None 570 ) 571 parser.add_argument( 572 "--halo", nargs="+", type=int, help="The halo for using tiled prediction", default=None 573 ) 574 parser.add_argument("--precompute_amg_state", action="store_true") 575 parser.add_argument("--prefer_decoder", action="store_false") 576 parser.add_argument("--skip_segmented", action="store_false") 577 578 args = parser.parse_args() 579 580 image_folder_annotator( 581 args.input_folder, args.output_folder, args.pattern, 582 initial_segmentation_folder=args.initial_segmentation_folder, 583 initial_segmentation_pattern=args.initial_segmentation_pattern, 584 embedding_path=args.embedding_path, model_type=args.model_type, 585 tile_shape=args.tile_shape, halo=args.halo, precompute_amg_state=args.precompute_amg_state, 586 checkpoint_path=args.checkpoint, device=args.device, is_volumetric=args.is_volumetric, 587 prefer_decoder=args.prefer_decoder, skip_segmented=args.skip_segmented 588 )
def
image_series_annotator( images: Union[List[Union[os.PathLike, str]], List[numpy.ndarray]], output_folder: str, model_type: str = 'vit_b_lm', embedding_path: Optional[str] = None, initial_segmentations: Union[List[Union[os.PathLike, str]], List[numpy.ndarray], NoneType] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, viewer: Optional[napari.viewer.Viewer] = None, return_viewer: bool = False, precompute_amg_state: bool = False, checkpoint_path: Optional[str] = None, is_volumetric: bool = False, device: Union[str, torch.device, NoneType] = None, prefer_decoder: bool = True, skip_segmented: bool = True) -> Optional[napari.viewer.Viewer]:
134def image_series_annotator( 135 images: Union[List[Union[os.PathLike, str]], List[np.ndarray]], 136 output_folder: str, 137 model_type: str = util._DEFAULT_MODEL, 138 embedding_path: Optional[str] = None, 139 initial_segmentations: Optional[Union[List[Union[os.PathLike, str]], List[np.ndarray]]] = None, 140 tile_shape: Optional[Tuple[int, int]] = None, 141 halo: Optional[Tuple[int, int]] = None, 142 viewer: Optional["napari.viewer.Viewer"] = None, 143 return_viewer: bool = False, 144 precompute_amg_state: bool = False, 145 checkpoint_path: Optional[str] = None, 146 is_volumetric: bool = False, 147 device: Optional[Union[str, torch.device]] = None, 148 prefer_decoder: bool = True, 149 skip_segmented: bool = True, 150) -> Optional["napari.viewer.Viewer"]: 151 """Run the annotation tool for a series of images (supported for both 2d and 3d images). 152 153 Args: 154 images: List of the file paths or list of (set of) slices for the images to be annotated. 155 output_folder: The folder where the segmentation results are saved. 156 model_type: The Segment Anything model to use. For details on the available models check out 157 https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models. 158 embedding_path: Filepath where to save the embeddings. 159 initial_segmentations: Initial segmentations to be corrected. 160 By default no initial segmentations are loaded. 161 If given, the initial segmentations will be loaded into 'committed_objects'. 162 tile_shape: Shape of tiles for tiled embedding prediction. 163 If `None` then the whole image is passed to Segment Anything. 164 halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders. 165 viewer: The viewer to which the Segment Anything functionality should be added. 166 This enables using a pre-initialized viewer. 167 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 168 By default, does not return the napari viewer. 169 precompute_amg_state: Whether to precompute the state for automatic mask generation. 170 This will take more time when precomputing embeddings, but will then make 171 automatic mask generation much faster. By default, set to 'False'. 172 checkpoint_path: Path to a custom checkpoint from which to load the SAM model. 173 is_volumetric: Whether to use the 3d annotator. By default, set to 'False'. 174 prefer_decoder: Whether to use decoder based instance segmentation if 175 the model used has an additional decoder for instance segmentation. 176 By default, set to 'True'. 177 skip_segmented: Whether to skip images that were already segmented. 178 If set to False, then segmentations that already exist will be loaded 179 and used to populate the 'committed_objects' layer. 180 181 Returns: 182 The napari viewer, only returned if `return_viewer=True`. 183 """ 184 if initial_segmentations is not None and len(initial_segmentations) != len(images): 185 raise ValueError( 186 "You have passed initial segmentations, but the number of images and segmentations is not the same: " 187 f"{len(images)} != {len(initial_segmentations)}." 188 ) 189 190 end_msg = "You have annotated the last image. Do you wish to close napari?" 191 os.makedirs(output_folder, exist_ok=True) 192 193 # Precompute embeddings and amg state (if corresponding options set). 194 predictor, decoder, embedding_paths = _precompute( 195 images, model_type, 196 embedding_path, tile_shape, halo, precompute_amg_state, 197 checkpoint_path=checkpoint_path, device=device, 198 ndim=3 if is_volumetric else 2, prefer_decoder=prefer_decoder, 199 ) 200 201 next_image_id = 0 202 have_inputs_as_arrays = isinstance(images[next_image_id], np.ndarray) 203 204 def _get_save_path(image_path, current_idx): 205 if have_inputs_as_arrays: 206 fname = f"seg_{current_idx:05}.tif" 207 else: 208 fname = os.path.basename(image_path) 209 fname = os.path.splitext(fname)[0] + ".tif" 210 return os.path.join(output_folder, fname) 211 212 def _load_image(image_id): 213 image = images[next_image_id] 214 if not have_inputs_as_arrays: 215 image = imageio.imread(image) 216 image_embedding_path = embedding_paths[next_image_id] 217 return image, image_embedding_path 218 219 # Check which image to load next if we skip segmented images. 220 if skip_segmented: 221 while True: 222 if next_image_id == len(images): 223 print("All images have already been annotated and you have set 'skip_segmented=True'. Nothing to do.") 224 return 225 226 save_path = _get_save_path(images[next_image_id], next_image_id) 227 if not os.path.exists(save_path): 228 print("The first image to annotate is image number", next_image_id) 229 image, image_embedding_path = _load_image(next_image_id) 230 break 231 232 next_image_id += 1 233 234 else: 235 save_path = _get_save_path(images[next_image_id], next_image_id) 236 image, image_embedding_path = _load_image(next_image_id) 237 238 # Initialize the viewer and annotator for this image. 239 state = AnnotatorState() 240 viewer, annotator = _initialize_annotator( 241 viewer, image, image_embedding_path, 242 model_type, halo, tile_shape, predictor, decoder, is_volumetric, 243 precompute_amg_state, checkpoint_path, device, embedding_path, 244 save_path, None if initial_segmentations is None else initial_segmentations[next_image_id], 245 ) 246 247 def _save_segmentation(image_path, current_idx, segmentation): 248 save_path = _get_save_path(image_path, next_image_id) 249 imageio.imwrite(save_path, segmentation, compression="zlib") 250 251 # Add functionality for going to the next image. 252 @magicgui(call_button="Next Image [N]") 253 def next_image(*args): 254 nonlocal next_image_id 255 256 segmentation = viewer.layers["committed_objects"].data 257 abort = False 258 if segmentation.sum() == 0: 259 msg = "Nothing is segmented yet. Do you wish to continue to the next image?" 260 abort = widgets._generate_message("info", msg) 261 if abort: 262 return 263 264 # Save the current segmentation. 265 _save_segmentation(images[next_image_id], next_image_id, segmentation) 266 267 # Check if we are done. 268 if (next_image_id + 1) == len(images): 269 # Inform the user via dialog. 270 abort = widgets._generate_message("info", end_msg) 271 if not abort: 272 QTimer.singleShot(0, viewer.close) 273 return 274 275 # Clear the segmentation already to avoid lagging removal. 276 viewer.layers["committed_objects"].data = np.zeros_like(viewer.layers["committed_objects"].data) 277 278 # Go to the next image. 279 next_image_id += 1 280 281 # If we are skipping images that are already segmented, then check if we have to load the next image. 282 save_path = _get_save_path(images[next_image_id], next_image_id) 283 if skip_segmented: 284 segmentation_result = None 285 while os.path.exists(save_path): 286 next_image_id += 1 287 288 # Check if we are done. 289 if next_image_id == len(images): 290 # Inform the user via dialog. 291 abort = widgets._generate_message("info", end_msg) 292 if not abort: 293 viewer.close() 294 return 295 296 save_path = _get_save_path(images[next_image_id], next_image_id) 297 else: 298 if os.path.exists(save_path): 299 segmentation_result = imageio.imread(save_path) 300 else: 301 segmentation_result = None 302 303 # Load initial segmentation if it exists and if we don't have a segmentation result loaded yet. 304 if initial_segmentations is not None and segmentation_result is None: 305 initial_seg = initial_segmentations[next_image_id] 306 segmentation_result = initial_seg if isinstance(initial_seg, np.ndarray) else imageio.imread(initial_seg) 307 308 print( 309 "Loading next image:", images[next_image_id] if not have_inputs_as_arrays else f"at index {next_image_id}" 310 ) 311 312 if have_inputs_as_arrays: 313 image = images[next_image_id] 314 else: 315 image = imageio.imread(images[next_image_id]) 316 317 image_embedding_path = embedding_paths[next_image_id] 318 319 # Set the new image in the viewer, state and annotator. 320 viewer.layers["image"].data = image 321 322 if state.amg is not None: 323 state.amg.clear_state() 324 325 state.initialize_predictor( 326 image, model_type=model_type, ndim=3 if is_volumetric else 2, 327 save_path=image_embedding_path, 328 tile_shape=tile_shape, halo=halo, 329 predictor=predictor, decoder=decoder, 330 precompute_amg_state=precompute_amg_state, device=device, 331 skip_load=False, 332 ) 333 state.image_shape = _get_input_shape(image, is_volumetric) 334 335 annotator._update_image(segmentation_result=segmentation_result) 336 337 viewer.window.add_dock_widget(next_image) 338 339 @viewer.bind_key("n", overwrite=True) 340 def _next_image(viewer): 341 next_image(viewer) 342 343 if return_viewer: 344 return viewer 345 napari.run()
Run the annotation tool for a series of images (supported for both 2d and 3d images).
Arguments:
- images: List of the file paths or list of (set of) slices for the images to be annotated.
- output_folder: The folder where the segmentation results are saved.
- model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
- embedding_path: Filepath where to save the embeddings.
- initial_segmentations: Initial segmentations to be corrected. By default no initial segmentations are loaded. If given, the initial segmentations will be loaded into 'committed_objects'.
- tile_shape: Shape of tiles for tiled embedding prediction.
If
Nonethen the whole image is passed to Segment Anything. - halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
- viewer: The viewer to which the Segment Anything functionality should be added. This enables using a pre-initialized viewer.
- return_viewer: Whether to return the napari viewer to further modify it before starting the tool. By default, does not return the napari viewer.
- precompute_amg_state: Whether to precompute the state for automatic mask generation. This will take more time when precomputing embeddings, but will then make automatic mask generation much faster. By default, set to 'False'.
- checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
- is_volumetric: Whether to use the 3d annotator. By default, set to 'False'.
- prefer_decoder: Whether to use decoder based instance segmentation if the model used has an additional decoder for instance segmentation. By default, set to 'True'.
- skip_segmented: Whether to skip images that were already segmented. If set to False, then segmentations that already exist will be loaded and used to populate the 'committed_objects' layer.
Returns:
The napari viewer, only returned if
return_viewer=True.
def
image_folder_annotator( input_folder: str, output_folder: str, pattern: str = '*', initial_segmentation_folder: Optional[str] = None, initial_segmentation_pattern: str = '*', viewer: Optional[napari.viewer.Viewer] = None, return_viewer: bool = False, **kwargs) -> Optional[napari.viewer.Viewer]:
348def image_folder_annotator( 349 input_folder: str, 350 output_folder: str, 351 pattern: str = "*", 352 initial_segmentation_folder: Optional[str] = None, 353 initial_segmentation_pattern: str = "*", 354 viewer: Optional["napari.viewer.Viewer"] = None, 355 return_viewer: bool = False, 356 **kwargs 357) -> Optional["napari.viewer.Viewer"]: 358 """Run the 2d annotation tool for a series of images in a folder. 359 360 Args: 361 input_folder: The folder with the images to be annotated. 362 output_folder: The folder where the segmentation results are saved. 363 pattern: The glob pattern for loading files from `input_folder`. 364 By default all files will be loaded. 365 initial_segmentation_folder: A folder with initial segmentation results. 366 By default no initial segmentations are loaded. 367 initial_segmentation_pattern: The glob pattern for loading files from `initial_segmentation_folder`. 368 viewer: The viewer to which the Segment Anything functionality should be added. 369 This enables using a pre-initialized viewer. 370 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 371 By default, does not return the napari viewer. 372 kwargs: The keyword arguments for `micro_sam.sam_annotator.image_series_annotator`. 373 374 Returns: 375 The napari viewer, only returned if `return_viewer=True`. 376 """ 377 image_files = sorted(glob(os.path.join(input_folder, pattern))) 378 if initial_segmentation_folder is None: 379 initial_segmentations = None 380 else: 381 initial_segmentations = sorted(glob(os.path.join( 382 initial_segmentation_folder, initial_segmentation_pattern 383 ))) 384 385 return image_series_annotator( 386 image_files, output_folder, 387 initial_segmentations=initial_segmentations, 388 viewer=viewer, return_viewer=return_viewer, **kwargs 389 )
Run the 2d annotation tool for a series of images in a folder.
Arguments:
- input_folder: The folder with the images to be annotated.
- output_folder: The folder where the segmentation results are saved.
- pattern: The glob pattern for loading files from
input_folder. By default all files will be loaded. - initial_segmentation_folder: A folder with initial segmentation results. By default no initial segmentations are loaded.
- initial_segmentation_pattern: The glob pattern for loading files from
initial_segmentation_folder. - viewer: The viewer to which the Segment Anything functionality should be added. This enables using a pre-initialized viewer.
- return_viewer: Whether to return the napari viewer to further modify it before starting the tool. By default, does not return the napari viewer.
- kwargs: The keyword arguments for
micro_sam.sam_annotator.image_series_annotator.
Returns:
The napari viewer, only returned if
return_viewer=True.
392class ImageSeriesAnnotator(widgets._WidgetBase): 393 def __init__(self, viewer: napari.Viewer, parent=None): 394 super().__init__(parent=parent) 395 self._viewer = viewer 396 397 # Create the UI: the general options. 398 self._create_options() 399 400 # Add the settings (collapsible). 401 self.layout().addWidget(self._create_settings()) 402 403 # Add the run button to trigger the embedding computation. 404 self.run_button = QtWidgets.QPushButton("Annotate Images") 405 self.run_button.clicked.connect(self.__call__) 406 self.layout().addWidget(self.run_button) 407 408 def _create_options(self): 409 self.folder = None 410 _, layout = self._add_path_param( 411 "folder", self.folder, "directory", 412 title="Input Folder", placeholder="Folder with images ...", 413 tooltip=get_tooltip("image_series_annotator", "folder") 414 ) 415 self.layout().addLayout(layout) 416 417 self.output_folder = None 418 _, layout = self._add_path_param( 419 "output_folder", self.output_folder, "directory", 420 title="Output Folder", placeholder="Folder to save the results ...", 421 tooltip=get_tooltip("image_series_annotator", "output_folder") 422 ) 423 self.layout().addLayout(layout) 424 425 # Add the model family widget section. 426 layout = self._create_model_section(create_layout=False) 427 self.layout().addLayout(layout) 428 429 def _create_settings(self): 430 setting_values = QtWidgets.QWidget() 431 setting_values.setLayout(QtWidgets.QVBoxLayout()) 432 433 # Add the model size widget section. 434 layout = self._create_model_size_section() 435 setting_values.layout().addLayout(layout) 436 437 self.pattern = "*" 438 _, layout = self._add_string_param( 439 "pattern", self.pattern, tooltip=get_tooltip("image_series_annotator", "pattern") 440 ) 441 setting_values.layout().addLayout(layout) 442 443 self.is_volumetric = False 444 setting_values.layout().addWidget(self._add_boolean_param( 445 "is_volumetric", self.is_volumetric, tooltip=get_tooltip("image_series_annotator", "is_volumetric") 446 )) 447 448 self.device = "auto" 449 device_options = ["auto"] + util._available_devices() 450 self.device_dropdown, layout = self._add_choice_param( 451 "device", self.device, device_options, tooltip=get_tooltip("embedding", "device") 452 ) 453 setting_values.layout().addLayout(layout) 454 455 self.embeddings_save_path = None 456 _, layout = self._add_path_param( 457 "embeddings_save_path", self.embeddings_save_path, "directory", title="embeddings save path:", 458 tooltip=get_tooltip("embedding", "embeddings_save_path") 459 ) 460 setting_values.layout().addLayout(layout) 461 462 self.custom_weights = None # select_file 463 _, layout = self._add_path_param( 464 "custom_weights", self.custom_weights, "file", title="custom weights path:", 465 tooltip=get_tooltip("embedding", "custom_weights") 466 ) 467 setting_values.layout().addLayout(layout) 468 469 self.tile_x, self.tile_y = 0, 0 470 self.tile_x_param, self.tile_y_param, layout = self._add_shape_param( 471 ("tile_x", "tile_y"), (self.tile_x, self.tile_y), min_val=0, max_val=2048, step=16, 472 tooltip=get_tooltip("embedding", "tiling") 473 ) 474 setting_values.layout().addLayout(layout) 475 476 self.halo_x, self.halo_y = 0, 0 477 self.halo_x_param, self.halo_y_param, layout = self._add_shape_param( 478 ("halo_x", "halo_y"), (self.halo_x, self.halo_y), min_val=0, max_val=512, 479 tooltip=get_tooltip("embedding", "halo") 480 ) 481 setting_values.layout().addLayout(layout) 482 483 settings = widgets._make_collapsible(setting_values, title="Advanced Settings") 484 return settings 485 486 def _validate_inputs(self): 487 missing_data = self.folder is None or len(glob(os.path.join(self.folder, self.pattern))) == 0 488 missing_output = self.output_folder is None 489 if missing_data or missing_output: 490 msg = "" 491 if missing_data: 492 msg += "The input folder is missing or empty. " 493 if missing_output: 494 msg += "The output folder is missing." 495 return widgets._generate_message("error", msg) 496 return False 497 498 def __call__(self, skip_validate=False): 499 self._validate_model_type_and_custom_weights() 500 501 if not skip_validate and self._validate_inputs(): 502 return 503 tile_shape, halo = widgets._process_tiling_inputs(self.tile_x, self.tile_y, self.halo_x, self.halo_y) 504 505 image_folder_annotator( 506 input_folder=self.folder, 507 output_folder=self.output_folder, 508 pattern=self.pattern, 509 model_type=self.model_type, 510 embedding_path=self.embeddings_save_path, 511 tile_shape=tile_shape, halo=halo, checkpoint_path=self.custom_weights, 512 device=self.device, is_volumetric=self.is_volumetric, 513 viewer=self._viewer, return_viewer=True, 514 )
QWidget(parent: typing.Optional[QWidget] = None, flags: Union[Qt.WindowFlags, Qt.WindowType] = Qt.WindowFlags())
ImageSeriesAnnotator(viewer: napari.viewer.Viewer, parent=None)
393 def __init__(self, viewer: napari.Viewer, parent=None): 394 super().__init__(parent=parent) 395 self._viewer = viewer 396 397 # Create the UI: the general options. 398 self._create_options() 399 400 # Add the settings (collapsible). 401 self.layout().addWidget(self._create_settings()) 402 403 # Add the run button to trigger the embedding computation. 404 self.run_button = QtWidgets.QPushButton("Annotate Images") 405 self.run_button.clicked.connect(self.__call__) 406 self.layout().addWidget(self.run_button)
Inherited Members
- PyQt5.QtWidgets.QWidget
- RenderFlag
- RenderFlags
- acceptDrops
- accessibleDescription
- accessibleName
- actionEvent
- actions
- activateWindow
- addAction
- addActions
- adjustSize
- autoFillBackground
- backgroundRole
- baseSize
- changeEvent
- childAt
- childrenRect
- childrenRegion
- clearFocus
- clearMask
- close
- closeEvent
- contentsMargins
- contentsRect
- contextMenuEvent
- contextMenuPolicy
- create
- createWindowContainer
- cursor
- destroy
- devType
- dragEnterEvent
- dragLeaveEvent
- dragMoveEvent
- dropEvent
- effectiveWinId
- ensurePolished
- enterEvent
- event
- find
- focusInEvent
- focusNextChild
- focusNextPrevChild
- focusOutEvent
- focusPolicy
- focusPreviousChild
- focusProxy
- focusWidget
- font
- fontInfo
- fontMetrics
- foregroundRole
- frameGeometry
- frameSize
- geometry
- getContentsMargins
- grab
- grabGesture
- grabKeyboard
- grabMouse
- grabShortcut
- graphicsEffect
- graphicsProxyWidget
- hasFocus
- hasHeightForWidth
- hasMouseTracking
- hasTabletTracking
- height
- heightForWidth
- hide
- hideEvent
- initPainter
- inputMethodEvent
- inputMethodHints
- inputMethodQuery
- insertAction
- insertActions
- isActiveWindow
- isAncestorOf
- isEnabled
- isEnabledTo
- isFullScreen
- isHidden
- isLeftToRight
- isMaximized
- isMinimized
- isModal
- isRightToLeft
- isVisible
- isVisibleTo
- isWindow
- isWindowModified
- keyPressEvent
- keyReleaseEvent
- keyboardGrabber
- layout
- layoutDirection
- leaveEvent
- locale
- lower
- mapFrom
- mapFromGlobal
- mapFromParent
- mapTo
- mapToGlobal
- mapToParent
- mask
- maximumHeight
- maximumSize
- maximumWidth
- metric
- minimumHeight
- minimumSize
- minimumSizeHint
- minimumWidth
- mouseDoubleClickEvent
- mouseGrabber
- mouseMoveEvent
- mousePressEvent
- mouseReleaseEvent
- move
- moveEvent
- nativeEvent
- nativeParentWidget
- nextInFocusChain
- normalGeometry
- overrideWindowFlags
- overrideWindowState
- paintEngine
- paintEvent
- palette
- parentWidget
- pos
- previousInFocusChain
- raise_
- rect
- releaseKeyboard
- releaseMouse
- releaseShortcut
- removeAction
- render
- repaint
- resize
- resizeEvent
- restoreGeometry
- saveGeometry
- screen
- scroll
- setAcceptDrops
- setAccessibleDescription
- setAccessibleName
- setAttribute
- setAutoFillBackground
- setBackgroundRole
- setBaseSize
- setContentsMargins
- setContextMenuPolicy
- setCursor
- setDisabled
- setEnabled
- setFixedHeight
- setFixedSize
- setFixedWidth
- setFocus
- setFocusPolicy
- setFocusProxy
- setFont
- setForegroundRole
- setGeometry
- setGraphicsEffect
- setHidden
- setInputMethodHints
- setLayout
- setLayoutDirection
- setLocale
- setMask
- setMaximumHeight
- setMaximumSize
- setMaximumWidth
- setMinimumHeight
- setMinimumSize
- setMinimumWidth
- setMouseTracking
- setPalette
- setParent
- setShortcutAutoRepeat
- setShortcutEnabled
- setSizeIncrement
- setSizePolicy
- setStatusTip
- setStyle
- setStyleSheet
- setTabOrder
- setTabletTracking
- setToolTip
- setToolTipDuration
- setUpdatesEnabled
- setVisible
- setWhatsThis
- setWindowFilePath
- setWindowFlag
- setWindowFlags
- setWindowIcon
- setWindowIconText
- setWindowModality
- setWindowModified
- setWindowOpacity
- setWindowRole
- setWindowState
- setWindowTitle
- show
- showEvent
- showFullScreen
- showMaximized
- showMinimized
- showNormal
- size
- sizeHint
- sizeIncrement
- sizePolicy
- stackUnder
- statusTip
- style
- styleSheet
- tabletEvent
- testAttribute
- toolTip
- toolTipDuration
- underMouse
- ungrabGesture
- unsetCursor
- unsetLayoutDirection
- unsetLocale
- update
- updateGeometry
- updateMicroFocus
- updatesEnabled
- visibleRegion
- whatsThis
- wheelEvent
- width
- winId
- window
- windowFilePath
- windowFlags
- windowHandle
- windowIcon
- windowIconText
- windowModality
- windowOpacity
- windowRole
- windowState
- windowTitle
- windowType
- x
- y
- DrawChildren
- DrawWindowBackground
- IgnoreMask
- windowIconTextChanged
- windowIconChanged
- windowTitleChanged
- customContextMenuRequested
- PyQt5.QtCore.QObject
- blockSignals
- childEvent
- children
- connectNotify
- customEvent
- deleteLater
- disconnect
- disconnectNotify
- dumpObjectInfo
- dumpObjectTree
- dynamicPropertyNames
- eventFilter
- findChild
- findChildren
- inherits
- installEventFilter
- isSignalConnected
- isWidgetType
- isWindowType
- killTimer
- metaObject
- moveToThread
- objectName
- parent
- property
- pyqtConfigure
- receivers
- removeEventFilter
- sender
- senderSignalIndex
- setObjectName
- setProperty
- signalsBlocked
- startTimer
- thread
- timerEvent
- tr
- staticMetaObject
- objectNameChanged
- destroyed
- PyQt5.QtGui.QPaintDevice
- PaintDeviceMetric
- colorCount
- depth
- devicePixelRatio
- devicePixelRatioF
- devicePixelRatioFScale
- heightMM
- logicalDpiX
- logicalDpiY
- paintingActive
- physicalDpiX
- physicalDpiY
- widthMM
- PdmDepth
- PdmDevicePixelRatio
- PdmDevicePixelRatioScaled
- PdmDpiX
- PdmDpiY
- PdmHeight
- PdmHeightMM
- PdmNumColors
- PdmPhysicalDpiX
- PdmPhysicalDpiY
- PdmWidth
- PdmWidthMM