micro_sam.sam_annotator.image_series_annotator
1import os 2import time 3from glob import glob 4from pathlib import Path 5from typing import List, Optional, Union, Tuple 6 7import numpy as np 8import imageio.v3 as imageio 9 10import torch 11 12import napari 13from magicgui import magicgui 14from qtpy import QtWidgets 15 16from .. import util 17from . import _widgets as widgets 18from ._tooltips import get_tooltip 19from ._state import AnnotatorState 20from .annotator_2d import Annotator2d 21from .annotator_3d import Annotator3d 22from .util import _sync_embedding_widget 23from ..instance_segmentation import get_decoder 24from ..precompute_state import _precompute_state_for_files 25 26 27def _precompute( 28 images, model_type, embedding_path, 29 tile_shape, halo, precompute_amg_state, 30 checkpoint_path, device, ndim, prefer_decoder, 31): 32 t_start = time.time() 33 34 device = util.get_device(device) 35 predictor, state = util.get_sam_model( 36 model_type=model_type, checkpoint_path=checkpoint_path, device=device, return_state=True 37 ) 38 if prefer_decoder and "decoder_state" in state: 39 decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device) 40 else: 41 decoder = None 42 43 if embedding_path is None: 44 embedding_paths = [None] * len(images) 45 else: 46 _precompute_state_for_files( 47 predictor, images, embedding_path, ndim=ndim, tile_shape=tile_shape, halo=halo, 48 precompute_amg_state=precompute_amg_state, decoder=decoder, 49 ) 50 if isinstance(images[0], np.ndarray): 51 embedding_paths = [ 52 os.path.join(embedding_path, f"embedding_{i:05}.zarr") for i, path in enumerate(images) 53 ] 54 else: 55 embedding_paths = [ 56 os.path.join(embedding_path, f"{Path(path).stem}.zarr") for path in images 57 ] 58 assert all(os.path.exists(emb_path) for emb_path in embedding_paths) 59 60 t_run = time.time() - t_start 61 minutes = int(t_run // 60) 62 seconds = int(round(t_run % 60, 0)) 63 print("Precomputation took", t_run, f"seconds (= {minutes:02}:{seconds:02} minutes)") 64 65 return predictor, decoder, embedding_paths 66 67 68def _get_input_shape(image, is_volumetric=False): 69 if image.ndim == 2: 70 image_shape = image.shape 71 elif image.ndim == 3: 72 if is_volumetric: 73 image_shape = image.shape 74 else: 75 image_shape = image.shape[:-1] 76 elif image.ndim == 4: 77 image_shape = image.shape[:-1] 78 79 return image_shape 80 81 82def _initialize_annotator( 83 viewer, image, image_embedding_path, 84 model_type, halo, tile_shape, predictor, decoder, is_volumetric, 85 precompute_amg_state, checkpoint_path, device, embedding_path, 86 segmentation_path, 87): 88 if viewer is None: 89 viewer = napari.Viewer() 90 viewer.add_image(image, name="image") 91 92 state = AnnotatorState() 93 state.initialize_predictor( 94 image, model_type=model_type, save_path=image_embedding_path, halo=halo, tile_shape=tile_shape, 95 predictor=predictor, decoder=decoder, 96 ndim=3 if is_volumetric else 2, precompute_amg_state=precompute_amg_state, 97 checkpoint_path=checkpoint_path, device=device, skip_load=False, use_cli=True, 98 ) 99 state.image_shape = _get_input_shape(image, is_volumetric) 100 101 if is_volumetric: 102 if image.ndim not in [3, 4]: 103 raise ValueError(f"Invalid image dimensions for 3d annotator, expect 3 or 4 dimensions, got {image.ndim}") 104 annotator = Annotator3d(viewer) 105 else: 106 if image.ndim not in (2, 3): 107 raise ValueError(f"Invalid image dimensions for 2d annotator, expect 2 or 3 dimensions, got {image.ndim}") 108 annotator = Annotator2d(viewer) 109 110 if os.path.exists(segmentation_path): 111 segmentation_result = imageio.imread(segmentation_path) 112 else: 113 segmentation_result = None 114 annotator._update_image(segmentation_result=segmentation_result) 115 116 # Add the annotator widget to the viewer and sync widgets. 117 viewer.window.add_dock_widget(annotator) 118 _sync_embedding_widget( 119 widget=state.widgets["embeddings"], 120 model_type=model_type if checkpoint_path is None else state.predictor.model_type, 121 save_path=embedding_path, 122 checkpoint_path=checkpoint_path, 123 device=device, 124 tile_shape=tile_shape, 125 halo=halo, 126 ) 127 return viewer, annotator 128 129 130def image_series_annotator( 131 images: Union[List[Union[os.PathLike, str]], List[np.ndarray]], 132 output_folder: str, 133 model_type: str = util._DEFAULT_MODEL, 134 embedding_path: Optional[str] = None, 135 tile_shape: Optional[Tuple[int, int]] = None, 136 halo: Optional[Tuple[int, int]] = None, 137 viewer: Optional["napari.viewer.Viewer"] = None, 138 return_viewer: bool = False, 139 precompute_amg_state: bool = False, 140 checkpoint_path: Optional[str] = None, 141 is_volumetric: bool = False, 142 device: Optional[Union[str, torch.device]] = None, 143 prefer_decoder: bool = True, 144 skip_segmented: bool = True, 145) -> Optional["napari.viewer.Viewer"]: 146 """Run the annotation tool for a series of images (supported for both 2d and 3d images). 147 148 Args: 149 images: List of the file paths or list of (set of) slices for the images to be annotated. 150 output_folder: The folder where the segmentation results are saved. 151 model_type: The Segment Anything model to use. For details on the available models check out 152 https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models. 153 embedding_path: Filepath where to save the embeddings. 154 tile_shape: Shape of tiles for tiled embedding prediction. 155 If `None` then the whole image is passed to Segment Anything. 156 halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders. 157 viewer: The viewer to which the Segment Anything functionality should be added. 158 This enables using a pre-initialized viewer. 159 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 160 By default, does not return the napari viewer. 161 precompute_amg_state: Whether to precompute the state for automatic mask generation. 162 This will take more time when precomputing embeddings, but will then make 163 automatic mask generation much faster. By default, set to 'False'. 164 checkpoint_path: Path to a custom checkpoint from which to load the SAM model. 165 is_volumetric: Whether to use the 3d annotator. By default, set to 'False'. 166 prefer_decoder: Whether to use decoder based instance segmentation if 167 the model used has an additional decoder for instance segmentation. 168 By default, set to 'True'. 169 skip_segmented: Whether to skip images that were already segmented. 170 If set to False, then segmentations that already exist will be loaded 171 and used to populate the 'committed_objects' layer. 172 173 Returns: 174 The napari viewer, only returned if `return_viewer=True`. 175 """ 176 end_msg = "You have annotated the last image. Do you wish to close napari?" 177 os.makedirs(output_folder, exist_ok=True) 178 179 # Precompute embeddings and amg state (if corresponding options set). 180 predictor, decoder, embedding_paths = _precompute( 181 images, model_type, 182 embedding_path, tile_shape, halo, precompute_amg_state, 183 checkpoint_path=checkpoint_path, device=device, 184 ndim=3 if is_volumetric else 2, prefer_decoder=prefer_decoder, 185 ) 186 187 next_image_id = 0 188 have_inputs_as_arrays = isinstance(images[next_image_id], np.ndarray) 189 190 def _get_save_path(image_path, current_idx): 191 if have_inputs_as_arrays: 192 fname = f"seg_{current_idx:05}.tif" 193 else: 194 fname = os.path.basename(image_path) 195 fname = os.path.splitext(fname)[0] + ".tif" 196 return os.path.join(output_folder, fname) 197 198 def _load_image(image_id): 199 image = images[next_image_id] 200 if not have_inputs_as_arrays: 201 image = imageio.imread(image) 202 image_embedding_path = embedding_paths[next_image_id] 203 return image, image_embedding_path 204 205 # Check which image to load next if we skip segmented images. 206 if skip_segmented: 207 while True: 208 if next_image_id == len(images): 209 print("All images have already been annotated and you have set 'skip_segmented=True'. Nothing to do.") 210 return 211 212 save_path = _get_save_path(images[next_image_id], next_image_id) 213 if not os.path.exists(save_path): 214 print("The first image to annotate is image number", next_image_id) 215 image, image_embedding_path = _load_image(next_image_id) 216 break 217 218 next_image_id += 1 219 220 else: 221 save_path = _get_save_path(images[next_image_id], next_image_id) 222 image, image_embedding_path = _load_image(next_image_id) 223 224 # Initialize the viewer and annotator for this image. 225 state = AnnotatorState() 226 viewer, annotator = _initialize_annotator( 227 viewer, image, image_embedding_path, 228 model_type, halo, tile_shape, predictor, decoder, is_volumetric, 229 precompute_amg_state, checkpoint_path, device, embedding_path, 230 save_path, 231 ) 232 233 def _save_segmentation(image_path, current_idx, segmentation): 234 save_path = _get_save_path(image_path, next_image_id) 235 imageio.imwrite(save_path, segmentation, compression="zlib") 236 237 # Add functionality for going to the next image. 238 @magicgui(call_button="Next Image [N]") 239 def next_image(*args): 240 nonlocal next_image_id 241 242 segmentation = viewer.layers["committed_objects"].data 243 abort = False 244 if segmentation.sum() == 0: 245 msg = "Nothing is segmented yet. Do you wish to continue to the next image?" 246 abort = widgets._generate_message("info", msg) 247 if abort: 248 return 249 250 # Save the current segmentation. 251 _save_segmentation(images[next_image_id], next_image_id, segmentation) 252 253 # Clear the segmentation already to avoid lagging removal. 254 viewer.layers["committed_objects"].data = np.zeros_like(viewer.layers["committed_objects"].data) 255 256 # Go to the next image. 257 next_image_id += 1 258 259 # Check if we are done. 260 if next_image_id == len(images): 261 # Inform the user via dialog. 262 abort = widgets._generate_message("info", end_msg) 263 if not abort: 264 viewer.close() 265 return 266 267 # If we are skipping images that are already segmented, then check if we have to load the next image. 268 save_path = _get_save_path(images[next_image_id], next_image_id) 269 if skip_segmented: 270 segmentation_result = None 271 while os.path.exists(save_path): 272 next_image_id += 1 273 274 # Check if we are done. 275 if next_image_id == len(images): 276 # Inform the user via dialog. 277 abort = widgets._generate_message("info", end_msg) 278 if not abort: 279 viewer.close() 280 return 281 282 save_path = _get_save_path(images[next_image_id], next_image_id) 283 else: 284 if os.path.exists(save_path): 285 segmentation_result = imageio.imread(save_path) 286 else: 287 segmentation_result = None 288 289 print( 290 "Loading next image:", images[next_image_id] if not have_inputs_as_arrays else f"at index {next_image_id}" 291 ) 292 293 if have_inputs_as_arrays: 294 image = images[next_image_id] 295 else: 296 image = imageio.imread(images[next_image_id]) 297 298 image_embedding_path = embedding_paths[next_image_id] 299 300 # Set the new image in the viewer, state and annotator. 301 viewer.layers["image"].data = image 302 303 if state.amg is not None: 304 state.amg.clear_state() 305 306 state.initialize_predictor( 307 image, model_type=model_type, ndim=3 if is_volumetric else 2, 308 save_path=image_embedding_path, 309 tile_shape=tile_shape, halo=halo, 310 predictor=predictor, decoder=decoder, 311 precompute_amg_state=precompute_amg_state, device=device, 312 skip_load=False, 313 ) 314 state.image_shape = _get_input_shape(image, is_volumetric) 315 316 annotator._update_image(segmentation_result=segmentation_result) 317 318 viewer.window.add_dock_widget(next_image) 319 320 @viewer.bind_key("n", overwrite=True) 321 def _next_image(viewer): 322 next_image(viewer) 323 324 if return_viewer: 325 return viewer 326 napari.run() 327 328 329def image_folder_annotator( 330 input_folder: str, 331 output_folder: str, 332 pattern: str = "*", 333 viewer: Optional["napari.viewer.Viewer"] = None, 334 return_viewer: bool = False, 335 **kwargs 336) -> Optional["napari.viewer.Viewer"]: 337 """Run the 2d annotation tool for a series of images in a folder. 338 339 Args: 340 input_folder: The folder with the images to be annotated. 341 output_folder: The folder where the segmentation results are saved. 342 pattern: The glob patter for loading files from `input_folder`. 343 By default all files will be loaded. 344 viewer: The viewer to which the Segment Anything functionality should be added. 345 This enables using a pre-initialized viewer. 346 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 347 By default, does not return the napari viewer. 348 kwargs: The keyword arguments for `micro_sam.sam_annotator.image_series_annotator`. 349 350 Returns: 351 The napari viewer, only returned if `return_viewer=True`. 352 """ 353 image_files = sorted(glob(os.path.join(input_folder, pattern))) 354 355 return image_series_annotator( 356 image_files, output_folder, viewer=viewer, return_viewer=return_viewer, **kwargs 357 ) 358 359 360class ImageSeriesAnnotator(widgets._WidgetBase): 361 def __init__(self, viewer: napari.Viewer, parent=None): 362 super().__init__(parent=parent) 363 self._viewer = viewer 364 365 # Create the UI: the general options. 366 self._create_options() 367 368 # Add the settings (collapsible). 369 self.layout().addWidget(self._create_settings()) 370 371 # Add the run button to trigger the embedding computation. 372 self.run_button = QtWidgets.QPushButton("Annotate Images") 373 self.run_button.clicked.connect(self.__call__) 374 self.layout().addWidget(self.run_button) 375 376 def _create_options(self): 377 self.folder = None 378 _, layout = self._add_path_param( 379 "folder", self.folder, "directory", 380 title="Input Folder", placeholder="Folder with images ...", 381 tooltip=get_tooltip("image_series_annotator", "folder") 382 ) 383 self.layout().addLayout(layout) 384 385 self.output_folder = None 386 _, layout = self._add_path_param( 387 "output_folder", self.output_folder, "directory", 388 title="Output Folder", placeholder="Folder to save the results ...", 389 tooltip=get_tooltip("image_series_annotator", "output_folder") 390 ) 391 self.layout().addLayout(layout) 392 393 # Add the model family widget section. 394 layout = self._create_model_section(create_layout=False) 395 self.layout().addLayout(layout) 396 397 def _create_settings(self): 398 setting_values = QtWidgets.QWidget() 399 setting_values.setLayout(QtWidgets.QVBoxLayout()) 400 401 # Add the model size widget section. 402 layout = self._create_model_size_section() 403 setting_values.layout().addLayout(layout) 404 405 self.pattern = "*" 406 _, layout = self._add_string_param( 407 "pattern", self.pattern, tooltip=get_tooltip("image_series_annotator", "pattern") 408 ) 409 setting_values.layout().addLayout(layout) 410 411 self.is_volumetric = False 412 setting_values.layout().addWidget(self._add_boolean_param( 413 "is_volumetric", self.is_volumetric, tooltip=get_tooltip("image_series_annotator", "is_volumetric") 414 )) 415 416 self.device = "auto" 417 device_options = ["auto"] + util._available_devices() 418 self.device_dropdown, layout = self._add_choice_param( 419 "device", self.device, device_options, tooltip=get_tooltip("embedding", "device") 420 ) 421 setting_values.layout().addLayout(layout) 422 423 self.embeddings_save_path = None 424 _, layout = self._add_path_param( 425 "embeddings_save_path", self.embeddings_save_path, "directory", title="embeddings save path:", 426 tooltip=get_tooltip("embedding", "embeddings_save_path") 427 ) 428 setting_values.layout().addLayout(layout) 429 430 self.custom_weights = None # select_file 431 _, layout = self._add_path_param( 432 "custom_weights", self.custom_weights, "file", title="custom weights path:", 433 tooltip=get_tooltip("embedding", "custom_weights") 434 ) 435 setting_values.layout().addLayout(layout) 436 437 self.tile_x, self.tile_y = 0, 0 438 self.tile_x_param, self.tile_y_param, layout = self._add_shape_param( 439 ("tile_x", "tile_y"), (self.tile_x, self.tile_y), min_val=0, max_val=2048, step=16, 440 tooltip=get_tooltip("embedding", "tiling") 441 ) 442 setting_values.layout().addLayout(layout) 443 444 self.halo_x, self.halo_y = 0, 0 445 self.halo_x_param, self.halo_y_param, layout = self._add_shape_param( 446 ("halo_x", "halo_y"), (self.halo_x, self.halo_y), min_val=0, max_val=512, 447 tooltip=get_tooltip("embedding", "halo") 448 ) 449 setting_values.layout().addLayout(layout) 450 451 settings = widgets._make_collapsible(setting_values, title="Advanced Settings") 452 return settings 453 454 def _validate_inputs(self): 455 missing_data = self.folder is None or len(glob(os.path.join(self.folder, self.pattern))) == 0 456 missing_output = self.output_folder is None 457 if missing_data or missing_output: 458 msg = "" 459 if missing_data: 460 msg += "The input folder is missing or empty. " 461 if missing_output: 462 msg += "The output folder is missing." 463 return widgets._generate_message("error", msg) 464 return False 465 466 def __call__(self, skip_validate=False): 467 self._validate_model_type_and_custom_weights() 468 469 if not skip_validate and self._validate_inputs(): 470 return 471 tile_shape, halo = widgets._process_tiling_inputs(self.tile_x, self.tile_y, self.halo_x, self.halo_y) 472 473 image_folder_annotator( 474 input_folder=self.folder, 475 output_folder=self.output_folder, 476 pattern=self.pattern, 477 model_type=self.model_type, 478 embedding_path=self.embeddings_save_path, 479 tile_shape=tile_shape, halo=halo, checkpoint_path=self.custom_weights, 480 device=self.device, is_volumetric=self.is_volumetric, 481 viewer=self._viewer, return_viewer=True, 482 ) 483 484 485def main(): 486 """@private""" 487 import argparse 488 489 available_models = list(util.get_model_names()) 490 available_models = ", ".join(available_models) 491 492 parser = argparse.ArgumentParser(description="Annotate a series of images from a folder.") 493 parser.add_argument( 494 "-i", "--input_folder", required=True, 495 help="The folder containing the image data. The data can be stored in any common format (tif, jpg, png, ...)." 496 ) 497 parser.add_argument( 498 "-o", "--output_folder", required=True, 499 help="The folder where the segmentation results will be stored." 500 ) 501 parser.add_argument( 502 "-p", "--pattern", default="*", 503 help="The pattern to select the images to annotator from the input folder. E.g. *.tif to annotate all tifs." 504 "By default all files in the folder will be loaded and annotated." 505 ) 506 parser.add_argument( 507 "-e", "--embedding_path", 508 help="The filepath for saving/loading the pre-computed image embeddings. " 509 "NOTE: It is recommended to pass this argument and store the embeddings, " 510 "otherwise they will be recomputed every time (which can take a long time)." 511 ) 512 parser.add_argument( 513 "-m", "--model_type", default=util._DEFAULT_MODEL, 514 help=f"The segment anything model that will be used, one of {available_models}." 515 ) 516 parser.add_argument( 517 "-c", "--checkpoint", default=None, 518 help="Checkpoint from which the SAM model will be loaded loaded." 519 ) 520 parser.add_argument( 521 "-d", "--device", default=None, 522 help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)." 523 "By default the most performant available device will be selected." 524 ) 525 parser.add_argument( 526 "--is_volumetric", action="store_true", help="Whether to use the 3d annotator for a set of 3d volumes." 527 ) 528 529 parser.add_argument( 530 "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction", default=None 531 ) 532 parser.add_argument( 533 "--halo", nargs="+", type=int, help="The halo for using tiled prediction", default=None 534 ) 535 parser.add_argument("--precompute_amg_state", action="store_true") 536 parser.add_argument("--prefer_decoder", action="store_false") 537 parser.add_argument("--skip_segmented", action="store_false") 538 539 args = parser.parse_args() 540 541 image_folder_annotator( 542 args.input_folder, args.output_folder, args.pattern, 543 embedding_path=args.embedding_path, model_type=args.model_type, 544 tile_shape=args.tile_shape, halo=args.halo, precompute_amg_state=args.precompute_amg_state, 545 checkpoint_path=args.checkpoint, device=args.device, is_volumetric=args.is_volumetric, 546 prefer_decoder=args.prefer_decoder, skip_segmented=args.skip_segmented 547 )
def
image_series_annotator( images: Union[List[Union[str, os.PathLike]], List[numpy.ndarray]], output_folder: str, model_type: str = 'vit_b_lm', embedding_path: Optional[str] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, viewer: Optional[napari.viewer.Viewer] = None, return_viewer: bool = False, precompute_amg_state: bool = False, checkpoint_path: Optional[str] = None, is_volumetric: bool = False, device: Union[str, torch.device, NoneType] = None, prefer_decoder: bool = True, skip_segmented: bool = True) -> Optional[napari.viewer.Viewer]:
131def image_series_annotator( 132 images: Union[List[Union[os.PathLike, str]], List[np.ndarray]], 133 output_folder: str, 134 model_type: str = util._DEFAULT_MODEL, 135 embedding_path: Optional[str] = None, 136 tile_shape: Optional[Tuple[int, int]] = None, 137 halo: Optional[Tuple[int, int]] = None, 138 viewer: Optional["napari.viewer.Viewer"] = None, 139 return_viewer: bool = False, 140 precompute_amg_state: bool = False, 141 checkpoint_path: Optional[str] = None, 142 is_volumetric: bool = False, 143 device: Optional[Union[str, torch.device]] = None, 144 prefer_decoder: bool = True, 145 skip_segmented: bool = True, 146) -> Optional["napari.viewer.Viewer"]: 147 """Run the annotation tool for a series of images (supported for both 2d and 3d images). 148 149 Args: 150 images: List of the file paths or list of (set of) slices for the images to be annotated. 151 output_folder: The folder where the segmentation results are saved. 152 model_type: The Segment Anything model to use. For details on the available models check out 153 https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models. 154 embedding_path: Filepath where to save the embeddings. 155 tile_shape: Shape of tiles for tiled embedding prediction. 156 If `None` then the whole image is passed to Segment Anything. 157 halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders. 158 viewer: The viewer to which the Segment Anything functionality should be added. 159 This enables using a pre-initialized viewer. 160 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 161 By default, does not return the napari viewer. 162 precompute_amg_state: Whether to precompute the state for automatic mask generation. 163 This will take more time when precomputing embeddings, but will then make 164 automatic mask generation much faster. By default, set to 'False'. 165 checkpoint_path: Path to a custom checkpoint from which to load the SAM model. 166 is_volumetric: Whether to use the 3d annotator. By default, set to 'False'. 167 prefer_decoder: Whether to use decoder based instance segmentation if 168 the model used has an additional decoder for instance segmentation. 169 By default, set to 'True'. 170 skip_segmented: Whether to skip images that were already segmented. 171 If set to False, then segmentations that already exist will be loaded 172 and used to populate the 'committed_objects' layer. 173 174 Returns: 175 The napari viewer, only returned if `return_viewer=True`. 176 """ 177 end_msg = "You have annotated the last image. Do you wish to close napari?" 178 os.makedirs(output_folder, exist_ok=True) 179 180 # Precompute embeddings and amg state (if corresponding options set). 181 predictor, decoder, embedding_paths = _precompute( 182 images, model_type, 183 embedding_path, tile_shape, halo, precompute_amg_state, 184 checkpoint_path=checkpoint_path, device=device, 185 ndim=3 if is_volumetric else 2, prefer_decoder=prefer_decoder, 186 ) 187 188 next_image_id = 0 189 have_inputs_as_arrays = isinstance(images[next_image_id], np.ndarray) 190 191 def _get_save_path(image_path, current_idx): 192 if have_inputs_as_arrays: 193 fname = f"seg_{current_idx:05}.tif" 194 else: 195 fname = os.path.basename(image_path) 196 fname = os.path.splitext(fname)[0] + ".tif" 197 return os.path.join(output_folder, fname) 198 199 def _load_image(image_id): 200 image = images[next_image_id] 201 if not have_inputs_as_arrays: 202 image = imageio.imread(image) 203 image_embedding_path = embedding_paths[next_image_id] 204 return image, image_embedding_path 205 206 # Check which image to load next if we skip segmented images. 207 if skip_segmented: 208 while True: 209 if next_image_id == len(images): 210 print("All images have already been annotated and you have set 'skip_segmented=True'. Nothing to do.") 211 return 212 213 save_path = _get_save_path(images[next_image_id], next_image_id) 214 if not os.path.exists(save_path): 215 print("The first image to annotate is image number", next_image_id) 216 image, image_embedding_path = _load_image(next_image_id) 217 break 218 219 next_image_id += 1 220 221 else: 222 save_path = _get_save_path(images[next_image_id], next_image_id) 223 image, image_embedding_path = _load_image(next_image_id) 224 225 # Initialize the viewer and annotator for this image. 226 state = AnnotatorState() 227 viewer, annotator = _initialize_annotator( 228 viewer, image, image_embedding_path, 229 model_type, halo, tile_shape, predictor, decoder, is_volumetric, 230 precompute_amg_state, checkpoint_path, device, embedding_path, 231 save_path, 232 ) 233 234 def _save_segmentation(image_path, current_idx, segmentation): 235 save_path = _get_save_path(image_path, next_image_id) 236 imageio.imwrite(save_path, segmentation, compression="zlib") 237 238 # Add functionality for going to the next image. 239 @magicgui(call_button="Next Image [N]") 240 def next_image(*args): 241 nonlocal next_image_id 242 243 segmentation = viewer.layers["committed_objects"].data 244 abort = False 245 if segmentation.sum() == 0: 246 msg = "Nothing is segmented yet. Do you wish to continue to the next image?" 247 abort = widgets._generate_message("info", msg) 248 if abort: 249 return 250 251 # Save the current segmentation. 252 _save_segmentation(images[next_image_id], next_image_id, segmentation) 253 254 # Clear the segmentation already to avoid lagging removal. 255 viewer.layers["committed_objects"].data = np.zeros_like(viewer.layers["committed_objects"].data) 256 257 # Go to the next image. 258 next_image_id += 1 259 260 # Check if we are done. 261 if next_image_id == len(images): 262 # Inform the user via dialog. 263 abort = widgets._generate_message("info", end_msg) 264 if not abort: 265 viewer.close() 266 return 267 268 # If we are skipping images that are already segmented, then check if we have to load the next image. 269 save_path = _get_save_path(images[next_image_id], next_image_id) 270 if skip_segmented: 271 segmentation_result = None 272 while os.path.exists(save_path): 273 next_image_id += 1 274 275 # Check if we are done. 276 if next_image_id == len(images): 277 # Inform the user via dialog. 278 abort = widgets._generate_message("info", end_msg) 279 if not abort: 280 viewer.close() 281 return 282 283 save_path = _get_save_path(images[next_image_id], next_image_id) 284 else: 285 if os.path.exists(save_path): 286 segmentation_result = imageio.imread(save_path) 287 else: 288 segmentation_result = None 289 290 print( 291 "Loading next image:", images[next_image_id] if not have_inputs_as_arrays else f"at index {next_image_id}" 292 ) 293 294 if have_inputs_as_arrays: 295 image = images[next_image_id] 296 else: 297 image = imageio.imread(images[next_image_id]) 298 299 image_embedding_path = embedding_paths[next_image_id] 300 301 # Set the new image in the viewer, state and annotator. 302 viewer.layers["image"].data = image 303 304 if state.amg is not None: 305 state.amg.clear_state() 306 307 state.initialize_predictor( 308 image, model_type=model_type, ndim=3 if is_volumetric else 2, 309 save_path=image_embedding_path, 310 tile_shape=tile_shape, halo=halo, 311 predictor=predictor, decoder=decoder, 312 precompute_amg_state=precompute_amg_state, device=device, 313 skip_load=False, 314 ) 315 state.image_shape = _get_input_shape(image, is_volumetric) 316 317 annotator._update_image(segmentation_result=segmentation_result) 318 319 viewer.window.add_dock_widget(next_image) 320 321 @viewer.bind_key("n", overwrite=True) 322 def _next_image(viewer): 323 next_image(viewer) 324 325 if return_viewer: 326 return viewer 327 napari.run()
Run the annotation tool for a series of images (supported for both 2d and 3d images).
Arguments:
- images: List of the file paths or list of (set of) slices for the images to be annotated.
- output_folder: The folder where the segmentation results are saved.
- model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
- embedding_path: Filepath where to save the embeddings.
- tile_shape: Shape of tiles for tiled embedding prediction.
If
None
then the whole image is passed to Segment Anything. - halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
- viewer: The viewer to which the Segment Anything functionality should be added. This enables using a pre-initialized viewer.
- return_viewer: Whether to return the napari viewer to further modify it before starting the tool. By default, does not return the napari viewer.
- precompute_amg_state: Whether to precompute the state for automatic mask generation. This will take more time when precomputing embeddings, but will then make automatic mask generation much faster. By default, set to 'False'.
- checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
- is_volumetric: Whether to use the 3d annotator. By default, set to 'False'.
- prefer_decoder: Whether to use decoder based instance segmentation if the model used has an additional decoder for instance segmentation. By default, set to 'True'.
- skip_segmented: Whether to skip images that were already segmented. If set to False, then segmentations that already exist will be loaded and used to populate the 'committed_objects' layer.
Returns:
The napari viewer, only returned if
return_viewer=True
.
def
image_folder_annotator( input_folder: str, output_folder: str, pattern: str = '*', viewer: Optional[napari.viewer.Viewer] = None, return_viewer: bool = False, **kwargs) -> Optional[napari.viewer.Viewer]:
330def image_folder_annotator( 331 input_folder: str, 332 output_folder: str, 333 pattern: str = "*", 334 viewer: Optional["napari.viewer.Viewer"] = None, 335 return_viewer: bool = False, 336 **kwargs 337) -> Optional["napari.viewer.Viewer"]: 338 """Run the 2d annotation tool for a series of images in a folder. 339 340 Args: 341 input_folder: The folder with the images to be annotated. 342 output_folder: The folder where the segmentation results are saved. 343 pattern: The glob patter for loading files from `input_folder`. 344 By default all files will be loaded. 345 viewer: The viewer to which the Segment Anything functionality should be added. 346 This enables using a pre-initialized viewer. 347 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 348 By default, does not return the napari viewer. 349 kwargs: The keyword arguments for `micro_sam.sam_annotator.image_series_annotator`. 350 351 Returns: 352 The napari viewer, only returned if `return_viewer=True`. 353 """ 354 image_files = sorted(glob(os.path.join(input_folder, pattern))) 355 356 return image_series_annotator( 357 image_files, output_folder, viewer=viewer, return_viewer=return_viewer, **kwargs 358 )
Run the 2d annotation tool for a series of images in a folder.
Arguments:
- input_folder: The folder with the images to be annotated.
- output_folder: The folder where the segmentation results are saved.
- pattern: The glob patter for loading files from
input_folder
. By default all files will be loaded. - viewer: The viewer to which the Segment Anything functionality should be added. This enables using a pre-initialized viewer.
- return_viewer: Whether to return the napari viewer to further modify it before starting the tool. By default, does not return the napari viewer.
- kwargs: The keyword arguments for
micro_sam.sam_annotator.image_series_annotator
.
Returns:
The napari viewer, only returned if
return_viewer=True
.
361class ImageSeriesAnnotator(widgets._WidgetBase): 362 def __init__(self, viewer: napari.Viewer, parent=None): 363 super().__init__(parent=parent) 364 self._viewer = viewer 365 366 # Create the UI: the general options. 367 self._create_options() 368 369 # Add the settings (collapsible). 370 self.layout().addWidget(self._create_settings()) 371 372 # Add the run button to trigger the embedding computation. 373 self.run_button = QtWidgets.QPushButton("Annotate Images") 374 self.run_button.clicked.connect(self.__call__) 375 self.layout().addWidget(self.run_button) 376 377 def _create_options(self): 378 self.folder = None 379 _, layout = self._add_path_param( 380 "folder", self.folder, "directory", 381 title="Input Folder", placeholder="Folder with images ...", 382 tooltip=get_tooltip("image_series_annotator", "folder") 383 ) 384 self.layout().addLayout(layout) 385 386 self.output_folder = None 387 _, layout = self._add_path_param( 388 "output_folder", self.output_folder, "directory", 389 title="Output Folder", placeholder="Folder to save the results ...", 390 tooltip=get_tooltip("image_series_annotator", "output_folder") 391 ) 392 self.layout().addLayout(layout) 393 394 # Add the model family widget section. 395 layout = self._create_model_section(create_layout=False) 396 self.layout().addLayout(layout) 397 398 def _create_settings(self): 399 setting_values = QtWidgets.QWidget() 400 setting_values.setLayout(QtWidgets.QVBoxLayout()) 401 402 # Add the model size widget section. 403 layout = self._create_model_size_section() 404 setting_values.layout().addLayout(layout) 405 406 self.pattern = "*" 407 _, layout = self._add_string_param( 408 "pattern", self.pattern, tooltip=get_tooltip("image_series_annotator", "pattern") 409 ) 410 setting_values.layout().addLayout(layout) 411 412 self.is_volumetric = False 413 setting_values.layout().addWidget(self._add_boolean_param( 414 "is_volumetric", self.is_volumetric, tooltip=get_tooltip("image_series_annotator", "is_volumetric") 415 )) 416 417 self.device = "auto" 418 device_options = ["auto"] + util._available_devices() 419 self.device_dropdown, layout = self._add_choice_param( 420 "device", self.device, device_options, tooltip=get_tooltip("embedding", "device") 421 ) 422 setting_values.layout().addLayout(layout) 423 424 self.embeddings_save_path = None 425 _, layout = self._add_path_param( 426 "embeddings_save_path", self.embeddings_save_path, "directory", title="embeddings save path:", 427 tooltip=get_tooltip("embedding", "embeddings_save_path") 428 ) 429 setting_values.layout().addLayout(layout) 430 431 self.custom_weights = None # select_file 432 _, layout = self._add_path_param( 433 "custom_weights", self.custom_weights, "file", title="custom weights path:", 434 tooltip=get_tooltip("embedding", "custom_weights") 435 ) 436 setting_values.layout().addLayout(layout) 437 438 self.tile_x, self.tile_y = 0, 0 439 self.tile_x_param, self.tile_y_param, layout = self._add_shape_param( 440 ("tile_x", "tile_y"), (self.tile_x, self.tile_y), min_val=0, max_val=2048, step=16, 441 tooltip=get_tooltip("embedding", "tiling") 442 ) 443 setting_values.layout().addLayout(layout) 444 445 self.halo_x, self.halo_y = 0, 0 446 self.halo_x_param, self.halo_y_param, layout = self._add_shape_param( 447 ("halo_x", "halo_y"), (self.halo_x, self.halo_y), min_val=0, max_val=512, 448 tooltip=get_tooltip("embedding", "halo") 449 ) 450 setting_values.layout().addLayout(layout) 451 452 settings = widgets._make_collapsible(setting_values, title="Advanced Settings") 453 return settings 454 455 def _validate_inputs(self): 456 missing_data = self.folder is None or len(glob(os.path.join(self.folder, self.pattern))) == 0 457 missing_output = self.output_folder is None 458 if missing_data or missing_output: 459 msg = "" 460 if missing_data: 461 msg += "The input folder is missing or empty. " 462 if missing_output: 463 msg += "The output folder is missing." 464 return widgets._generate_message("error", msg) 465 return False 466 467 def __call__(self, skip_validate=False): 468 self._validate_model_type_and_custom_weights() 469 470 if not skip_validate and self._validate_inputs(): 471 return 472 tile_shape, halo = widgets._process_tiling_inputs(self.tile_x, self.tile_y, self.halo_x, self.halo_y) 473 474 image_folder_annotator( 475 input_folder=self.folder, 476 output_folder=self.output_folder, 477 pattern=self.pattern, 478 model_type=self.model_type, 479 embedding_path=self.embeddings_save_path, 480 tile_shape=tile_shape, halo=halo, checkpoint_path=self.custom_weights, 481 device=self.device, is_volumetric=self.is_volumetric, 482 viewer=self._viewer, return_viewer=True, 483 )
QWidget(parent: Optional[QWidget] = None, flags: Union[Qt.WindowFlags, Qt.WindowType] = Qt.WindowFlags())
ImageSeriesAnnotator(viewer: napari.viewer.Viewer, parent=None)
362 def __init__(self, viewer: napari.Viewer, parent=None): 363 super().__init__(parent=parent) 364 self._viewer = viewer 365 366 # Create the UI: the general options. 367 self._create_options() 368 369 # Add the settings (collapsible). 370 self.layout().addWidget(self._create_settings()) 371 372 # Add the run button to trigger the embedding computation. 373 self.run_button = QtWidgets.QPushButton("Annotate Images") 374 self.run_button.clicked.connect(self.__call__) 375 self.layout().addWidget(self.run_button)
Inherited Members
- PyQt5.QtWidgets.QWidget
- RenderFlag
- RenderFlags
- acceptDrops
- accessibleDescription
- accessibleName
- actionEvent
- actions
- activateWindow
- addAction
- addActions
- adjustSize
- autoFillBackground
- backgroundRole
- baseSize
- changeEvent
- childAt
- childrenRect
- childrenRegion
- clearFocus
- clearMask
- close
- closeEvent
- contentsMargins
- contentsRect
- contextMenuEvent
- contextMenuPolicy
- create
- createWindowContainer
- cursor
- destroy
- devType
- dragEnterEvent
- dragLeaveEvent
- dragMoveEvent
- dropEvent
- effectiveWinId
- ensurePolished
- enterEvent
- event
- find
- focusInEvent
- focusNextChild
- focusNextPrevChild
- focusOutEvent
- focusPolicy
- focusPreviousChild
- focusProxy
- focusWidget
- font
- fontInfo
- fontMetrics
- foregroundRole
- frameGeometry
- frameSize
- geometry
- getContentsMargins
- grab
- grabGesture
- grabKeyboard
- grabMouse
- grabShortcut
- graphicsEffect
- graphicsProxyWidget
- hasFocus
- hasHeightForWidth
- hasMouseTracking
- hasTabletTracking
- height
- heightForWidth
- hide
- hideEvent
- initPainter
- inputMethodEvent
- inputMethodHints
- inputMethodQuery
- insertAction
- insertActions
- isActiveWindow
- isAncestorOf
- isEnabled
- isEnabledTo
- isFullScreen
- isHidden
- isLeftToRight
- isMaximized
- isMinimized
- isModal
- isRightToLeft
- isVisible
- isVisibleTo
- isWindow
- isWindowModified
- keyPressEvent
- keyReleaseEvent
- keyboardGrabber
- layout
- layoutDirection
- leaveEvent
- locale
- lower
- mapFrom
- mapFromGlobal
- mapFromParent
- mapTo
- mapToGlobal
- mapToParent
- mask
- maximumHeight
- maximumSize
- maximumWidth
- metric
- minimumHeight
- minimumSize
- minimumSizeHint
- minimumWidth
- mouseDoubleClickEvent
- mouseGrabber
- mouseMoveEvent
- mousePressEvent
- mouseReleaseEvent
- move
- moveEvent
- nativeEvent
- nativeParentWidget
- nextInFocusChain
- normalGeometry
- overrideWindowFlags
- overrideWindowState
- paintEngine
- paintEvent
- palette
- parentWidget
- pos
- previousInFocusChain
- raise_
- rect
- releaseKeyboard
- releaseMouse
- releaseShortcut
- removeAction
- render
- repaint
- resize
- resizeEvent
- restoreGeometry
- saveGeometry
- screen
- scroll
- setAcceptDrops
- setAccessibleDescription
- setAccessibleName
- setAttribute
- setAutoFillBackground
- setBackgroundRole
- setBaseSize
- setContentsMargins
- setContextMenuPolicy
- setCursor
- setDisabled
- setEnabled
- setFixedHeight
- setFixedSize
- setFixedWidth
- setFocus
- setFocusPolicy
- setFocusProxy
- setFont
- setForegroundRole
- setGeometry
- setGraphicsEffect
- setHidden
- setInputMethodHints
- setLayout
- setLayoutDirection
- setLocale
- setMask
- setMaximumHeight
- setMaximumSize
- setMaximumWidth
- setMinimumHeight
- setMinimumSize
- setMinimumWidth
- setMouseTracking
- setPalette
- setParent
- setShortcutAutoRepeat
- setShortcutEnabled
- setSizeIncrement
- setSizePolicy
- setStatusTip
- setStyle
- setStyleSheet
- setTabOrder
- setTabletTracking
- setToolTip
- setToolTipDuration
- setUpdatesEnabled
- setVisible
- setWhatsThis
- setWindowFilePath
- setWindowFlag
- setWindowFlags
- setWindowIcon
- setWindowIconText
- setWindowModality
- setWindowModified
- setWindowOpacity
- setWindowRole
- setWindowState
- setWindowTitle
- show
- showEvent
- showFullScreen
- showMaximized
- showMinimized
- showNormal
- size
- sizeHint
- sizeIncrement
- sizePolicy
- stackUnder
- statusTip
- style
- styleSheet
- tabletEvent
- testAttribute
- toolTip
- toolTipDuration
- underMouse
- ungrabGesture
- unsetCursor
- unsetLayoutDirection
- unsetLocale
- update
- updateGeometry
- updateMicroFocus
- updatesEnabled
- visibleRegion
- whatsThis
- wheelEvent
- width
- winId
- window
- windowFilePath
- windowFlags
- windowHandle
- windowIcon
- windowIconText
- windowModality
- windowOpacity
- windowRole
- windowState
- windowTitle
- windowType
- x
- y
- DrawChildren
- DrawWindowBackground
- IgnoreMask
- windowIconTextChanged
- windowIconChanged
- windowTitleChanged
- customContextMenuRequested
- PyQt5.QtCore.QObject
- blockSignals
- childEvent
- children
- connectNotify
- customEvent
- deleteLater
- disconnect
- disconnectNotify
- dumpObjectInfo
- dumpObjectTree
- dynamicPropertyNames
- eventFilter
- findChild
- findChildren
- inherits
- installEventFilter
- isSignalConnected
- isWidgetType
- isWindowType
- killTimer
- metaObject
- moveToThread
- objectName
- parent
- property
- pyqtConfigure
- receivers
- removeEventFilter
- sender
- senderSignalIndex
- setObjectName
- setProperty
- signalsBlocked
- startTimer
- thread
- timerEvent
- tr
- staticMetaObject
- objectNameChanged
- destroyed
- PyQt5.QtGui.QPaintDevice
- PaintDeviceMetric
- colorCount
- depth
- devicePixelRatio
- devicePixelRatioF
- devicePixelRatioFScale
- heightMM
- logicalDpiX
- logicalDpiY
- paintingActive
- physicalDpiX
- physicalDpiY
- widthMM
- PdmDepth
- PdmDevicePixelRatio
- PdmDevicePixelRatioScaled
- PdmDpiX
- PdmDpiY
- PdmHeight
- PdmHeightMM
- PdmNumColors
- PdmPhysicalDpiX
- PdmPhysicalDpiY
- PdmWidth
- PdmWidthMM