micro_sam.sam_annotator.image_series_annotator
1import os 2import time 3 4from glob import glob 5from pathlib import Path 6from typing import List, Optional, Union, Tuple 7 8import numpy as np 9import imageio.v3 as imageio 10 11import torch 12 13import napari 14from magicgui import magicgui 15from qtpy import QtWidgets 16 17from .. import util 18from . import _widgets as widgets 19from ..precompute_state import _precompute_state_for_files 20from ..instance_segmentation import get_decoder 21from .annotator_2d import Annotator2d 22from .annotator_3d import Annotator3d 23from ._state import AnnotatorState 24from .util import _sync_embedding_widget 25from ._tooltips import get_tooltip 26 27 28def _precompute( 29 images, model_type, embedding_path, 30 tile_shape, halo, precompute_amg_state, 31 checkpoint_path, device, ndim, prefer_decoder, 32): 33 t_start = time.time() 34 35 device = util.get_device(device) 36 predictor, state = util.get_sam_model( 37 model_type=model_type, checkpoint_path=checkpoint_path, device=device, return_state=True 38 ) 39 if prefer_decoder and "decoder_state" in state: 40 decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device) 41 else: 42 decoder = None 43 44 if embedding_path is None: 45 embedding_paths = [None] * len(images) 46 else: 47 _precompute_state_for_files( 48 predictor, images, embedding_path, ndim=ndim, tile_shape=tile_shape, halo=halo, 49 precompute_amg_state=precompute_amg_state, decoder=decoder, 50 ) 51 if isinstance(images[0], np.ndarray): 52 embedding_paths = [ 53 os.path.join(embedding_path, f"embedding_{i:05}.zarr") for i, path in enumerate(images) 54 ] 55 else: 56 embedding_paths = [ 57 os.path.join(embedding_path, f"{Path(path).stem}.zarr") for path in images 58 ] 59 assert all(os.path.exists(emb_path) for emb_path in embedding_paths) 60 61 t_run = time.time() - t_start 62 minutes = int(t_run // 60) 63 seconds = int(round(t_run % 60, 0)) 64 print("Precomputation took", t_run, f"seconds (= {minutes:02}:{seconds:02} minutes)") 65 66 return predictor, decoder, embedding_paths 67 68 69def _get_input_shape(image, is_volumetric=False): 70 if image.ndim == 2: 71 image_shape = image.shape 72 elif image.ndim == 3: 73 if is_volumetric: 74 image_shape = image.shape 75 else: 76 image_shape = image.shape[:-1] 77 elif image.ndim == 4: 78 image_shape = image.shape[:-1] 79 80 return image_shape 81 82 83def _initialize_annotator( 84 viewer, image, image_embedding_path, 85 model_type, halo, tile_shape, predictor, decoder, is_volumetric, 86 precompute_amg_state, checkpoint_path, device, 87 embedding_path, 88): 89 if viewer is None: 90 viewer = napari.Viewer() 91 viewer.add_image(image, name="image") 92 93 state = AnnotatorState() 94 state.initialize_predictor( 95 image, model_type=model_type, save_path=image_embedding_path, halo=halo, tile_shape=tile_shape, 96 predictor=predictor, decoder=decoder, 97 ndim=3 if is_volumetric else 2, precompute_amg_state=precompute_amg_state, 98 checkpoint_path=checkpoint_path, device=device, skip_load=False, 99 ) 100 state.image_shape = _get_input_shape(image, is_volumetric) 101 102 if is_volumetric: 103 if image.ndim not in [3, 4]: 104 raise ValueError(f"Invalid image dimensions for 3d annotator, expect 3 or 4 dimensions, got {image.ndim}") 105 annotator = Annotator3d(viewer) 106 else: 107 if image.ndim not in (2, 3): 108 raise ValueError(f"Invalid image dimensions for 2d annotator, expect 2 or 3 dimensions, got {image.ndim}") 109 annotator = Annotator2d(viewer) 110 111 annotator._update_image() 112 113 # Add the annotator widget to the viewer and sync widgets. 114 viewer.window.add_dock_widget(annotator) 115 _sync_embedding_widget( 116 state.widgets["embeddings"], model_type, 117 save_path=embedding_path, checkpoint_path=checkpoint_path, 118 device=device, tile_shape=tile_shape, halo=halo 119 ) 120 return viewer, annotator 121 122 123def image_series_annotator( 124 images: Union[List[Union[os.PathLike, str]], List[np.ndarray]], 125 output_folder: str, 126 model_type: str = util._DEFAULT_MODEL, 127 embedding_path: Optional[str] = None, 128 tile_shape: Optional[Tuple[int, int]] = None, 129 halo: Optional[Tuple[int, int]] = None, 130 viewer: Optional["napari.viewer.Viewer"] = None, 131 return_viewer: bool = False, 132 precompute_amg_state: bool = False, 133 checkpoint_path: Optional[str] = None, 134 is_volumetric: bool = False, 135 device: Optional[Union[str, torch.device]] = None, 136 prefer_decoder: bool = True, 137 skip_segmented: bool = True, 138) -> Optional["napari.viewer.Viewer"]: 139 """Run the annotation tool for a series of images (supported for both 2d and 3d images). 140 141 Args: 142 images: List of the file paths or list of (set of) slices for the images to be annotated. 143 output_folder: The folder where the segmentation results are saved. 144 model_type: The Segment Anything model to use. For details on the available models check out 145 https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models. 146 embedding_path: Filepath where to save the embeddings. 147 tile_shape: Shape of tiles for tiled embedding prediction. 148 If `None` then the whole image is passed to Segment Anything. 149 halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders. 150 viewer: The viewer to which the SegmentAnything functionality should be added. 151 This enables using a pre-initialized viewer. 152 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 153 precompute_amg_state: Whether to precompute the state for automatic mask generation. 154 This will take more time when precomputing embeddings, but will then make 155 automatic mask generation much faster. 156 checkpoint_path: Path to a custom checkpoint from which to load the SAM model. 157 is_volumetric: Whether to use the 3d annotator. 158 prefer_decoder: Whether to use decoder based instance segmentation if 159 the model used has an additional decoder for instance segmentation. 160 skip_segmented: Whether to skip images that were already segmented. 161 162 Returns: 163 The napari viewer, only returned if `return_viewer=True`. 164 """ 165 end_msg = "You have annotated the last image. Do you wish to close napari?" 166 os.makedirs(output_folder, exist_ok=True) 167 168 # Precompute embeddings and amg state (if corresponding options set). 169 predictor, decoder, embedding_paths = _precompute( 170 images, model_type, 171 embedding_path, tile_shape, halo, precompute_amg_state, 172 checkpoint_path=checkpoint_path, device=device, 173 ndim=3 if is_volumetric else 2, prefer_decoder=prefer_decoder, 174 ) 175 176 next_image_id = 0 177 have_inputs_as_arrays = isinstance(images[next_image_id], np.ndarray) 178 179 def _get_save_path(image_path, current_idx): 180 if have_inputs_as_arrays: 181 fname = f"seg_{current_idx:05}.tif" 182 else: 183 fname = os.path.basename(image_path) 184 fname = os.path.splitext(fname)[0] + ".tif" 185 return os.path.join(output_folder, fname) 186 187 # Check which image to load next if we skip segmented images. 188 image_embedding_path = None 189 if skip_segmented: 190 while True: 191 if next_image_id == len(images): 192 print(end_msg) 193 return 194 195 save_path = _get_save_path(images[next_image_id], next_image_id) 196 if not os.path.exists(save_path): 197 print("The first image to annotate is image number", next_image_id) 198 image = images[next_image_id] 199 if not have_inputs_as_arrays: 200 image = imageio.imread(image) 201 image_embedding_path = embedding_paths[next_image_id] 202 break 203 204 next_image_id += 1 205 206 # Initialize the viewer and annotator for this image. 207 state = AnnotatorState() 208 viewer, annotator = _initialize_annotator( 209 viewer, image, image_embedding_path, 210 model_type, halo, tile_shape, predictor, decoder, is_volumetric, 211 precompute_amg_state, checkpoint_path, device, 212 embedding_path, 213 ) 214 215 def _save_segmentation(image_path, current_idx, segmentation): 216 save_path = _get_save_path(image_path, next_image_id) 217 imageio.imwrite(save_path, segmentation, compression="zlib") 218 219 # Add functionality for going to the next image. 220 @magicgui(call_button="Next Image [N]") 221 def next_image(*args): 222 nonlocal next_image_id 223 224 segmentation = viewer.layers["committed_objects"].data 225 abort = False 226 if segmentation.sum() == 0: 227 msg = "Nothing is segmented yet. Do you wish to continue to the next image?" 228 abort = widgets._generate_message("info", msg) 229 if abort: 230 return 231 232 # Save the current segmentation. 233 _save_segmentation(images[next_image_id], next_image_id, segmentation) 234 235 # Clear the segmentation already to avoid lagging removal. 236 viewer.layers["committed_objects"].data = np.zeros_like(viewer.layers["committed_objects"].data) 237 238 # Go to the next images, if skipping images that are already segmented check if we have to load it. 239 next_image_id += 1 240 if skip_segmented: 241 save_path = _get_save_path(images[next_image_id], next_image_id) 242 while os.path.exists(save_path): 243 next_image_id += 1 244 if next_image_id == len(images): 245 break 246 save_path = _get_save_path(images[next_image_id], next_image_id) 247 248 # Load the next image. 249 if next_image_id == len(images): 250 # Inform the user via dialog. 251 abort = widgets._generate_message("info", end_msg) 252 if not abort: 253 viewer.close() 254 return 255 256 print( 257 "Loading next image:", 258 images[next_image_id] if not have_inputs_as_arrays else f"at index {next_image_id}" 259 ) 260 261 if have_inputs_as_arrays: 262 image = images[next_image_id] 263 else: 264 image = imageio.imread(images[next_image_id]) 265 266 image_embedding_path = embedding_paths[next_image_id] 267 268 # Set the new image in the viewer, state and annotator. 269 viewer.layers["image"].data = image 270 271 if state.amg is not None: 272 state.amg.clear_state() 273 state.initialize_predictor( 274 image, model_type=model_type, ndim=3 if is_volumetric else 2, 275 save_path=image_embedding_path, 276 tile_shape=tile_shape, halo=halo, 277 predictor=predictor, decoder=decoder, 278 precompute_amg_state=precompute_amg_state, device=device, 279 skip_load=False, 280 ) 281 state.image_shape = _get_input_shape(image, is_volumetric) 282 283 annotator._update_image() 284 285 viewer.window.add_dock_widget(next_image) 286 287 @viewer.bind_key("n", overwrite=True) 288 def _next_image(viewer): 289 next_image(viewer) 290 291 if return_viewer: 292 return viewer 293 napari.run() 294 295 296def image_folder_annotator( 297 input_folder: str, 298 output_folder: str, 299 pattern: str = "*", 300 viewer: Optional["napari.viewer.Viewer"] = None, 301 return_viewer: bool = False, 302 **kwargs 303) -> Optional["napari.viewer.Viewer"]: 304 """Run the 2d annotation tool for a series of images in a folder. 305 306 Args: 307 input_folder: The folder with the images to be annotated. 308 output_folder: The folder where the segmentation results are saved. 309 pattern: The glob patter for loading files from `input_folder`. 310 By default all files will be loaded. 311 viewer: The viewer to which the SegmentAnything functionality should be added. 312 This enables using a pre-initialized viewer. 313 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 314 kwargs: The keyword arguments for `micro_sam.sam_annotator.image_series_annotator`. 315 316 Returns: 317 The napari viewer, only returned if `return_viewer=True`. 318 """ 319 image_files = sorted(glob(os.path.join(input_folder, pattern))) 320 return image_series_annotator( 321 image_files, output_folder, viewer=viewer, return_viewer=return_viewer, **kwargs 322 ) 323 324 325class ImageSeriesAnnotator(widgets._WidgetBase): 326 def __init__(self, viewer: napari.Viewer, parent=None): 327 super().__init__(parent=parent) 328 self._viewer = viewer 329 330 # Create the UI: the general options. 331 self._create_options() 332 333 # Add the settings (collapsible). 334 self.layout().addWidget(self._create_settings()) 335 336 # Add the run button to trigger the embedding computation. 337 self.run_button = QtWidgets.QPushButton("Annotate Images") 338 self.run_button.clicked.connect(self.__call__) 339 self.layout().addWidget(self.run_button) 340 341 # model_type: str = util._DEFAULT_MODEL, 342 def _create_options(self): 343 self.folder = None 344 _, layout = self._add_path_param( 345 "folder", self.folder, "directory", 346 title="Input Folder", placeholder="Folder with images ...", 347 tooltip=get_tooltip("image_series_annotator", "folder") 348 ) 349 self.layout().addLayout(layout) 350 351 self.output_folder = None 352 _, layout = self._add_path_param( 353 "output_folder", self.output_folder, "directory", 354 title="Output Folder", placeholder="Folder to save the results ...", 355 tooltip=get_tooltip("image_series_annotator", "output_folder") 356 ) 357 self.layout().addLayout(layout) 358 359 self.model_type = util._DEFAULT_MODEL 360 model_options = list(util.models().urls.keys()) 361 model_options = [model for model in model_options if not model.endswith("decoder")] 362 _, layout = self._add_choice_param( 363 "model_type", self.model_type, model_options, title="Model:", 364 tooltip=get_tooltip("embedding", "model") 365 ) 366 self.layout().addLayout(layout) 367 368 def _create_settings(self): 369 setting_values = QtWidgets.QWidget() 370 setting_values.setLayout(QtWidgets.QVBoxLayout()) 371 372 self.pattern = "*" 373 _, layout = self._add_string_param( 374 "pattern", self.pattern, tooltip=get_tooltip("image_series_annotator", "pattern") 375 ) 376 setting_values.layout().addLayout(layout) 377 378 self.is_volumetric = False 379 setting_values.layout().addWidget(self._add_boolean_param( 380 "is_volumetric", self.is_volumetric, tooltip=get_tooltip("image_series_annotator", "is_volumetric") 381 )) 382 383 self.device = "auto" 384 device_options = ["auto"] + util._available_devices() 385 self.device_dropdown, layout = self._add_choice_param( 386 "device", self.device, device_options, tooltip=get_tooltip("embedding", "device") 387 ) 388 setting_values.layout().addLayout(layout) 389 390 self.embeddings_save_path = None 391 _, layout = self._add_path_param( 392 "embeddings_save_path", self.embeddings_save_path, "directory", title="embeddings save path:", 393 tooltip=get_tooltip("embedding", "embeddings_save_path") 394 ) 395 setting_values.layout().addLayout(layout) 396 397 self.custom_weights = None # select_file 398 _, layout = self._add_path_param( 399 "custom_weights", self.custom_weights, "file", title="custom weights path:", 400 tooltip=get_tooltip("embedding", "custom_weights") 401 ) 402 setting_values.layout().addLayout(layout) 403 404 self.tile_x, self.tile_y = 0, 0 405 self.tile_x_param, self.tile_y_param, layout = self._add_shape_param( 406 ("tile_x", "tile_y"), (self.tile_x, self.tile_y), min_val=0, max_val=2048, step=16, 407 tooltip=get_tooltip("embedding", "tiling") 408 ) 409 setting_values.layout().addLayout(layout) 410 411 self.halo_x, self.halo_y = 0, 0 412 self.halo_x_param, self.halo_y_param, layout = self._add_shape_param( 413 ("halo_x", "halo_y"), (self.halo_x, self.halo_y), min_val=0, max_val=512, 414 tooltip=get_tooltip("embedding", "halo") 415 ) 416 setting_values.layout().addLayout(layout) 417 418 settings = widgets._make_collapsible(setting_values, title="Advanced Settings") 419 return settings 420 421 def _validate_inputs(self): 422 missing_data = self.folder is None or len(glob(os.path.join(self.folder, self.pattern))) == 0 423 missing_output = self.output_folder is None 424 if missing_data or missing_output: 425 msg = "" 426 if missing_data: 427 msg += "The input folder is missing or empty. " 428 if missing_output: 429 msg += "The output folder is missing." 430 return widgets._generate_message("error", msg) 431 return False 432 433 def __call__(self, skip_validate=False): 434 if not skip_validate and self._validate_inputs(): 435 return 436 tile_shape, halo = widgets._process_tiling_inputs(self.tile_x, self.tile_y, self.halo_x, self.halo_y) 437 438 image_folder_annotator( 439 self.folder, self.output_folder, self.pattern, 440 model_type=self.model_type, 441 embedding_path=self.embeddings_save_path, 442 tile_shape=tile_shape, halo=halo, checkpoint_path=self.custom_weights, 443 device=self.device, is_volumetric=self.is_volumetric, 444 viewer=self._viewer, return_viewer=True, 445 ) 446 447 448def main(): 449 """@private""" 450 import argparse 451 452 available_models = list(util.get_model_names()) 453 available_models = ", ".join(available_models) 454 455 parser = argparse.ArgumentParser(description="Annotate a series of images from a folder.") 456 parser.add_argument( 457 "-i", "--input_folder", required=True, 458 help="The folder containing the image data. The data can be stored in any common format (tif, jpg, png, ...)." 459 ) 460 parser.add_argument( 461 "-o", "--output_folder", required=True, 462 help="The folder where the segmentation results will be stored." 463 ) 464 parser.add_argument( 465 "-p", "--pattern", default="*", 466 help="The pattern to select the images to annotator from the input folder. E.g. *.tif to annotate all tifs." 467 "By default all files in the folder will be loaded and annotated." 468 ) 469 parser.add_argument( 470 "-e", "--embedding_path", 471 help="The filepath for saving/loading the pre-computed image embeddings. " 472 "NOTE: It is recommended to pass this argument and store the embeddings, " 473 "otherwise they will be recomputed every time (which can take a long time)." 474 ) 475 parser.add_argument( 476 "-m", "--model_type", default=util._DEFAULT_MODEL, 477 help=f"The segment anything model that will be used, one of {available_models}." 478 ) 479 parser.add_argument( 480 "-c", "--checkpoint", default=None, 481 help="Checkpoint from which the SAM model will be loaded loaded." 482 ) 483 parser.add_argument( 484 "-d", "--device", default=None, 485 help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)." 486 "By default the most performant available device will be selected." 487 ) 488 parser.add_argument( 489 "--is_volumetric", action="store_true", help="Whether to use the 3d annotator for a set of 3d volumes." 490 ) 491 492 parser.add_argument( 493 "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction", default=None 494 ) 495 parser.add_argument( 496 "--halo", nargs="+", type=int, help="The halo for using tiled prediction", default=None 497 ) 498 parser.add_argument("--precompute_amg_state", action="store_true") 499 parser.add_argument("--prefer_decoder", action="store_false") 500 parser.add_argument("--skip_segmented", action="store_false") 501 502 args = parser.parse_args() 503 504 image_folder_annotator( 505 args.input_folder, args.output_folder, args.pattern, 506 embedding_path=args.embedding_path, model_type=args.model_type, 507 tile_shape=args.tile_shape, halo=args.halo, precompute_amg_state=args.precompute_amg_state, 508 checkpoint_path=args.checkpoint, device=args.device, is_volumetric=args.is_volumetric, 509 prefer_decoder=args.prefer_decoder, skip_segmented=args.skip_segmented 510 )
def
image_series_annotator( images: Union[List[Union[str, os.PathLike]], List[numpy.ndarray]], output_folder: str, model_type: str = 'vit_l', embedding_path: Optional[str] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, viewer: Optional[napari.viewer.Viewer] = None, return_viewer: bool = False, precompute_amg_state: bool = False, checkpoint_path: Optional[str] = None, is_volumetric: bool = False, device: Union[str, torch.device, NoneType] = None, prefer_decoder: bool = True, skip_segmented: bool = True) -> Optional[napari.viewer.Viewer]:
124def image_series_annotator( 125 images: Union[List[Union[os.PathLike, str]], List[np.ndarray]], 126 output_folder: str, 127 model_type: str = util._DEFAULT_MODEL, 128 embedding_path: Optional[str] = None, 129 tile_shape: Optional[Tuple[int, int]] = None, 130 halo: Optional[Tuple[int, int]] = None, 131 viewer: Optional["napari.viewer.Viewer"] = None, 132 return_viewer: bool = False, 133 precompute_amg_state: bool = False, 134 checkpoint_path: Optional[str] = None, 135 is_volumetric: bool = False, 136 device: Optional[Union[str, torch.device]] = None, 137 prefer_decoder: bool = True, 138 skip_segmented: bool = True, 139) -> Optional["napari.viewer.Viewer"]: 140 """Run the annotation tool for a series of images (supported for both 2d and 3d images). 141 142 Args: 143 images: List of the file paths or list of (set of) slices for the images to be annotated. 144 output_folder: The folder where the segmentation results are saved. 145 model_type: The Segment Anything model to use. For details on the available models check out 146 https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models. 147 embedding_path: Filepath where to save the embeddings. 148 tile_shape: Shape of tiles for tiled embedding prediction. 149 If `None` then the whole image is passed to Segment Anything. 150 halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders. 151 viewer: The viewer to which the SegmentAnything functionality should be added. 152 This enables using a pre-initialized viewer. 153 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 154 precompute_amg_state: Whether to precompute the state for automatic mask generation. 155 This will take more time when precomputing embeddings, but will then make 156 automatic mask generation much faster. 157 checkpoint_path: Path to a custom checkpoint from which to load the SAM model. 158 is_volumetric: Whether to use the 3d annotator. 159 prefer_decoder: Whether to use decoder based instance segmentation if 160 the model used has an additional decoder for instance segmentation. 161 skip_segmented: Whether to skip images that were already segmented. 162 163 Returns: 164 The napari viewer, only returned if `return_viewer=True`. 165 """ 166 end_msg = "You have annotated the last image. Do you wish to close napari?" 167 os.makedirs(output_folder, exist_ok=True) 168 169 # Precompute embeddings and amg state (if corresponding options set). 170 predictor, decoder, embedding_paths = _precompute( 171 images, model_type, 172 embedding_path, tile_shape, halo, precompute_amg_state, 173 checkpoint_path=checkpoint_path, device=device, 174 ndim=3 if is_volumetric else 2, prefer_decoder=prefer_decoder, 175 ) 176 177 next_image_id = 0 178 have_inputs_as_arrays = isinstance(images[next_image_id], np.ndarray) 179 180 def _get_save_path(image_path, current_idx): 181 if have_inputs_as_arrays: 182 fname = f"seg_{current_idx:05}.tif" 183 else: 184 fname = os.path.basename(image_path) 185 fname = os.path.splitext(fname)[0] + ".tif" 186 return os.path.join(output_folder, fname) 187 188 # Check which image to load next if we skip segmented images. 189 image_embedding_path = None 190 if skip_segmented: 191 while True: 192 if next_image_id == len(images): 193 print(end_msg) 194 return 195 196 save_path = _get_save_path(images[next_image_id], next_image_id) 197 if not os.path.exists(save_path): 198 print("The first image to annotate is image number", next_image_id) 199 image = images[next_image_id] 200 if not have_inputs_as_arrays: 201 image = imageio.imread(image) 202 image_embedding_path = embedding_paths[next_image_id] 203 break 204 205 next_image_id += 1 206 207 # Initialize the viewer and annotator for this image. 208 state = AnnotatorState() 209 viewer, annotator = _initialize_annotator( 210 viewer, image, image_embedding_path, 211 model_type, halo, tile_shape, predictor, decoder, is_volumetric, 212 precompute_amg_state, checkpoint_path, device, 213 embedding_path, 214 ) 215 216 def _save_segmentation(image_path, current_idx, segmentation): 217 save_path = _get_save_path(image_path, next_image_id) 218 imageio.imwrite(save_path, segmentation, compression="zlib") 219 220 # Add functionality for going to the next image. 221 @magicgui(call_button="Next Image [N]") 222 def next_image(*args): 223 nonlocal next_image_id 224 225 segmentation = viewer.layers["committed_objects"].data 226 abort = False 227 if segmentation.sum() == 0: 228 msg = "Nothing is segmented yet. Do you wish to continue to the next image?" 229 abort = widgets._generate_message("info", msg) 230 if abort: 231 return 232 233 # Save the current segmentation. 234 _save_segmentation(images[next_image_id], next_image_id, segmentation) 235 236 # Clear the segmentation already to avoid lagging removal. 237 viewer.layers["committed_objects"].data = np.zeros_like(viewer.layers["committed_objects"].data) 238 239 # Go to the next images, if skipping images that are already segmented check if we have to load it. 240 next_image_id += 1 241 if skip_segmented: 242 save_path = _get_save_path(images[next_image_id], next_image_id) 243 while os.path.exists(save_path): 244 next_image_id += 1 245 if next_image_id == len(images): 246 break 247 save_path = _get_save_path(images[next_image_id], next_image_id) 248 249 # Load the next image. 250 if next_image_id == len(images): 251 # Inform the user via dialog. 252 abort = widgets._generate_message("info", end_msg) 253 if not abort: 254 viewer.close() 255 return 256 257 print( 258 "Loading next image:", 259 images[next_image_id] if not have_inputs_as_arrays else f"at index {next_image_id}" 260 ) 261 262 if have_inputs_as_arrays: 263 image = images[next_image_id] 264 else: 265 image = imageio.imread(images[next_image_id]) 266 267 image_embedding_path = embedding_paths[next_image_id] 268 269 # Set the new image in the viewer, state and annotator. 270 viewer.layers["image"].data = image 271 272 if state.amg is not None: 273 state.amg.clear_state() 274 state.initialize_predictor( 275 image, model_type=model_type, ndim=3 if is_volumetric else 2, 276 save_path=image_embedding_path, 277 tile_shape=tile_shape, halo=halo, 278 predictor=predictor, decoder=decoder, 279 precompute_amg_state=precompute_amg_state, device=device, 280 skip_load=False, 281 ) 282 state.image_shape = _get_input_shape(image, is_volumetric) 283 284 annotator._update_image() 285 286 viewer.window.add_dock_widget(next_image) 287 288 @viewer.bind_key("n", overwrite=True) 289 def _next_image(viewer): 290 next_image(viewer) 291 292 if return_viewer: 293 return viewer 294 napari.run()
Run the annotation tool for a series of images (supported for both 2d and 3d images).
Arguments:
- images: List of the file paths or list of (set of) slices for the images to be annotated.
- output_folder: The folder where the segmentation results are saved.
- model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
- embedding_path: Filepath where to save the embeddings.
- tile_shape: Shape of tiles for tiled embedding prediction.
If
None
then the whole image is passed to Segment Anything. - halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
- viewer: The viewer to which the SegmentAnything functionality should be added. This enables using a pre-initialized viewer.
- return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
- precompute_amg_state: Whether to precompute the state for automatic mask generation. This will take more time when precomputing embeddings, but will then make automatic mask generation much faster.
- checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
- is_volumetric: Whether to use the 3d annotator.
- prefer_decoder: Whether to use decoder based instance segmentation if the model used has an additional decoder for instance segmentation.
- skip_segmented: Whether to skip images that were already segmented.
Returns:
The napari viewer, only returned if
return_viewer=True
.
def
image_folder_annotator( input_folder: str, output_folder: str, pattern: str = '*', viewer: Optional[napari.viewer.Viewer] = None, return_viewer: bool = False, **kwargs) -> Optional[napari.viewer.Viewer]:
297def image_folder_annotator( 298 input_folder: str, 299 output_folder: str, 300 pattern: str = "*", 301 viewer: Optional["napari.viewer.Viewer"] = None, 302 return_viewer: bool = False, 303 **kwargs 304) -> Optional["napari.viewer.Viewer"]: 305 """Run the 2d annotation tool for a series of images in a folder. 306 307 Args: 308 input_folder: The folder with the images to be annotated. 309 output_folder: The folder where the segmentation results are saved. 310 pattern: The glob patter for loading files from `input_folder`. 311 By default all files will be loaded. 312 viewer: The viewer to which the SegmentAnything functionality should be added. 313 This enables using a pre-initialized viewer. 314 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 315 kwargs: The keyword arguments for `micro_sam.sam_annotator.image_series_annotator`. 316 317 Returns: 318 The napari viewer, only returned if `return_viewer=True`. 319 """ 320 image_files = sorted(glob(os.path.join(input_folder, pattern))) 321 return image_series_annotator( 322 image_files, output_folder, viewer=viewer, return_viewer=return_viewer, **kwargs 323 )
Run the 2d annotation tool for a series of images in a folder.
Arguments:
- input_folder: The folder with the images to be annotated.
- output_folder: The folder where the segmentation results are saved.
- pattern: The glob patter for loading files from
input_folder
. By default all files will be loaded. - viewer: The viewer to which the SegmentAnything functionality should be added. This enables using a pre-initialized viewer.
- return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
- kwargs: The keyword arguments for
micro_sam.sam_annotator.image_series_annotator
.
Returns:
The napari viewer, only returned if
return_viewer=True
.
326class ImageSeriesAnnotator(widgets._WidgetBase): 327 def __init__(self, viewer: napari.Viewer, parent=None): 328 super().__init__(parent=parent) 329 self._viewer = viewer 330 331 # Create the UI: the general options. 332 self._create_options() 333 334 # Add the settings (collapsible). 335 self.layout().addWidget(self._create_settings()) 336 337 # Add the run button to trigger the embedding computation. 338 self.run_button = QtWidgets.QPushButton("Annotate Images") 339 self.run_button.clicked.connect(self.__call__) 340 self.layout().addWidget(self.run_button) 341 342 # model_type: str = util._DEFAULT_MODEL, 343 def _create_options(self): 344 self.folder = None 345 _, layout = self._add_path_param( 346 "folder", self.folder, "directory", 347 title="Input Folder", placeholder="Folder with images ...", 348 tooltip=get_tooltip("image_series_annotator", "folder") 349 ) 350 self.layout().addLayout(layout) 351 352 self.output_folder = None 353 _, layout = self._add_path_param( 354 "output_folder", self.output_folder, "directory", 355 title="Output Folder", placeholder="Folder to save the results ...", 356 tooltip=get_tooltip("image_series_annotator", "output_folder") 357 ) 358 self.layout().addLayout(layout) 359 360 self.model_type = util._DEFAULT_MODEL 361 model_options = list(util.models().urls.keys()) 362 model_options = [model for model in model_options if not model.endswith("decoder")] 363 _, layout = self._add_choice_param( 364 "model_type", self.model_type, model_options, title="Model:", 365 tooltip=get_tooltip("embedding", "model") 366 ) 367 self.layout().addLayout(layout) 368 369 def _create_settings(self): 370 setting_values = QtWidgets.QWidget() 371 setting_values.setLayout(QtWidgets.QVBoxLayout()) 372 373 self.pattern = "*" 374 _, layout = self._add_string_param( 375 "pattern", self.pattern, tooltip=get_tooltip("image_series_annotator", "pattern") 376 ) 377 setting_values.layout().addLayout(layout) 378 379 self.is_volumetric = False 380 setting_values.layout().addWidget(self._add_boolean_param( 381 "is_volumetric", self.is_volumetric, tooltip=get_tooltip("image_series_annotator", "is_volumetric") 382 )) 383 384 self.device = "auto" 385 device_options = ["auto"] + util._available_devices() 386 self.device_dropdown, layout = self._add_choice_param( 387 "device", self.device, device_options, tooltip=get_tooltip("embedding", "device") 388 ) 389 setting_values.layout().addLayout(layout) 390 391 self.embeddings_save_path = None 392 _, layout = self._add_path_param( 393 "embeddings_save_path", self.embeddings_save_path, "directory", title="embeddings save path:", 394 tooltip=get_tooltip("embedding", "embeddings_save_path") 395 ) 396 setting_values.layout().addLayout(layout) 397 398 self.custom_weights = None # select_file 399 _, layout = self._add_path_param( 400 "custom_weights", self.custom_weights, "file", title="custom weights path:", 401 tooltip=get_tooltip("embedding", "custom_weights") 402 ) 403 setting_values.layout().addLayout(layout) 404 405 self.tile_x, self.tile_y = 0, 0 406 self.tile_x_param, self.tile_y_param, layout = self._add_shape_param( 407 ("tile_x", "tile_y"), (self.tile_x, self.tile_y), min_val=0, max_val=2048, step=16, 408 tooltip=get_tooltip("embedding", "tiling") 409 ) 410 setting_values.layout().addLayout(layout) 411 412 self.halo_x, self.halo_y = 0, 0 413 self.halo_x_param, self.halo_y_param, layout = self._add_shape_param( 414 ("halo_x", "halo_y"), (self.halo_x, self.halo_y), min_val=0, max_val=512, 415 tooltip=get_tooltip("embedding", "halo") 416 ) 417 setting_values.layout().addLayout(layout) 418 419 settings = widgets._make_collapsible(setting_values, title="Advanced Settings") 420 return settings 421 422 def _validate_inputs(self): 423 missing_data = self.folder is None or len(glob(os.path.join(self.folder, self.pattern))) == 0 424 missing_output = self.output_folder is None 425 if missing_data or missing_output: 426 msg = "" 427 if missing_data: 428 msg += "The input folder is missing or empty. " 429 if missing_output: 430 msg += "The output folder is missing." 431 return widgets._generate_message("error", msg) 432 return False 433 434 def __call__(self, skip_validate=False): 435 if not skip_validate and self._validate_inputs(): 436 return 437 tile_shape, halo = widgets._process_tiling_inputs(self.tile_x, self.tile_y, self.halo_x, self.halo_y) 438 439 image_folder_annotator( 440 self.folder, self.output_folder, self.pattern, 441 model_type=self.model_type, 442 embedding_path=self.embeddings_save_path, 443 tile_shape=tile_shape, halo=halo, checkpoint_path=self.custom_weights, 444 device=self.device, is_volumetric=self.is_volumetric, 445 viewer=self._viewer, return_viewer=True, 446 )
QWidget(parent: typing.Optional[QWidget] = None, flags: Union[Qt.WindowFlags, Qt.WindowType] = Qt.WindowFlags())
ImageSeriesAnnotator(viewer: napari.viewer.Viewer, parent=None)
327 def __init__(self, viewer: napari.Viewer, parent=None): 328 super().__init__(parent=parent) 329 self._viewer = viewer 330 331 # Create the UI: the general options. 332 self._create_options() 333 334 # Add the settings (collapsible). 335 self.layout().addWidget(self._create_settings()) 336 337 # Add the run button to trigger the embedding computation. 338 self.run_button = QtWidgets.QPushButton("Annotate Images") 339 self.run_button.clicked.connect(self.__call__) 340 self.layout().addWidget(self.run_button)
Inherited Members
- PyQt5.QtWidgets.QWidget
- RenderFlag
- RenderFlags
- acceptDrops
- accessibleDescription
- accessibleName
- actionEvent
- actions
- activateWindow
- addAction
- addActions
- adjustSize
- autoFillBackground
- backgroundRole
- baseSize
- changeEvent
- childAt
- childrenRect
- childrenRegion
- clearFocus
- clearMask
- close
- closeEvent
- contentsMargins
- contentsRect
- contextMenuEvent
- contextMenuPolicy
- create
- createWindowContainer
- cursor
- destroy
- devType
- dragEnterEvent
- dragLeaveEvent
- dragMoveEvent
- dropEvent
- effectiveWinId
- ensurePolished
- enterEvent
- event
- find
- focusInEvent
- focusNextChild
- focusNextPrevChild
- focusOutEvent
- focusPolicy
- focusPreviousChild
- focusProxy
- focusWidget
- font
- fontInfo
- fontMetrics
- foregroundRole
- frameGeometry
- frameSize
- geometry
- getContentsMargins
- grab
- grabGesture
- grabKeyboard
- grabMouse
- grabShortcut
- graphicsEffect
- graphicsProxyWidget
- hasFocus
- hasHeightForWidth
- hasMouseTracking
- hasTabletTracking
- height
- heightForWidth
- hide
- hideEvent
- initPainter
- inputMethodEvent
- inputMethodHints
- inputMethodQuery
- insertAction
- insertActions
- isActiveWindow
- isAncestorOf
- isEnabled
- isEnabledTo
- isFullScreen
- isHidden
- isLeftToRight
- isMaximized
- isMinimized
- isModal
- isRightToLeft
- isVisible
- isVisibleTo
- isWindow
- isWindowModified
- keyPressEvent
- keyReleaseEvent
- keyboardGrabber
- layout
- layoutDirection
- leaveEvent
- locale
- lower
- mapFrom
- mapFromGlobal
- mapFromParent
- mapTo
- mapToGlobal
- mapToParent
- mask
- maximumHeight
- maximumSize
- maximumWidth
- metric
- minimumHeight
- minimumSize
- minimumSizeHint
- minimumWidth
- mouseDoubleClickEvent
- mouseGrabber
- mouseMoveEvent
- mousePressEvent
- mouseReleaseEvent
- move
- moveEvent
- nativeEvent
- nativeParentWidget
- nextInFocusChain
- normalGeometry
- overrideWindowFlags
- overrideWindowState
- paintEngine
- paintEvent
- palette
- parentWidget
- pos
- previousInFocusChain
- raise_
- rect
- releaseKeyboard
- releaseMouse
- releaseShortcut
- removeAction
- render
- repaint
- resize
- resizeEvent
- restoreGeometry
- saveGeometry
- screen
- scroll
- setAcceptDrops
- setAccessibleDescription
- setAccessibleName
- setAttribute
- setAutoFillBackground
- setBackgroundRole
- setBaseSize
- setContentsMargins
- setContextMenuPolicy
- setCursor
- setDisabled
- setEnabled
- setFixedHeight
- setFixedSize
- setFixedWidth
- setFocus
- setFocusPolicy
- setFocusProxy
- setFont
- setForegroundRole
- setGeometry
- setGraphicsEffect
- setHidden
- setInputMethodHints
- setLayout
- setLayoutDirection
- setLocale
- setMask
- setMaximumHeight
- setMaximumSize
- setMaximumWidth
- setMinimumHeight
- setMinimumSize
- setMinimumWidth
- setMouseTracking
- setPalette
- setParent
- setShortcutAutoRepeat
- setShortcutEnabled
- setSizeIncrement
- setSizePolicy
- setStatusTip
- setStyle
- setStyleSheet
- setTabOrder
- setTabletTracking
- setToolTip
- setToolTipDuration
- setUpdatesEnabled
- setVisible
- setWhatsThis
- setWindowFilePath
- setWindowFlag
- setWindowFlags
- setWindowIcon
- setWindowIconText
- setWindowModality
- setWindowModified
- setWindowOpacity
- setWindowRole
- setWindowState
- setWindowTitle
- show
- showEvent
- showFullScreen
- showMaximized
- showMinimized
- showNormal
- size
- sizeHint
- sizeIncrement
- sizePolicy
- stackUnder
- statusTip
- style
- styleSheet
- tabletEvent
- testAttribute
- toolTip
- toolTipDuration
- underMouse
- ungrabGesture
- unsetCursor
- unsetLayoutDirection
- unsetLocale
- update
- updateGeometry
- updateMicroFocus
- updatesEnabled
- visibleRegion
- whatsThis
- wheelEvent
- width
- winId
- window
- windowFilePath
- windowFlags
- windowHandle
- windowIcon
- windowIconText
- windowModality
- windowOpacity
- windowRole
- windowState
- windowTitle
- windowType
- x
- y
- DrawChildren
- DrawWindowBackground
- IgnoreMask
- windowIconTextChanged
- windowIconChanged
- windowTitleChanged
- customContextMenuRequested
- PyQt5.QtCore.QObject
- blockSignals
- childEvent
- children
- connectNotify
- customEvent
- deleteLater
- disconnect
- disconnectNotify
- dumpObjectInfo
- dumpObjectTree
- dynamicPropertyNames
- eventFilter
- findChild
- findChildren
- inherits
- installEventFilter
- isSignalConnected
- isWidgetType
- isWindowType
- killTimer
- metaObject
- moveToThread
- objectName
- parent
- property
- pyqtConfigure
- receivers
- removeEventFilter
- sender
- senderSignalIndex
- setObjectName
- setProperty
- signalsBlocked
- startTimer
- thread
- timerEvent
- tr
- staticMetaObject
- objectNameChanged
- destroyed
- PyQt5.QtGui.QPaintDevice
- PaintDeviceMetric
- colorCount
- depth
- devicePixelRatio
- devicePixelRatioF
- devicePixelRatioFScale
- heightMM
- logicalDpiX
- logicalDpiY
- paintingActive
- physicalDpiX
- physicalDpiY
- widthMM
- PdmDepth
- PdmDevicePixelRatio
- PdmDevicePixelRatioScaled
- PdmDpiX
- PdmDpiY
- PdmHeight
- PdmHeightMM
- PdmNumColors
- PdmPhysicalDpiX
- PdmPhysicalDpiY
- PdmWidth
- PdmWidthMM