micro_sam.sam_annotator.image_series_annotator
1import os 2import time 3from glob import glob 4from pathlib import Path 5from typing import List, Optional, Union, Tuple 6 7import numpy as np 8import imageio.v3 as imageio 9 10import torch 11 12import napari 13from magicgui import magicgui 14from qtpy import QtWidgets 15 16from .. import util 17from . import _widgets as widgets 18from ._tooltips import get_tooltip 19from ._state import AnnotatorState 20from .annotator_2d import Annotator2d 21from .annotator_3d import Annotator3d 22from .util import _sync_embedding_widget 23from ..instance_segmentation import get_decoder 24from ..precompute_state import _precompute_state_for_files 25 26 27def _precompute( 28 images, model_type, embedding_path, 29 tile_shape, halo, precompute_amg_state, 30 checkpoint_path, device, ndim, prefer_decoder, 31): 32 t_start = time.time() 33 34 device = util.get_device(device) 35 predictor, state = util.get_sam_model( 36 model_type=model_type, checkpoint_path=checkpoint_path, device=device, return_state=True 37 ) 38 if prefer_decoder and "decoder_state" in state: 39 decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device) 40 else: 41 decoder = None 42 43 if embedding_path is None: 44 embedding_paths = [None] * len(images) 45 else: 46 _precompute_state_for_files( 47 predictor, images, embedding_path, ndim=ndim, tile_shape=tile_shape, halo=halo, 48 precompute_amg_state=precompute_amg_state, decoder=decoder, 49 ) 50 if isinstance(images[0], np.ndarray): 51 embedding_paths = [ 52 os.path.join(embedding_path, f"embedding_{i:05}.zarr") for i, path in enumerate(images) 53 ] 54 else: 55 embedding_paths = [ 56 os.path.join(embedding_path, f"{Path(path).stem}.zarr") for path in images 57 ] 58 assert all(os.path.exists(emb_path) for emb_path in embedding_paths) 59 60 t_run = time.time() - t_start 61 minutes = int(t_run // 60) 62 seconds = int(round(t_run % 60, 0)) 63 print("Precomputation took", t_run, f"seconds (= {minutes:02}:{seconds:02} minutes)") 64 65 return predictor, decoder, embedding_paths 66 67 68def _get_input_shape(image, is_volumetric=False): 69 if image.ndim == 2: 70 image_shape = image.shape 71 elif image.ndim == 3: 72 if is_volumetric: 73 image_shape = image.shape 74 else: 75 image_shape = image.shape[:-1] 76 elif image.ndim == 4: 77 image_shape = image.shape[:-1] 78 79 return image_shape 80 81 82def _initialize_annotator( 83 viewer, image, image_embedding_path, 84 model_type, halo, tile_shape, predictor, decoder, is_volumetric, 85 precompute_amg_state, checkpoint_path, device, embedding_path, 86 segmentation_path, 87): 88 if viewer is None: 89 viewer = napari.Viewer() 90 viewer.add_image(image, name="image") 91 92 state = AnnotatorState() 93 state.initialize_predictor( 94 image, model_type=model_type, save_path=image_embedding_path, halo=halo, tile_shape=tile_shape, 95 predictor=predictor, decoder=decoder, 96 ndim=3 if is_volumetric else 2, precompute_amg_state=precompute_amg_state, 97 checkpoint_path=checkpoint_path, device=device, skip_load=False, 98 ) 99 state.image_shape = _get_input_shape(image, is_volumetric) 100 101 if is_volumetric: 102 if image.ndim not in [3, 4]: 103 raise ValueError(f"Invalid image dimensions for 3d annotator, expect 3 or 4 dimensions, got {image.ndim}") 104 annotator = Annotator3d(viewer) 105 else: 106 if image.ndim not in (2, 3): 107 raise ValueError(f"Invalid image dimensions for 2d annotator, expect 2 or 3 dimensions, got {image.ndim}") 108 annotator = Annotator2d(viewer) 109 110 if os.path.exists(segmentation_path): 111 segmentation_result = imageio.imread(segmentation_path) 112 else: 113 segmentation_result = None 114 annotator._update_image(segmentation_result=segmentation_result) 115 116 # Add the annotator widget to the viewer and sync widgets. 117 viewer.window.add_dock_widget(annotator) 118 _sync_embedding_widget( 119 widget=state.widgets["embeddings"], 120 model_type=model_type if checkpoint_path is None else state.predictor.model_type, 121 save_path=embedding_path, 122 checkpoint_path=checkpoint_path, 123 device=device, 124 tile_shape=tile_shape, 125 halo=halo, 126 ) 127 return viewer, annotator 128 129 130def image_series_annotator( 131 images: Union[List[Union[os.PathLike, str]], List[np.ndarray]], 132 output_folder: str, 133 model_type: str = util._DEFAULT_MODEL, 134 embedding_path: Optional[str] = None, 135 tile_shape: Optional[Tuple[int, int]] = None, 136 halo: Optional[Tuple[int, int]] = None, 137 viewer: Optional["napari.viewer.Viewer"] = None, 138 return_viewer: bool = False, 139 precompute_amg_state: bool = False, 140 checkpoint_path: Optional[str] = None, 141 is_volumetric: bool = False, 142 device: Optional[Union[str, torch.device]] = None, 143 prefer_decoder: bool = True, 144 skip_segmented: bool = True, 145) -> Optional["napari.viewer.Viewer"]: 146 """Run the annotation tool for a series of images (supported for both 2d and 3d images). 147 148 Args: 149 images: List of the file paths or list of (set of) slices for the images to be annotated. 150 output_folder: The folder where the segmentation results are saved. 151 model_type: The Segment Anything model to use. For details on the available models check out 152 https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models. 153 embedding_path: Filepath where to save the embeddings. 154 tile_shape: Shape of tiles for tiled embedding prediction. 155 If `None` then the whole image is passed to Segment Anything. 156 halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders. 157 viewer: The viewer to which the Segment Anything functionality should be added. 158 This enables using a pre-initialized viewer. 159 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 160 precompute_amg_state: Whether to precompute the state for automatic mask generation. 161 This will take more time when precomputing embeddings, but will then make 162 automatic mask generation much faster. 163 checkpoint_path: Path to a custom checkpoint from which to load the SAM model. 164 is_volumetric: Whether to use the 3d annotator. 165 prefer_decoder: Whether to use decoder based instance segmentation if 166 the model used has an additional decoder for instance segmentation. 167 skip_segmented: Whether to skip images that were already segmented. 168 If set to False, then segmentations that already exist will be loaded 169 and used to populate the 'committed_objects' layer. 170 171 Returns: 172 The napari viewer, only returned if `return_viewer=True`. 173 """ 174 end_msg = "You have annotated the last image. Do you wish to close napari?" 175 os.makedirs(output_folder, exist_ok=True) 176 177 # Precompute embeddings and amg state (if corresponding options set). 178 predictor, decoder, embedding_paths = _precompute( 179 images, model_type, 180 embedding_path, tile_shape, halo, precompute_amg_state, 181 checkpoint_path=checkpoint_path, device=device, 182 ndim=3 if is_volumetric else 2, prefer_decoder=prefer_decoder, 183 ) 184 185 next_image_id = 0 186 have_inputs_as_arrays = isinstance(images[next_image_id], np.ndarray) 187 188 def _get_save_path(image_path, current_idx): 189 if have_inputs_as_arrays: 190 fname = f"seg_{current_idx:05}.tif" 191 else: 192 fname = os.path.basename(image_path) 193 fname = os.path.splitext(fname)[0] + ".tif" 194 return os.path.join(output_folder, fname) 195 196 def _load_image(image_id): 197 image = images[next_image_id] 198 if not have_inputs_as_arrays: 199 image = imageio.imread(image) 200 image_embedding_path = embedding_paths[next_image_id] 201 return image, image_embedding_path 202 203 # Check which image to load next if we skip segmented images. 204 if skip_segmented: 205 while True: 206 if next_image_id == len(images): 207 print("All images have already been annotated and you have set 'skip_segmented=True'. Nothing to do.") 208 return 209 210 save_path = _get_save_path(images[next_image_id], next_image_id) 211 if not os.path.exists(save_path): 212 print("The first image to annotate is image number", next_image_id) 213 image, image_embedding_path = _load_image(next_image_id) 214 break 215 216 next_image_id += 1 217 218 else: 219 save_path = _get_save_path(images[next_image_id], next_image_id) 220 image, image_embedding_path = _load_image(next_image_id) 221 222 # Initialize the viewer and annotator for this image. 223 state = AnnotatorState() 224 viewer, annotator = _initialize_annotator( 225 viewer, image, image_embedding_path, 226 model_type, halo, tile_shape, predictor, decoder, is_volumetric, 227 precompute_amg_state, checkpoint_path, device, embedding_path, 228 save_path, 229 ) 230 231 def _save_segmentation(image_path, current_idx, segmentation): 232 save_path = _get_save_path(image_path, next_image_id) 233 imageio.imwrite(save_path, segmentation, compression="zlib") 234 235 # Add functionality for going to the next image. 236 @magicgui(call_button="Next Image [N]") 237 def next_image(*args): 238 nonlocal next_image_id 239 240 segmentation = viewer.layers["committed_objects"].data 241 abort = False 242 if segmentation.sum() == 0: 243 msg = "Nothing is segmented yet. Do you wish to continue to the next image?" 244 abort = widgets._generate_message("info", msg) 245 if abort: 246 return 247 248 # Save the current segmentation. 249 _save_segmentation(images[next_image_id], next_image_id, segmentation) 250 251 # Clear the segmentation already to avoid lagging removal. 252 viewer.layers["committed_objects"].data = np.zeros_like(viewer.layers["committed_objects"].data) 253 254 # Go to the next image. 255 next_image_id += 1 256 257 # Check if we are done. 258 if next_image_id == len(images): 259 # Inform the user via dialog. 260 abort = widgets._generate_message("info", end_msg) 261 if not abort: 262 viewer.close() 263 return 264 265 # If we are skipping images that are already segmented, then check if we have to load the next image. 266 save_path = _get_save_path(images[next_image_id], next_image_id) 267 if skip_segmented: 268 segmentation_result = None 269 while os.path.exists(save_path): 270 next_image_id += 1 271 272 # Check if we are done. 273 if next_image_id == len(images): 274 # Inform the user via dialog. 275 abort = widgets._generate_message("info", end_msg) 276 if not abort: 277 viewer.close() 278 return 279 280 save_path = _get_save_path(images[next_image_id], next_image_id) 281 else: 282 if os.path.exists(save_path): 283 segmentation_result = imageio.imread(save_path) 284 else: 285 segmentation_result = None 286 287 print( 288 "Loading next image:", images[next_image_id] if not have_inputs_as_arrays else f"at index {next_image_id}" 289 ) 290 291 if have_inputs_as_arrays: 292 image = images[next_image_id] 293 else: 294 image = imageio.imread(images[next_image_id]) 295 296 image_embedding_path = embedding_paths[next_image_id] 297 298 # Set the new image in the viewer, state and annotator. 299 viewer.layers["image"].data = image 300 301 if state.amg is not None: 302 state.amg.clear_state() 303 304 state.initialize_predictor( 305 image, model_type=model_type, ndim=3 if is_volumetric else 2, 306 save_path=image_embedding_path, 307 tile_shape=tile_shape, halo=halo, 308 predictor=predictor, decoder=decoder, 309 precompute_amg_state=precompute_amg_state, device=device, 310 skip_load=False, 311 ) 312 state.image_shape = _get_input_shape(image, is_volumetric) 313 314 annotator._update_image(segmentation_result=segmentation_result) 315 316 viewer.window.add_dock_widget(next_image) 317 318 @viewer.bind_key("n", overwrite=True) 319 def _next_image(viewer): 320 next_image(viewer) 321 322 if return_viewer: 323 return viewer 324 napari.run() 325 326 327def image_folder_annotator( 328 input_folder: str, 329 output_folder: str, 330 pattern: str = "*", 331 viewer: Optional["napari.viewer.Viewer"] = None, 332 return_viewer: bool = False, 333 **kwargs 334) -> Optional["napari.viewer.Viewer"]: 335 """Run the 2d annotation tool for a series of images in a folder. 336 337 Args: 338 input_folder: The folder with the images to be annotated. 339 output_folder: The folder where the segmentation results are saved. 340 pattern: The glob patter for loading files from `input_folder`. 341 By default all files will be loaded. 342 viewer: The viewer to which the Segment Anything functionality should be added. 343 This enables using a pre-initialized viewer. 344 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 345 kwargs: The keyword arguments for `micro_sam.sam_annotator.image_series_annotator`. 346 347 Returns: 348 The napari viewer, only returned if `return_viewer=True`. 349 """ 350 image_files = sorted(glob(os.path.join(input_folder, pattern))) 351 352 return image_series_annotator( 353 image_files, output_folder, viewer=viewer, return_viewer=return_viewer, **kwargs 354 ) 355 356 357class ImageSeriesAnnotator(widgets._WidgetBase): 358 def __init__(self, viewer: napari.Viewer, parent=None): 359 super().__init__(parent=parent) 360 self._viewer = viewer 361 362 # Create the UI: the general options. 363 self._create_options() 364 365 # Add the settings (collapsible). 366 self.layout().addWidget(self._create_settings()) 367 368 # Add the run button to trigger the embedding computation. 369 self.run_button = QtWidgets.QPushButton("Annotate Images") 370 self.run_button.clicked.connect(self.__call__) 371 self.layout().addWidget(self.run_button) 372 373 def _create_options(self): 374 self.folder = None 375 _, layout = self._add_path_param( 376 "folder", self.folder, "directory", 377 title="Input Folder", placeholder="Folder with images ...", 378 tooltip=get_tooltip("image_series_annotator", "folder") 379 ) 380 self.layout().addLayout(layout) 381 382 self.output_folder = None 383 _, layout = self._add_path_param( 384 "output_folder", self.output_folder, "directory", 385 title="Output Folder", placeholder="Folder to save the results ...", 386 tooltip=get_tooltip("image_series_annotator", "output_folder") 387 ) 388 self.layout().addLayout(layout) 389 390 # Add the model family widget section. 391 layout = self._create_model_section(create_layout=False) 392 self.layout().addLayout(layout) 393 394 def _create_settings(self): 395 setting_values = QtWidgets.QWidget() 396 setting_values.setLayout(QtWidgets.QVBoxLayout()) 397 398 # Add the model size widget section. 399 layout = self._create_model_size_section() 400 setting_values.layout().addLayout(layout) 401 402 self.pattern = "*" 403 _, layout = self._add_string_param( 404 "pattern", self.pattern, tooltip=get_tooltip("image_series_annotator", "pattern") 405 ) 406 setting_values.layout().addLayout(layout) 407 408 self.is_volumetric = False 409 setting_values.layout().addWidget(self._add_boolean_param( 410 "is_volumetric", self.is_volumetric, tooltip=get_tooltip("image_series_annotator", "is_volumetric") 411 )) 412 413 self.device = "auto" 414 device_options = ["auto"] + util._available_devices() 415 self.device_dropdown, layout = self._add_choice_param( 416 "device", self.device, device_options, tooltip=get_tooltip("embedding", "device") 417 ) 418 setting_values.layout().addLayout(layout) 419 420 self.embeddings_save_path = None 421 _, layout = self._add_path_param( 422 "embeddings_save_path", self.embeddings_save_path, "directory", title="embeddings save path:", 423 tooltip=get_tooltip("embedding", "embeddings_save_path") 424 ) 425 setting_values.layout().addLayout(layout) 426 427 self.custom_weights = None # select_file 428 _, layout = self._add_path_param( 429 "custom_weights", self.custom_weights, "file", title="custom weights path:", 430 tooltip=get_tooltip("embedding", "custom_weights") 431 ) 432 setting_values.layout().addLayout(layout) 433 434 self.tile_x, self.tile_y = 0, 0 435 self.tile_x_param, self.tile_y_param, layout = self._add_shape_param( 436 ("tile_x", "tile_y"), (self.tile_x, self.tile_y), min_val=0, max_val=2048, step=16, 437 tooltip=get_tooltip("embedding", "tiling") 438 ) 439 setting_values.layout().addLayout(layout) 440 441 self.halo_x, self.halo_y = 0, 0 442 self.halo_x_param, self.halo_y_param, layout = self._add_shape_param( 443 ("halo_x", "halo_y"), (self.halo_x, self.halo_y), min_val=0, max_val=512, 444 tooltip=get_tooltip("embedding", "halo") 445 ) 446 setting_values.layout().addLayout(layout) 447 448 settings = widgets._make_collapsible(setting_values, title="Advanced Settings") 449 return settings 450 451 def _validate_inputs(self): 452 missing_data = self.folder is None or len(glob(os.path.join(self.folder, self.pattern))) == 0 453 missing_output = self.output_folder is None 454 if missing_data or missing_output: 455 msg = "" 456 if missing_data: 457 msg += "The input folder is missing or empty. " 458 if missing_output: 459 msg += "The output folder is missing." 460 return widgets._generate_message("error", msg) 461 return False 462 463 def __call__(self, skip_validate=False): 464 self._validate_model_type_and_custom_weights() 465 466 if not skip_validate and self._validate_inputs(): 467 return 468 tile_shape, halo = widgets._process_tiling_inputs(self.tile_x, self.tile_y, self.halo_x, self.halo_y) 469 470 image_folder_annotator( 471 input_folder=self.folder, 472 output_folder=self.output_folder, 473 pattern=self.pattern, 474 model_type=self.model_type, 475 embedding_path=self.embeddings_save_path, 476 tile_shape=tile_shape, halo=halo, checkpoint_path=self.custom_weights, 477 device=self.device, is_volumetric=self.is_volumetric, 478 viewer=self._viewer, return_viewer=True, 479 ) 480 481 482def main(): 483 """@private""" 484 import argparse 485 486 available_models = list(util.get_model_names()) 487 available_models = ", ".join(available_models) 488 489 parser = argparse.ArgumentParser(description="Annotate a series of images from a folder.") 490 parser.add_argument( 491 "-i", "--input_folder", required=True, 492 help="The folder containing the image data. The data can be stored in any common format (tif, jpg, png, ...)." 493 ) 494 parser.add_argument( 495 "-o", "--output_folder", required=True, 496 help="The folder where the segmentation results will be stored." 497 ) 498 parser.add_argument( 499 "-p", "--pattern", default="*", 500 help="The pattern to select the images to annotator from the input folder. E.g. *.tif to annotate all tifs." 501 "By default all files in the folder will be loaded and annotated." 502 ) 503 parser.add_argument( 504 "-e", "--embedding_path", 505 help="The filepath for saving/loading the pre-computed image embeddings. " 506 "NOTE: It is recommended to pass this argument and store the embeddings, " 507 "otherwise they will be recomputed every time (which can take a long time)." 508 ) 509 parser.add_argument( 510 "-m", "--model_type", default=util._DEFAULT_MODEL, 511 help=f"The segment anything model that will be used, one of {available_models}." 512 ) 513 parser.add_argument( 514 "-c", "--checkpoint", default=None, 515 help="Checkpoint from which the SAM model will be loaded loaded." 516 ) 517 parser.add_argument( 518 "-d", "--device", default=None, 519 help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)." 520 "By default the most performant available device will be selected." 521 ) 522 parser.add_argument( 523 "--is_volumetric", action="store_true", help="Whether to use the 3d annotator for a set of 3d volumes." 524 ) 525 526 parser.add_argument( 527 "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction", default=None 528 ) 529 parser.add_argument( 530 "--halo", nargs="+", type=int, help="The halo for using tiled prediction", default=None 531 ) 532 parser.add_argument("--precompute_amg_state", action="store_true") 533 parser.add_argument("--prefer_decoder", action="store_false") 534 parser.add_argument("--skip_segmented", action="store_false") 535 536 args = parser.parse_args() 537 538 image_folder_annotator( 539 args.input_folder, args.output_folder, args.pattern, 540 embedding_path=args.embedding_path, model_type=args.model_type, 541 tile_shape=args.tile_shape, halo=args.halo, precompute_amg_state=args.precompute_amg_state, 542 checkpoint_path=args.checkpoint, device=args.device, is_volumetric=args.is_volumetric, 543 prefer_decoder=args.prefer_decoder, skip_segmented=args.skip_segmented 544 )
def
image_series_annotator( images: Union[List[Union[os.PathLike, str]], List[numpy.ndarray]], output_folder: str, model_type: str = 'vit_b_lm', embedding_path: Optional[str] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, viewer: Optional[napari.viewer.Viewer] = None, return_viewer: bool = False, precompute_amg_state: bool = False, checkpoint_path: Optional[str] = None, is_volumetric: bool = False, device: Union[str, torch.device, NoneType] = None, prefer_decoder: bool = True, skip_segmented: bool = True) -> Optional[napari.viewer.Viewer]:
131def image_series_annotator( 132 images: Union[List[Union[os.PathLike, str]], List[np.ndarray]], 133 output_folder: str, 134 model_type: str = util._DEFAULT_MODEL, 135 embedding_path: Optional[str] = None, 136 tile_shape: Optional[Tuple[int, int]] = None, 137 halo: Optional[Tuple[int, int]] = None, 138 viewer: Optional["napari.viewer.Viewer"] = None, 139 return_viewer: bool = False, 140 precompute_amg_state: bool = False, 141 checkpoint_path: Optional[str] = None, 142 is_volumetric: bool = False, 143 device: Optional[Union[str, torch.device]] = None, 144 prefer_decoder: bool = True, 145 skip_segmented: bool = True, 146) -> Optional["napari.viewer.Viewer"]: 147 """Run the annotation tool for a series of images (supported for both 2d and 3d images). 148 149 Args: 150 images: List of the file paths or list of (set of) slices for the images to be annotated. 151 output_folder: The folder where the segmentation results are saved. 152 model_type: The Segment Anything model to use. For details on the available models check out 153 https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models. 154 embedding_path: Filepath where to save the embeddings. 155 tile_shape: Shape of tiles for tiled embedding prediction. 156 If `None` then the whole image is passed to Segment Anything. 157 halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders. 158 viewer: The viewer to which the Segment Anything functionality should be added. 159 This enables using a pre-initialized viewer. 160 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 161 precompute_amg_state: Whether to precompute the state for automatic mask generation. 162 This will take more time when precomputing embeddings, but will then make 163 automatic mask generation much faster. 164 checkpoint_path: Path to a custom checkpoint from which to load the SAM model. 165 is_volumetric: Whether to use the 3d annotator. 166 prefer_decoder: Whether to use decoder based instance segmentation if 167 the model used has an additional decoder for instance segmentation. 168 skip_segmented: Whether to skip images that were already segmented. 169 If set to False, then segmentations that already exist will be loaded 170 and used to populate the 'committed_objects' layer. 171 172 Returns: 173 The napari viewer, only returned if `return_viewer=True`. 174 """ 175 end_msg = "You have annotated the last image. Do you wish to close napari?" 176 os.makedirs(output_folder, exist_ok=True) 177 178 # Precompute embeddings and amg state (if corresponding options set). 179 predictor, decoder, embedding_paths = _precompute( 180 images, model_type, 181 embedding_path, tile_shape, halo, precompute_amg_state, 182 checkpoint_path=checkpoint_path, device=device, 183 ndim=3 if is_volumetric else 2, prefer_decoder=prefer_decoder, 184 ) 185 186 next_image_id = 0 187 have_inputs_as_arrays = isinstance(images[next_image_id], np.ndarray) 188 189 def _get_save_path(image_path, current_idx): 190 if have_inputs_as_arrays: 191 fname = f"seg_{current_idx:05}.tif" 192 else: 193 fname = os.path.basename(image_path) 194 fname = os.path.splitext(fname)[0] + ".tif" 195 return os.path.join(output_folder, fname) 196 197 def _load_image(image_id): 198 image = images[next_image_id] 199 if not have_inputs_as_arrays: 200 image = imageio.imread(image) 201 image_embedding_path = embedding_paths[next_image_id] 202 return image, image_embedding_path 203 204 # Check which image to load next if we skip segmented images. 205 if skip_segmented: 206 while True: 207 if next_image_id == len(images): 208 print("All images have already been annotated and you have set 'skip_segmented=True'. Nothing to do.") 209 return 210 211 save_path = _get_save_path(images[next_image_id], next_image_id) 212 if not os.path.exists(save_path): 213 print("The first image to annotate is image number", next_image_id) 214 image, image_embedding_path = _load_image(next_image_id) 215 break 216 217 next_image_id += 1 218 219 else: 220 save_path = _get_save_path(images[next_image_id], next_image_id) 221 image, image_embedding_path = _load_image(next_image_id) 222 223 # Initialize the viewer and annotator for this image. 224 state = AnnotatorState() 225 viewer, annotator = _initialize_annotator( 226 viewer, image, image_embedding_path, 227 model_type, halo, tile_shape, predictor, decoder, is_volumetric, 228 precompute_amg_state, checkpoint_path, device, embedding_path, 229 save_path, 230 ) 231 232 def _save_segmentation(image_path, current_idx, segmentation): 233 save_path = _get_save_path(image_path, next_image_id) 234 imageio.imwrite(save_path, segmentation, compression="zlib") 235 236 # Add functionality for going to the next image. 237 @magicgui(call_button="Next Image [N]") 238 def next_image(*args): 239 nonlocal next_image_id 240 241 segmentation = viewer.layers["committed_objects"].data 242 abort = False 243 if segmentation.sum() == 0: 244 msg = "Nothing is segmented yet. Do you wish to continue to the next image?" 245 abort = widgets._generate_message("info", msg) 246 if abort: 247 return 248 249 # Save the current segmentation. 250 _save_segmentation(images[next_image_id], next_image_id, segmentation) 251 252 # Clear the segmentation already to avoid lagging removal. 253 viewer.layers["committed_objects"].data = np.zeros_like(viewer.layers["committed_objects"].data) 254 255 # Go to the next image. 256 next_image_id += 1 257 258 # Check if we are done. 259 if next_image_id == len(images): 260 # Inform the user via dialog. 261 abort = widgets._generate_message("info", end_msg) 262 if not abort: 263 viewer.close() 264 return 265 266 # If we are skipping images that are already segmented, then check if we have to load the next image. 267 save_path = _get_save_path(images[next_image_id], next_image_id) 268 if skip_segmented: 269 segmentation_result = None 270 while os.path.exists(save_path): 271 next_image_id += 1 272 273 # Check if we are done. 274 if next_image_id == len(images): 275 # Inform the user via dialog. 276 abort = widgets._generate_message("info", end_msg) 277 if not abort: 278 viewer.close() 279 return 280 281 save_path = _get_save_path(images[next_image_id], next_image_id) 282 else: 283 if os.path.exists(save_path): 284 segmentation_result = imageio.imread(save_path) 285 else: 286 segmentation_result = None 287 288 print( 289 "Loading next image:", images[next_image_id] if not have_inputs_as_arrays else f"at index {next_image_id}" 290 ) 291 292 if have_inputs_as_arrays: 293 image = images[next_image_id] 294 else: 295 image = imageio.imread(images[next_image_id]) 296 297 image_embedding_path = embedding_paths[next_image_id] 298 299 # Set the new image in the viewer, state and annotator. 300 viewer.layers["image"].data = image 301 302 if state.amg is not None: 303 state.amg.clear_state() 304 305 state.initialize_predictor( 306 image, model_type=model_type, ndim=3 if is_volumetric else 2, 307 save_path=image_embedding_path, 308 tile_shape=tile_shape, halo=halo, 309 predictor=predictor, decoder=decoder, 310 precompute_amg_state=precompute_amg_state, device=device, 311 skip_load=False, 312 ) 313 state.image_shape = _get_input_shape(image, is_volumetric) 314 315 annotator._update_image(segmentation_result=segmentation_result) 316 317 viewer.window.add_dock_widget(next_image) 318 319 @viewer.bind_key("n", overwrite=True) 320 def _next_image(viewer): 321 next_image(viewer) 322 323 if return_viewer: 324 return viewer 325 napari.run()
Run the annotation tool for a series of images (supported for both 2d and 3d images).
Arguments:
- images: List of the file paths or list of (set of) slices for the images to be annotated.
- output_folder: The folder where the segmentation results are saved.
- model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
- embedding_path: Filepath where to save the embeddings.
- tile_shape: Shape of tiles for tiled embedding prediction.
If
None
then the whole image is passed to Segment Anything. - halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
- viewer: The viewer to which the Segment Anything functionality should be added. This enables using a pre-initialized viewer.
- return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
- precompute_amg_state: Whether to precompute the state for automatic mask generation. This will take more time when precomputing embeddings, but will then make automatic mask generation much faster.
- checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
- is_volumetric: Whether to use the 3d annotator.
- prefer_decoder: Whether to use decoder based instance segmentation if the model used has an additional decoder for instance segmentation.
- skip_segmented: Whether to skip images that were already segmented. If set to False, then segmentations that already exist will be loaded and used to populate the 'committed_objects' layer.
Returns:
The napari viewer, only returned if
return_viewer=True
.
def
image_folder_annotator( input_folder: str, output_folder: str, pattern: str = '*', viewer: Optional[napari.viewer.Viewer] = None, return_viewer: bool = False, **kwargs) -> Optional[napari.viewer.Viewer]:
328def image_folder_annotator( 329 input_folder: str, 330 output_folder: str, 331 pattern: str = "*", 332 viewer: Optional["napari.viewer.Viewer"] = None, 333 return_viewer: bool = False, 334 **kwargs 335) -> Optional["napari.viewer.Viewer"]: 336 """Run the 2d annotation tool for a series of images in a folder. 337 338 Args: 339 input_folder: The folder with the images to be annotated. 340 output_folder: The folder where the segmentation results are saved. 341 pattern: The glob patter for loading files from `input_folder`. 342 By default all files will be loaded. 343 viewer: The viewer to which the Segment Anything functionality should be added. 344 This enables using a pre-initialized viewer. 345 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 346 kwargs: The keyword arguments for `micro_sam.sam_annotator.image_series_annotator`. 347 348 Returns: 349 The napari viewer, only returned if `return_viewer=True`. 350 """ 351 image_files = sorted(glob(os.path.join(input_folder, pattern))) 352 353 return image_series_annotator( 354 image_files, output_folder, viewer=viewer, return_viewer=return_viewer, **kwargs 355 )
Run the 2d annotation tool for a series of images in a folder.
Arguments:
- input_folder: The folder with the images to be annotated.
- output_folder: The folder where the segmentation results are saved.
- pattern: The glob patter for loading files from
input_folder
. By default all files will be loaded. - viewer: The viewer to which the Segment Anything functionality should be added. This enables using a pre-initialized viewer.
- return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
- kwargs: The keyword arguments for
micro_sam.sam_annotator.image_series_annotator
.
Returns:
The napari viewer, only returned if
return_viewer=True
.
358class ImageSeriesAnnotator(widgets._WidgetBase): 359 def __init__(self, viewer: napari.Viewer, parent=None): 360 super().__init__(parent=parent) 361 self._viewer = viewer 362 363 # Create the UI: the general options. 364 self._create_options() 365 366 # Add the settings (collapsible). 367 self.layout().addWidget(self._create_settings()) 368 369 # Add the run button to trigger the embedding computation. 370 self.run_button = QtWidgets.QPushButton("Annotate Images") 371 self.run_button.clicked.connect(self.__call__) 372 self.layout().addWidget(self.run_button) 373 374 def _create_options(self): 375 self.folder = None 376 _, layout = self._add_path_param( 377 "folder", self.folder, "directory", 378 title="Input Folder", placeholder="Folder with images ...", 379 tooltip=get_tooltip("image_series_annotator", "folder") 380 ) 381 self.layout().addLayout(layout) 382 383 self.output_folder = None 384 _, layout = self._add_path_param( 385 "output_folder", self.output_folder, "directory", 386 title="Output Folder", placeholder="Folder to save the results ...", 387 tooltip=get_tooltip("image_series_annotator", "output_folder") 388 ) 389 self.layout().addLayout(layout) 390 391 # Add the model family widget section. 392 layout = self._create_model_section(create_layout=False) 393 self.layout().addLayout(layout) 394 395 def _create_settings(self): 396 setting_values = QtWidgets.QWidget() 397 setting_values.setLayout(QtWidgets.QVBoxLayout()) 398 399 # Add the model size widget section. 400 layout = self._create_model_size_section() 401 setting_values.layout().addLayout(layout) 402 403 self.pattern = "*" 404 _, layout = self._add_string_param( 405 "pattern", self.pattern, tooltip=get_tooltip("image_series_annotator", "pattern") 406 ) 407 setting_values.layout().addLayout(layout) 408 409 self.is_volumetric = False 410 setting_values.layout().addWidget(self._add_boolean_param( 411 "is_volumetric", self.is_volumetric, tooltip=get_tooltip("image_series_annotator", "is_volumetric") 412 )) 413 414 self.device = "auto" 415 device_options = ["auto"] + util._available_devices() 416 self.device_dropdown, layout = self._add_choice_param( 417 "device", self.device, device_options, tooltip=get_tooltip("embedding", "device") 418 ) 419 setting_values.layout().addLayout(layout) 420 421 self.embeddings_save_path = None 422 _, layout = self._add_path_param( 423 "embeddings_save_path", self.embeddings_save_path, "directory", title="embeddings save path:", 424 tooltip=get_tooltip("embedding", "embeddings_save_path") 425 ) 426 setting_values.layout().addLayout(layout) 427 428 self.custom_weights = None # select_file 429 _, layout = self._add_path_param( 430 "custom_weights", self.custom_weights, "file", title="custom weights path:", 431 tooltip=get_tooltip("embedding", "custom_weights") 432 ) 433 setting_values.layout().addLayout(layout) 434 435 self.tile_x, self.tile_y = 0, 0 436 self.tile_x_param, self.tile_y_param, layout = self._add_shape_param( 437 ("tile_x", "tile_y"), (self.tile_x, self.tile_y), min_val=0, max_val=2048, step=16, 438 tooltip=get_tooltip("embedding", "tiling") 439 ) 440 setting_values.layout().addLayout(layout) 441 442 self.halo_x, self.halo_y = 0, 0 443 self.halo_x_param, self.halo_y_param, layout = self._add_shape_param( 444 ("halo_x", "halo_y"), (self.halo_x, self.halo_y), min_val=0, max_val=512, 445 tooltip=get_tooltip("embedding", "halo") 446 ) 447 setting_values.layout().addLayout(layout) 448 449 settings = widgets._make_collapsible(setting_values, title="Advanced Settings") 450 return settings 451 452 def _validate_inputs(self): 453 missing_data = self.folder is None or len(glob(os.path.join(self.folder, self.pattern))) == 0 454 missing_output = self.output_folder is None 455 if missing_data or missing_output: 456 msg = "" 457 if missing_data: 458 msg += "The input folder is missing or empty. " 459 if missing_output: 460 msg += "The output folder is missing." 461 return widgets._generate_message("error", msg) 462 return False 463 464 def __call__(self, skip_validate=False): 465 self._validate_model_type_and_custom_weights() 466 467 if not skip_validate and self._validate_inputs(): 468 return 469 tile_shape, halo = widgets._process_tiling_inputs(self.tile_x, self.tile_y, self.halo_x, self.halo_y) 470 471 image_folder_annotator( 472 input_folder=self.folder, 473 output_folder=self.output_folder, 474 pattern=self.pattern, 475 model_type=self.model_type, 476 embedding_path=self.embeddings_save_path, 477 tile_shape=tile_shape, halo=halo, checkpoint_path=self.custom_weights, 478 device=self.device, is_volumetric=self.is_volumetric, 479 viewer=self._viewer, return_viewer=True, 480 )
QWidget(parent: typing.Optional[QWidget] = None, flags: Union[Qt.WindowFlags, Qt.WindowType] = Qt.WindowFlags())
ImageSeriesAnnotator(viewer: napari.viewer.Viewer, parent=None)
359 def __init__(self, viewer: napari.Viewer, parent=None): 360 super().__init__(parent=parent) 361 self._viewer = viewer 362 363 # Create the UI: the general options. 364 self._create_options() 365 366 # Add the settings (collapsible). 367 self.layout().addWidget(self._create_settings()) 368 369 # Add the run button to trigger the embedding computation. 370 self.run_button = QtWidgets.QPushButton("Annotate Images") 371 self.run_button.clicked.connect(self.__call__) 372 self.layout().addWidget(self.run_button)
Inherited Members
- PyQt5.QtWidgets.QWidget
- RenderFlag
- RenderFlags
- acceptDrops
- accessibleDescription
- accessibleName
- actionEvent
- actions
- activateWindow
- addAction
- addActions
- adjustSize
- autoFillBackground
- backgroundRole
- baseSize
- changeEvent
- childAt
- childrenRect
- childrenRegion
- clearFocus
- clearMask
- close
- closeEvent
- contentsMargins
- contentsRect
- contextMenuEvent
- contextMenuPolicy
- create
- createWindowContainer
- cursor
- destroy
- devType
- dragEnterEvent
- dragLeaveEvent
- dragMoveEvent
- dropEvent
- effectiveWinId
- ensurePolished
- enterEvent
- event
- find
- focusInEvent
- focusNextChild
- focusNextPrevChild
- focusOutEvent
- focusPolicy
- focusPreviousChild
- focusProxy
- focusWidget
- font
- fontInfo
- fontMetrics
- foregroundRole
- frameGeometry
- frameSize
- geometry
- getContentsMargins
- grab
- grabGesture
- grabKeyboard
- grabMouse
- grabShortcut
- graphicsEffect
- graphicsProxyWidget
- hasFocus
- hasHeightForWidth
- hasMouseTracking
- hasTabletTracking
- height
- heightForWidth
- hide
- hideEvent
- initPainter
- inputMethodEvent
- inputMethodHints
- inputMethodQuery
- insertAction
- insertActions
- isActiveWindow
- isAncestorOf
- isEnabled
- isEnabledTo
- isFullScreen
- isHidden
- isLeftToRight
- isMaximized
- isMinimized
- isModal
- isRightToLeft
- isVisible
- isVisibleTo
- isWindow
- isWindowModified
- keyPressEvent
- keyReleaseEvent
- keyboardGrabber
- layout
- layoutDirection
- leaveEvent
- locale
- lower
- mapFrom
- mapFromGlobal
- mapFromParent
- mapTo
- mapToGlobal
- mapToParent
- mask
- maximumHeight
- maximumSize
- maximumWidth
- metric
- minimumHeight
- minimumSize
- minimumSizeHint
- minimumWidth
- mouseDoubleClickEvent
- mouseGrabber
- mouseMoveEvent
- mousePressEvent
- mouseReleaseEvent
- move
- moveEvent
- nativeEvent
- nativeParentWidget
- nextInFocusChain
- normalGeometry
- overrideWindowFlags
- overrideWindowState
- paintEngine
- paintEvent
- palette
- parentWidget
- pos
- previousInFocusChain
- raise_
- rect
- releaseKeyboard
- releaseMouse
- releaseShortcut
- removeAction
- render
- repaint
- resize
- resizeEvent
- restoreGeometry
- saveGeometry
- screen
- scroll
- setAcceptDrops
- setAccessibleDescription
- setAccessibleName
- setAttribute
- setAutoFillBackground
- setBackgroundRole
- setBaseSize
- setContentsMargins
- setContextMenuPolicy
- setCursor
- setDisabled
- setEnabled
- setFixedHeight
- setFixedSize
- setFixedWidth
- setFocus
- setFocusPolicy
- setFocusProxy
- setFont
- setForegroundRole
- setGeometry
- setGraphicsEffect
- setHidden
- setInputMethodHints
- setLayout
- setLayoutDirection
- setLocale
- setMask
- setMaximumHeight
- setMaximumSize
- setMaximumWidth
- setMinimumHeight
- setMinimumSize
- setMinimumWidth
- setMouseTracking
- setPalette
- setParent
- setShortcutAutoRepeat
- setShortcutEnabled
- setSizeIncrement
- setSizePolicy
- setStatusTip
- setStyle
- setStyleSheet
- setTabOrder
- setTabletTracking
- setToolTip
- setToolTipDuration
- setUpdatesEnabled
- setVisible
- setWhatsThis
- setWindowFilePath
- setWindowFlag
- setWindowFlags
- setWindowIcon
- setWindowIconText
- setWindowModality
- setWindowModified
- setWindowOpacity
- setWindowRole
- setWindowState
- setWindowTitle
- show
- showEvent
- showFullScreen
- showMaximized
- showMinimized
- showNormal
- size
- sizeHint
- sizeIncrement
- sizePolicy
- stackUnder
- statusTip
- style
- styleSheet
- tabletEvent
- testAttribute
- toolTip
- toolTipDuration
- underMouse
- ungrabGesture
- unsetCursor
- unsetLayoutDirection
- unsetLocale
- update
- updateGeometry
- updateMicroFocus
- updatesEnabled
- visibleRegion
- whatsThis
- wheelEvent
- width
- winId
- window
- windowFilePath
- windowFlags
- windowHandle
- windowIcon
- windowIconText
- windowModality
- windowOpacity
- windowRole
- windowState
- windowTitle
- windowType
- x
- y
- DrawChildren
- DrawWindowBackground
- IgnoreMask
- windowIconTextChanged
- windowIconChanged
- windowTitleChanged
- customContextMenuRequested
- PyQt5.QtCore.QObject
- blockSignals
- childEvent
- children
- connectNotify
- customEvent
- deleteLater
- disconnect
- disconnectNotify
- dumpObjectInfo
- dumpObjectTree
- dynamicPropertyNames
- eventFilter
- findChild
- findChildren
- inherits
- installEventFilter
- isSignalConnected
- isWidgetType
- isWindowType
- killTimer
- metaObject
- moveToThread
- objectName
- parent
- property
- pyqtConfigure
- receivers
- removeEventFilter
- sender
- senderSignalIndex
- setObjectName
- setProperty
- signalsBlocked
- startTimer
- thread
- timerEvent
- tr
- staticMetaObject
- objectNameChanged
- destroyed
- PyQt5.QtGui.QPaintDevice
- PaintDeviceMetric
- colorCount
- depth
- devicePixelRatio
- devicePixelRatioF
- devicePixelRatioFScale
- heightMM
- logicalDpiX
- logicalDpiY
- paintingActive
- physicalDpiX
- physicalDpiY
- widthMM
- PdmDepth
- PdmDevicePixelRatio
- PdmDevicePixelRatioScaled
- PdmDpiX
- PdmDpiY
- PdmHeight
- PdmHeightMM
- PdmNumColors
- PdmPhysicalDpiX
- PdmPhysicalDpiY
- PdmWidth
- PdmWidthMM