micro_sam.sam_annotator.image_series_annotator
1import os 2import time 3from glob import glob 4from pathlib import Path 5from typing import List, Optional, Union, Tuple 6 7import numpy as np 8import imageio.v3 as imageio 9 10import torch 11 12import napari 13from magicgui import magicgui 14from qtpy import QtWidgets 15 16from .. import util 17from . import _widgets as widgets 18from ._tooltips import get_tooltip 19from ._state import AnnotatorState 20from .annotator_2d import Annotator2d 21from .annotator_3d import Annotator3d 22from .util import _sync_embedding_widget 23from ..instance_segmentation import get_decoder 24from ..precompute_state import _precompute_state_for_files 25 26 27def _precompute( 28 images, model_type, embedding_path, 29 tile_shape, halo, precompute_amg_state, 30 checkpoint_path, device, ndim, prefer_decoder, 31): 32 t_start = time.time() 33 34 device = util.get_device(device) 35 predictor, state = util.get_sam_model( 36 model_type=model_type, checkpoint_path=checkpoint_path, device=device, return_state=True 37 ) 38 if prefer_decoder and "decoder_state" in state: 39 decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device) 40 else: 41 decoder = None 42 43 if embedding_path is None: 44 embedding_paths = [None] * len(images) 45 else: 46 _precompute_state_for_files( 47 predictor, images, embedding_path, ndim=ndim, tile_shape=tile_shape, halo=halo, 48 precompute_amg_state=precompute_amg_state, decoder=decoder, 49 ) 50 if isinstance(images[0], np.ndarray): 51 embedding_paths = [ 52 os.path.join(embedding_path, f"embedding_{i:05}.zarr") for i, path in enumerate(images) 53 ] 54 else: 55 embedding_paths = [ 56 os.path.join(embedding_path, f"{Path(path).stem}.zarr") for path in images 57 ] 58 assert all(os.path.exists(emb_path) for emb_path in embedding_paths) 59 60 t_run = time.time() - t_start 61 minutes = int(t_run // 60) 62 seconds = int(round(t_run % 60, 0)) 63 print("Precomputation took", t_run, f"seconds (= {minutes:02}:{seconds:02} minutes)") 64 65 return predictor, decoder, embedding_paths 66 67 68def _get_input_shape(image, is_volumetric=False): 69 if image.ndim == 2: 70 image_shape = image.shape 71 elif image.ndim == 3: 72 if is_volumetric: 73 image_shape = image.shape 74 else: 75 image_shape = image.shape[:-1] 76 elif image.ndim == 4: 77 image_shape = image.shape[:-1] 78 79 return image_shape 80 81 82def _initialize_annotator( 83 viewer, image, image_embedding_path, 84 model_type, halo, tile_shape, predictor, decoder, is_volumetric, 85 precompute_amg_state, checkpoint_path, device, embedding_path, 86): 87 if viewer is None: 88 viewer = napari.Viewer() 89 viewer.add_image(image, name="image") 90 91 state = AnnotatorState() 92 state.initialize_predictor( 93 image, model_type=model_type, save_path=image_embedding_path, halo=halo, tile_shape=tile_shape, 94 predictor=predictor, decoder=decoder, 95 ndim=3 if is_volumetric else 2, precompute_amg_state=precompute_amg_state, 96 checkpoint_path=checkpoint_path, device=device, skip_load=False, 97 ) 98 state.image_shape = _get_input_shape(image, is_volumetric) 99 100 if is_volumetric: 101 if image.ndim not in [3, 4]: 102 raise ValueError(f"Invalid image dimensions for 3d annotator, expect 3 or 4 dimensions, got {image.ndim}") 103 annotator = Annotator3d(viewer) 104 else: 105 if image.ndim not in (2, 3): 106 raise ValueError(f"Invalid image dimensions for 2d annotator, expect 2 or 3 dimensions, got {image.ndim}") 107 annotator = Annotator2d(viewer) 108 109 annotator._update_image() 110 111 # Add the annotator widget to the viewer and sync widgets. 112 viewer.window.add_dock_widget(annotator) 113 _sync_embedding_widget( 114 state.widgets["embeddings"], model_type, 115 save_path=embedding_path, checkpoint_path=checkpoint_path, 116 device=device, tile_shape=tile_shape, halo=halo 117 ) 118 return viewer, annotator 119 120 121def image_series_annotator( 122 images: Union[List[Union[os.PathLike, str]], List[np.ndarray]], 123 output_folder: str, 124 model_type: str = util._DEFAULT_MODEL, 125 embedding_path: Optional[str] = None, 126 tile_shape: Optional[Tuple[int, int]] = None, 127 halo: Optional[Tuple[int, int]] = None, 128 viewer: Optional["napari.viewer.Viewer"] = None, 129 return_viewer: bool = False, 130 precompute_amg_state: bool = False, 131 checkpoint_path: Optional[str] = None, 132 is_volumetric: bool = False, 133 device: Optional[Union[str, torch.device]] = None, 134 prefer_decoder: bool = True, 135 skip_segmented: bool = True, 136) -> Optional["napari.viewer.Viewer"]: 137 """Run the annotation tool for a series of images (supported for both 2d and 3d images). 138 139 Args: 140 images: List of the file paths or list of (set of) slices for the images to be annotated. 141 output_folder: The folder where the segmentation results are saved. 142 model_type: The Segment Anything model to use. For details on the available models check out 143 https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models. 144 embedding_path: Filepath where to save the embeddings. 145 tile_shape: Shape of tiles for tiled embedding prediction. 146 If `None` then the whole image is passed to Segment Anything. 147 halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders. 148 viewer: The viewer to which the SegmentAnything functionality should be added. 149 This enables using a pre-initialized viewer. 150 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 151 precompute_amg_state: Whether to precompute the state for automatic mask generation. 152 This will take more time when precomputing embeddings, but will then make 153 automatic mask generation much faster. 154 checkpoint_path: Path to a custom checkpoint from which to load the SAM model. 155 is_volumetric: Whether to use the 3d annotator. 156 prefer_decoder: Whether to use decoder based instance segmentation if 157 the model used has an additional decoder for instance segmentation. 158 skip_segmented: Whether to skip images that were already segmented. 159 160 Returns: 161 The napari viewer, only returned if `return_viewer=True`. 162 """ 163 end_msg = "You have annotated the last image. Do you wish to close napari?" 164 os.makedirs(output_folder, exist_ok=True) 165 166 # Precompute embeddings and amg state (if corresponding options set). 167 predictor, decoder, embedding_paths = _precompute( 168 images, model_type, 169 embedding_path, tile_shape, halo, precompute_amg_state, 170 checkpoint_path=checkpoint_path, device=device, 171 ndim=3 if is_volumetric else 2, prefer_decoder=prefer_decoder, 172 ) 173 174 next_image_id = 0 175 have_inputs_as_arrays = isinstance(images[next_image_id], np.ndarray) 176 177 def _get_save_path(image_path, current_idx): 178 if have_inputs_as_arrays: 179 fname = f"seg_{current_idx:05}.tif" 180 else: 181 fname = os.path.basename(image_path) 182 fname = os.path.splitext(fname)[0] + ".tif" 183 return os.path.join(output_folder, fname) 184 185 # Check which image to load next if we skip segmented images. 186 image_embedding_path = None 187 if skip_segmented: 188 while True: 189 if next_image_id == len(images): 190 print(end_msg) 191 return 192 193 save_path = _get_save_path(images[next_image_id], next_image_id) 194 if not os.path.exists(save_path): 195 print("The first image to annotate is image number", next_image_id) 196 image = images[next_image_id] 197 if not have_inputs_as_arrays: 198 image = imageio.imread(image) 199 image_embedding_path = embedding_paths[next_image_id] 200 break 201 202 next_image_id += 1 203 204 # Initialize the viewer and annotator for this image. 205 state = AnnotatorState() 206 viewer, annotator = _initialize_annotator( 207 viewer, image, image_embedding_path, 208 model_type, halo, tile_shape, predictor, decoder, is_volumetric, 209 precompute_amg_state, checkpoint_path, device, embedding_path, 210 ) 211 212 def _save_segmentation(image_path, current_idx, segmentation): 213 save_path = _get_save_path(image_path, next_image_id) 214 imageio.imwrite(save_path, segmentation, compression="zlib") 215 216 # Add functionality for going to the next image. 217 @magicgui(call_button="Next Image [N]") 218 def next_image(*args): 219 nonlocal next_image_id 220 221 segmentation = viewer.layers["committed_objects"].data 222 abort = False 223 if segmentation.sum() == 0: 224 msg = "Nothing is segmented yet. Do you wish to continue to the next image?" 225 abort = widgets._generate_message("info", msg) 226 if abort: 227 return 228 229 # Save the current segmentation. 230 _save_segmentation(images[next_image_id], next_image_id, segmentation) 231 232 # Clear the segmentation already to avoid lagging removal. 233 viewer.layers["committed_objects"].data = np.zeros_like(viewer.layers["committed_objects"].data) 234 235 # Go to the next images, if skipping images that are already segmented check if we have to load it. 236 next_image_id += 1 237 if skip_segmented: 238 save_path = _get_save_path(images[next_image_id], next_image_id) 239 while os.path.exists(save_path): 240 next_image_id += 1 241 if next_image_id == len(images): 242 break 243 save_path = _get_save_path(images[next_image_id], next_image_id) 244 245 # Load the next image. 246 if next_image_id == len(images): 247 # Inform the user via dialog. 248 abort = widgets._generate_message("info", end_msg) 249 if not abort: 250 viewer.close() 251 return 252 253 print( 254 "Loading next image:", images[next_image_id] if not have_inputs_as_arrays else f"at index {next_image_id}" 255 ) 256 257 if have_inputs_as_arrays: 258 image = images[next_image_id] 259 else: 260 image = imageio.imread(images[next_image_id]) 261 262 image_embedding_path = embedding_paths[next_image_id] 263 264 # Set the new image in the viewer, state and annotator. 265 viewer.layers["image"].data = image 266 267 if state.amg is not None: 268 state.amg.clear_state() 269 270 state.initialize_predictor( 271 image, model_type=model_type, ndim=3 if is_volumetric else 2, 272 save_path=image_embedding_path, 273 tile_shape=tile_shape, halo=halo, 274 predictor=predictor, decoder=decoder, 275 precompute_amg_state=precompute_amg_state, device=device, 276 skip_load=False, 277 ) 278 state.image_shape = _get_input_shape(image, is_volumetric) 279 280 annotator._update_image() 281 282 viewer.window.add_dock_widget(next_image) 283 284 @viewer.bind_key("n", overwrite=True) 285 def _next_image(viewer): 286 next_image(viewer) 287 288 if return_viewer: 289 return viewer 290 napari.run() 291 292 293def image_folder_annotator( 294 input_folder: str, 295 output_folder: str, 296 pattern: str = "*", 297 viewer: Optional["napari.viewer.Viewer"] = None, 298 return_viewer: bool = False, 299 **kwargs 300) -> Optional["napari.viewer.Viewer"]: 301 """Run the 2d annotation tool for a series of images in a folder. 302 303 Args: 304 input_folder: The folder with the images to be annotated. 305 output_folder: The folder where the segmentation results are saved. 306 pattern: The glob patter for loading files from `input_folder`. 307 By default all files will be loaded. 308 viewer: The viewer to which the SegmentAnything functionality should be added. 309 This enables using a pre-initialized viewer. 310 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 311 kwargs: The keyword arguments for `micro_sam.sam_annotator.image_series_annotator`. 312 313 Returns: 314 The napari viewer, only returned if `return_viewer=True`. 315 """ 316 image_files = sorted(glob(os.path.join(input_folder, pattern))) 317 318 return image_series_annotator( 319 image_files, output_folder, viewer=viewer, return_viewer=return_viewer, **kwargs 320 ) 321 322 323class ImageSeriesAnnotator(widgets._WidgetBase): 324 def __init__(self, viewer: napari.Viewer, parent=None): 325 super().__init__(parent=parent) 326 self._viewer = viewer 327 328 # Create the UI: the general options. 329 self._create_options() 330 331 # Add the settings (collapsible). 332 self.layout().addWidget(self._create_settings()) 333 334 # Add the run button to trigger the embedding computation. 335 self.run_button = QtWidgets.QPushButton("Annotate Images") 336 self.run_button.clicked.connect(self.__call__) 337 self.layout().addWidget(self.run_button) 338 339 # model_type: str = util._DEFAULT_MODEL, 340 def _create_options(self): 341 self.folder = None 342 _, layout = self._add_path_param( 343 "folder", self.folder, "directory", 344 title="Input Folder", placeholder="Folder with images ...", 345 tooltip=get_tooltip("image_series_annotator", "folder") 346 ) 347 self.layout().addLayout(layout) 348 349 self.output_folder = None 350 _, layout = self._add_path_param( 351 "output_folder", self.output_folder, "directory", 352 title="Output Folder", placeholder="Folder to save the results ...", 353 tooltip=get_tooltip("image_series_annotator", "output_folder") 354 ) 355 self.layout().addLayout(layout) 356 357 self.model_type = util._DEFAULT_MODEL 358 model_options = list(util.models().urls.keys()) 359 model_options = [model for model in model_options if not model.endswith("decoder")] 360 _, layout = self._add_choice_param( 361 "model_type", self.model_type, model_options, title="Model:", 362 tooltip=get_tooltip("embedding", "model") 363 ) 364 self.layout().addLayout(layout) 365 366 def _create_settings(self): 367 setting_values = QtWidgets.QWidget() 368 setting_values.setLayout(QtWidgets.QVBoxLayout()) 369 370 self.pattern = "*" 371 _, layout = self._add_string_param( 372 "pattern", self.pattern, tooltip=get_tooltip("image_series_annotator", "pattern") 373 ) 374 setting_values.layout().addLayout(layout) 375 376 self.is_volumetric = False 377 setting_values.layout().addWidget(self._add_boolean_param( 378 "is_volumetric", self.is_volumetric, tooltip=get_tooltip("image_series_annotator", "is_volumetric") 379 )) 380 381 self.device = "auto" 382 device_options = ["auto"] + util._available_devices() 383 self.device_dropdown, layout = self._add_choice_param( 384 "device", self.device, device_options, tooltip=get_tooltip("embedding", "device") 385 ) 386 setting_values.layout().addLayout(layout) 387 388 self.embeddings_save_path = None 389 _, layout = self._add_path_param( 390 "embeddings_save_path", self.embeddings_save_path, "directory", title="embeddings save path:", 391 tooltip=get_tooltip("embedding", "embeddings_save_path") 392 ) 393 setting_values.layout().addLayout(layout) 394 395 self.custom_weights = None # select_file 396 _, layout = self._add_path_param( 397 "custom_weights", self.custom_weights, "file", title="custom weights path:", 398 tooltip=get_tooltip("embedding", "custom_weights") 399 ) 400 setting_values.layout().addLayout(layout) 401 402 self.tile_x, self.tile_y = 0, 0 403 self.tile_x_param, self.tile_y_param, layout = self._add_shape_param( 404 ("tile_x", "tile_y"), (self.tile_x, self.tile_y), min_val=0, max_val=2048, step=16, 405 tooltip=get_tooltip("embedding", "tiling") 406 ) 407 setting_values.layout().addLayout(layout) 408 409 self.halo_x, self.halo_y = 0, 0 410 self.halo_x_param, self.halo_y_param, layout = self._add_shape_param( 411 ("halo_x", "halo_y"), (self.halo_x, self.halo_y), min_val=0, max_val=512, 412 tooltip=get_tooltip("embedding", "halo") 413 ) 414 setting_values.layout().addLayout(layout) 415 416 settings = widgets._make_collapsible(setting_values, title="Advanced Settings") 417 return settings 418 419 def _validate_inputs(self): 420 missing_data = self.folder is None or len(glob(os.path.join(self.folder, self.pattern))) == 0 421 missing_output = self.output_folder is None 422 if missing_data or missing_output: 423 msg = "" 424 if missing_data: 425 msg += "The input folder is missing or empty. " 426 if missing_output: 427 msg += "The output folder is missing." 428 return widgets._generate_message("error", msg) 429 return False 430 431 def __call__(self, skip_validate=False): 432 if not skip_validate and self._validate_inputs(): 433 return 434 tile_shape, halo = widgets._process_tiling_inputs(self.tile_x, self.tile_y, self.halo_x, self.halo_y) 435 436 image_folder_annotator( 437 self.folder, self.output_folder, self.pattern, 438 model_type=self.model_type, 439 embedding_path=self.embeddings_save_path, 440 tile_shape=tile_shape, halo=halo, checkpoint_path=self.custom_weights, 441 device=self.device, is_volumetric=self.is_volumetric, 442 viewer=self._viewer, return_viewer=True, 443 ) 444 445 446def main(): 447 """@private""" 448 import argparse 449 450 available_models = list(util.get_model_names()) 451 available_models = ", ".join(available_models) 452 453 parser = argparse.ArgumentParser(description="Annotate a series of images from a folder.") 454 parser.add_argument( 455 "-i", "--input_folder", required=True, 456 help="The folder containing the image data. The data can be stored in any common format (tif, jpg, png, ...)." 457 ) 458 parser.add_argument( 459 "-o", "--output_folder", required=True, 460 help="The folder where the segmentation results will be stored." 461 ) 462 parser.add_argument( 463 "-p", "--pattern", default="*", 464 help="The pattern to select the images to annotator from the input folder. E.g. *.tif to annotate all tifs." 465 "By default all files in the folder will be loaded and annotated." 466 ) 467 parser.add_argument( 468 "-e", "--embedding_path", 469 help="The filepath for saving/loading the pre-computed image embeddings. " 470 "NOTE: It is recommended to pass this argument and store the embeddings, " 471 "otherwise they will be recomputed every time (which can take a long time)." 472 ) 473 parser.add_argument( 474 "-m", "--model_type", default=util._DEFAULT_MODEL, 475 help=f"The segment anything model that will be used, one of {available_models}." 476 ) 477 parser.add_argument( 478 "-c", "--checkpoint", default=None, 479 help="Checkpoint from which the SAM model will be loaded loaded." 480 ) 481 parser.add_argument( 482 "-d", "--device", default=None, 483 help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)." 484 "By default the most performant available device will be selected." 485 ) 486 parser.add_argument( 487 "--is_volumetric", action="store_true", help="Whether to use the 3d annotator for a set of 3d volumes." 488 ) 489 490 parser.add_argument( 491 "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction", default=None 492 ) 493 parser.add_argument( 494 "--halo", nargs="+", type=int, help="The halo for using tiled prediction", default=None 495 ) 496 parser.add_argument("--precompute_amg_state", action="store_true") 497 parser.add_argument("--prefer_decoder", action="store_false") 498 parser.add_argument("--skip_segmented", action="store_false") 499 500 args = parser.parse_args() 501 502 image_folder_annotator( 503 args.input_folder, args.output_folder, args.pattern, 504 embedding_path=args.embedding_path, model_type=args.model_type, 505 tile_shape=args.tile_shape, halo=args.halo, precompute_amg_state=args.precompute_amg_state, 506 checkpoint_path=args.checkpoint, device=args.device, is_volumetric=args.is_volumetric, 507 prefer_decoder=args.prefer_decoder, skip_segmented=args.skip_segmented 508 )
def
image_series_annotator( images: Union[List[Union[str, os.PathLike]], List[numpy.ndarray]], output_folder: str, model_type: str = 'vit_l', embedding_path: Optional[str] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, viewer: Optional[napari.viewer.Viewer] = None, return_viewer: bool = False, precompute_amg_state: bool = False, checkpoint_path: Optional[str] = None, is_volumetric: bool = False, device: Union[str, torch.device, NoneType] = None, prefer_decoder: bool = True, skip_segmented: bool = True) -> Optional[napari.viewer.Viewer]:
122def image_series_annotator( 123 images: Union[List[Union[os.PathLike, str]], List[np.ndarray]], 124 output_folder: str, 125 model_type: str = util._DEFAULT_MODEL, 126 embedding_path: Optional[str] = None, 127 tile_shape: Optional[Tuple[int, int]] = None, 128 halo: Optional[Tuple[int, int]] = None, 129 viewer: Optional["napari.viewer.Viewer"] = None, 130 return_viewer: bool = False, 131 precompute_amg_state: bool = False, 132 checkpoint_path: Optional[str] = None, 133 is_volumetric: bool = False, 134 device: Optional[Union[str, torch.device]] = None, 135 prefer_decoder: bool = True, 136 skip_segmented: bool = True, 137) -> Optional["napari.viewer.Viewer"]: 138 """Run the annotation tool for a series of images (supported for both 2d and 3d images). 139 140 Args: 141 images: List of the file paths or list of (set of) slices for the images to be annotated. 142 output_folder: The folder where the segmentation results are saved. 143 model_type: The Segment Anything model to use. For details on the available models check out 144 https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models. 145 embedding_path: Filepath where to save the embeddings. 146 tile_shape: Shape of tiles for tiled embedding prediction. 147 If `None` then the whole image is passed to Segment Anything. 148 halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders. 149 viewer: The viewer to which the SegmentAnything functionality should be added. 150 This enables using a pre-initialized viewer. 151 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 152 precompute_amg_state: Whether to precompute the state for automatic mask generation. 153 This will take more time when precomputing embeddings, but will then make 154 automatic mask generation much faster. 155 checkpoint_path: Path to a custom checkpoint from which to load the SAM model. 156 is_volumetric: Whether to use the 3d annotator. 157 prefer_decoder: Whether to use decoder based instance segmentation if 158 the model used has an additional decoder for instance segmentation. 159 skip_segmented: Whether to skip images that were already segmented. 160 161 Returns: 162 The napari viewer, only returned if `return_viewer=True`. 163 """ 164 end_msg = "You have annotated the last image. Do you wish to close napari?" 165 os.makedirs(output_folder, exist_ok=True) 166 167 # Precompute embeddings and amg state (if corresponding options set). 168 predictor, decoder, embedding_paths = _precompute( 169 images, model_type, 170 embedding_path, tile_shape, halo, precompute_amg_state, 171 checkpoint_path=checkpoint_path, device=device, 172 ndim=3 if is_volumetric else 2, prefer_decoder=prefer_decoder, 173 ) 174 175 next_image_id = 0 176 have_inputs_as_arrays = isinstance(images[next_image_id], np.ndarray) 177 178 def _get_save_path(image_path, current_idx): 179 if have_inputs_as_arrays: 180 fname = f"seg_{current_idx:05}.tif" 181 else: 182 fname = os.path.basename(image_path) 183 fname = os.path.splitext(fname)[0] + ".tif" 184 return os.path.join(output_folder, fname) 185 186 # Check which image to load next if we skip segmented images. 187 image_embedding_path = None 188 if skip_segmented: 189 while True: 190 if next_image_id == len(images): 191 print(end_msg) 192 return 193 194 save_path = _get_save_path(images[next_image_id], next_image_id) 195 if not os.path.exists(save_path): 196 print("The first image to annotate is image number", next_image_id) 197 image = images[next_image_id] 198 if not have_inputs_as_arrays: 199 image = imageio.imread(image) 200 image_embedding_path = embedding_paths[next_image_id] 201 break 202 203 next_image_id += 1 204 205 # Initialize the viewer and annotator for this image. 206 state = AnnotatorState() 207 viewer, annotator = _initialize_annotator( 208 viewer, image, image_embedding_path, 209 model_type, halo, tile_shape, predictor, decoder, is_volumetric, 210 precompute_amg_state, checkpoint_path, device, embedding_path, 211 ) 212 213 def _save_segmentation(image_path, current_idx, segmentation): 214 save_path = _get_save_path(image_path, next_image_id) 215 imageio.imwrite(save_path, segmentation, compression="zlib") 216 217 # Add functionality for going to the next image. 218 @magicgui(call_button="Next Image [N]") 219 def next_image(*args): 220 nonlocal next_image_id 221 222 segmentation = viewer.layers["committed_objects"].data 223 abort = False 224 if segmentation.sum() == 0: 225 msg = "Nothing is segmented yet. Do you wish to continue to the next image?" 226 abort = widgets._generate_message("info", msg) 227 if abort: 228 return 229 230 # Save the current segmentation. 231 _save_segmentation(images[next_image_id], next_image_id, segmentation) 232 233 # Clear the segmentation already to avoid lagging removal. 234 viewer.layers["committed_objects"].data = np.zeros_like(viewer.layers["committed_objects"].data) 235 236 # Go to the next images, if skipping images that are already segmented check if we have to load it. 237 next_image_id += 1 238 if skip_segmented: 239 save_path = _get_save_path(images[next_image_id], next_image_id) 240 while os.path.exists(save_path): 241 next_image_id += 1 242 if next_image_id == len(images): 243 break 244 save_path = _get_save_path(images[next_image_id], next_image_id) 245 246 # Load the next image. 247 if next_image_id == len(images): 248 # Inform the user via dialog. 249 abort = widgets._generate_message("info", end_msg) 250 if not abort: 251 viewer.close() 252 return 253 254 print( 255 "Loading next image:", images[next_image_id] if not have_inputs_as_arrays else f"at index {next_image_id}" 256 ) 257 258 if have_inputs_as_arrays: 259 image = images[next_image_id] 260 else: 261 image = imageio.imread(images[next_image_id]) 262 263 image_embedding_path = embedding_paths[next_image_id] 264 265 # Set the new image in the viewer, state and annotator. 266 viewer.layers["image"].data = image 267 268 if state.amg is not None: 269 state.amg.clear_state() 270 271 state.initialize_predictor( 272 image, model_type=model_type, ndim=3 if is_volumetric else 2, 273 save_path=image_embedding_path, 274 tile_shape=tile_shape, halo=halo, 275 predictor=predictor, decoder=decoder, 276 precompute_amg_state=precompute_amg_state, device=device, 277 skip_load=False, 278 ) 279 state.image_shape = _get_input_shape(image, is_volumetric) 280 281 annotator._update_image() 282 283 viewer.window.add_dock_widget(next_image) 284 285 @viewer.bind_key("n", overwrite=True) 286 def _next_image(viewer): 287 next_image(viewer) 288 289 if return_viewer: 290 return viewer 291 napari.run()
Run the annotation tool for a series of images (supported for both 2d and 3d images).
Arguments:
- images: List of the file paths or list of (set of) slices for the images to be annotated.
- output_folder: The folder where the segmentation results are saved.
- model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
- embedding_path: Filepath where to save the embeddings.
- tile_shape: Shape of tiles for tiled embedding prediction.
If
None
then the whole image is passed to Segment Anything. - halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
- viewer: The viewer to which the SegmentAnything functionality should be added. This enables using a pre-initialized viewer.
- return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
- precompute_amg_state: Whether to precompute the state for automatic mask generation. This will take more time when precomputing embeddings, but will then make automatic mask generation much faster.
- checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
- is_volumetric: Whether to use the 3d annotator.
- prefer_decoder: Whether to use decoder based instance segmentation if the model used has an additional decoder for instance segmentation.
- skip_segmented: Whether to skip images that were already segmented.
Returns:
The napari viewer, only returned if
return_viewer=True
.
def
image_folder_annotator( input_folder: str, output_folder: str, pattern: str = '*', viewer: Optional[napari.viewer.Viewer] = None, return_viewer: bool = False, **kwargs) -> Optional[napari.viewer.Viewer]:
294def image_folder_annotator( 295 input_folder: str, 296 output_folder: str, 297 pattern: str = "*", 298 viewer: Optional["napari.viewer.Viewer"] = None, 299 return_viewer: bool = False, 300 **kwargs 301) -> Optional["napari.viewer.Viewer"]: 302 """Run the 2d annotation tool for a series of images in a folder. 303 304 Args: 305 input_folder: The folder with the images to be annotated. 306 output_folder: The folder where the segmentation results are saved. 307 pattern: The glob patter for loading files from `input_folder`. 308 By default all files will be loaded. 309 viewer: The viewer to which the SegmentAnything functionality should be added. 310 This enables using a pre-initialized viewer. 311 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 312 kwargs: The keyword arguments for `micro_sam.sam_annotator.image_series_annotator`. 313 314 Returns: 315 The napari viewer, only returned if `return_viewer=True`. 316 """ 317 image_files = sorted(glob(os.path.join(input_folder, pattern))) 318 319 return image_series_annotator( 320 image_files, output_folder, viewer=viewer, return_viewer=return_viewer, **kwargs 321 )
Run the 2d annotation tool for a series of images in a folder.
Arguments:
- input_folder: The folder with the images to be annotated.
- output_folder: The folder where the segmentation results are saved.
- pattern: The glob patter for loading files from
input_folder
. By default all files will be loaded. - viewer: The viewer to which the SegmentAnything functionality should be added. This enables using a pre-initialized viewer.
- return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
- kwargs: The keyword arguments for
micro_sam.sam_annotator.image_series_annotator
.
Returns:
The napari viewer, only returned if
return_viewer=True
.
324class ImageSeriesAnnotator(widgets._WidgetBase): 325 def __init__(self, viewer: napari.Viewer, parent=None): 326 super().__init__(parent=parent) 327 self._viewer = viewer 328 329 # Create the UI: the general options. 330 self._create_options() 331 332 # Add the settings (collapsible). 333 self.layout().addWidget(self._create_settings()) 334 335 # Add the run button to trigger the embedding computation. 336 self.run_button = QtWidgets.QPushButton("Annotate Images") 337 self.run_button.clicked.connect(self.__call__) 338 self.layout().addWidget(self.run_button) 339 340 # model_type: str = util._DEFAULT_MODEL, 341 def _create_options(self): 342 self.folder = None 343 _, layout = self._add_path_param( 344 "folder", self.folder, "directory", 345 title="Input Folder", placeholder="Folder with images ...", 346 tooltip=get_tooltip("image_series_annotator", "folder") 347 ) 348 self.layout().addLayout(layout) 349 350 self.output_folder = None 351 _, layout = self._add_path_param( 352 "output_folder", self.output_folder, "directory", 353 title="Output Folder", placeholder="Folder to save the results ...", 354 tooltip=get_tooltip("image_series_annotator", "output_folder") 355 ) 356 self.layout().addLayout(layout) 357 358 self.model_type = util._DEFAULT_MODEL 359 model_options = list(util.models().urls.keys()) 360 model_options = [model for model in model_options if not model.endswith("decoder")] 361 _, layout = self._add_choice_param( 362 "model_type", self.model_type, model_options, title="Model:", 363 tooltip=get_tooltip("embedding", "model") 364 ) 365 self.layout().addLayout(layout) 366 367 def _create_settings(self): 368 setting_values = QtWidgets.QWidget() 369 setting_values.setLayout(QtWidgets.QVBoxLayout()) 370 371 self.pattern = "*" 372 _, layout = self._add_string_param( 373 "pattern", self.pattern, tooltip=get_tooltip("image_series_annotator", "pattern") 374 ) 375 setting_values.layout().addLayout(layout) 376 377 self.is_volumetric = False 378 setting_values.layout().addWidget(self._add_boolean_param( 379 "is_volumetric", self.is_volumetric, tooltip=get_tooltip("image_series_annotator", "is_volumetric") 380 )) 381 382 self.device = "auto" 383 device_options = ["auto"] + util._available_devices() 384 self.device_dropdown, layout = self._add_choice_param( 385 "device", self.device, device_options, tooltip=get_tooltip("embedding", "device") 386 ) 387 setting_values.layout().addLayout(layout) 388 389 self.embeddings_save_path = None 390 _, layout = self._add_path_param( 391 "embeddings_save_path", self.embeddings_save_path, "directory", title="embeddings save path:", 392 tooltip=get_tooltip("embedding", "embeddings_save_path") 393 ) 394 setting_values.layout().addLayout(layout) 395 396 self.custom_weights = None # select_file 397 _, layout = self._add_path_param( 398 "custom_weights", self.custom_weights, "file", title="custom weights path:", 399 tooltip=get_tooltip("embedding", "custom_weights") 400 ) 401 setting_values.layout().addLayout(layout) 402 403 self.tile_x, self.tile_y = 0, 0 404 self.tile_x_param, self.tile_y_param, layout = self._add_shape_param( 405 ("tile_x", "tile_y"), (self.tile_x, self.tile_y), min_val=0, max_val=2048, step=16, 406 tooltip=get_tooltip("embedding", "tiling") 407 ) 408 setting_values.layout().addLayout(layout) 409 410 self.halo_x, self.halo_y = 0, 0 411 self.halo_x_param, self.halo_y_param, layout = self._add_shape_param( 412 ("halo_x", "halo_y"), (self.halo_x, self.halo_y), min_val=0, max_val=512, 413 tooltip=get_tooltip("embedding", "halo") 414 ) 415 setting_values.layout().addLayout(layout) 416 417 settings = widgets._make_collapsible(setting_values, title="Advanced Settings") 418 return settings 419 420 def _validate_inputs(self): 421 missing_data = self.folder is None or len(glob(os.path.join(self.folder, self.pattern))) == 0 422 missing_output = self.output_folder is None 423 if missing_data or missing_output: 424 msg = "" 425 if missing_data: 426 msg += "The input folder is missing or empty. " 427 if missing_output: 428 msg += "The output folder is missing." 429 return widgets._generate_message("error", msg) 430 return False 431 432 def __call__(self, skip_validate=False): 433 if not skip_validate and self._validate_inputs(): 434 return 435 tile_shape, halo = widgets._process_tiling_inputs(self.tile_x, self.tile_y, self.halo_x, self.halo_y) 436 437 image_folder_annotator( 438 self.folder, self.output_folder, self.pattern, 439 model_type=self.model_type, 440 embedding_path=self.embeddings_save_path, 441 tile_shape=tile_shape, halo=halo, checkpoint_path=self.custom_weights, 442 device=self.device, is_volumetric=self.is_volumetric, 443 viewer=self._viewer, return_viewer=True, 444 )
QWidget(parent: typing.Optional[QWidget] = None, flags: Union[Qt.WindowFlags, Qt.WindowType] = Qt.WindowFlags())
ImageSeriesAnnotator(viewer: napari.viewer.Viewer, parent=None)
325 def __init__(self, viewer: napari.Viewer, parent=None): 326 super().__init__(parent=parent) 327 self._viewer = viewer 328 329 # Create the UI: the general options. 330 self._create_options() 331 332 # Add the settings (collapsible). 333 self.layout().addWidget(self._create_settings()) 334 335 # Add the run button to trigger the embedding computation. 336 self.run_button = QtWidgets.QPushButton("Annotate Images") 337 self.run_button.clicked.connect(self.__call__) 338 self.layout().addWidget(self.run_button)
Inherited Members
- PyQt5.QtWidgets.QWidget
- RenderFlag
- RenderFlags
- acceptDrops
- accessibleDescription
- accessibleName
- actionEvent
- actions
- activateWindow
- addAction
- addActions
- adjustSize
- autoFillBackground
- backgroundRole
- baseSize
- changeEvent
- childAt
- childrenRect
- childrenRegion
- clearFocus
- clearMask
- close
- closeEvent
- contentsMargins
- contentsRect
- contextMenuEvent
- contextMenuPolicy
- create
- createWindowContainer
- cursor
- destroy
- devType
- dragEnterEvent
- dragLeaveEvent
- dragMoveEvent
- dropEvent
- effectiveWinId
- ensurePolished
- enterEvent
- event
- find
- focusInEvent
- focusNextChild
- focusNextPrevChild
- focusOutEvent
- focusPolicy
- focusPreviousChild
- focusProxy
- focusWidget
- font
- fontInfo
- fontMetrics
- foregroundRole
- frameGeometry
- frameSize
- geometry
- getContentsMargins
- grab
- grabGesture
- grabKeyboard
- grabMouse
- grabShortcut
- graphicsEffect
- graphicsProxyWidget
- hasFocus
- hasHeightForWidth
- hasMouseTracking
- hasTabletTracking
- height
- heightForWidth
- hide
- hideEvent
- initPainter
- inputMethodEvent
- inputMethodHints
- inputMethodQuery
- insertAction
- insertActions
- isActiveWindow
- isAncestorOf
- isEnabled
- isEnabledTo
- isFullScreen
- isHidden
- isLeftToRight
- isMaximized
- isMinimized
- isModal
- isRightToLeft
- isVisible
- isVisibleTo
- isWindow
- isWindowModified
- keyPressEvent
- keyReleaseEvent
- keyboardGrabber
- layout
- layoutDirection
- leaveEvent
- locale
- lower
- mapFrom
- mapFromGlobal
- mapFromParent
- mapTo
- mapToGlobal
- mapToParent
- mask
- maximumHeight
- maximumSize
- maximumWidth
- metric
- minimumHeight
- minimumSize
- minimumSizeHint
- minimumWidth
- mouseDoubleClickEvent
- mouseGrabber
- mouseMoveEvent
- mousePressEvent
- mouseReleaseEvent
- move
- moveEvent
- nativeEvent
- nativeParentWidget
- nextInFocusChain
- normalGeometry
- overrideWindowFlags
- overrideWindowState
- paintEngine
- paintEvent
- palette
- parentWidget
- pos
- previousInFocusChain
- raise_
- rect
- releaseKeyboard
- releaseMouse
- releaseShortcut
- removeAction
- render
- repaint
- resize
- resizeEvent
- restoreGeometry
- saveGeometry
- screen
- scroll
- setAcceptDrops
- setAccessibleDescription
- setAccessibleName
- setAttribute
- setAutoFillBackground
- setBackgroundRole
- setBaseSize
- setContentsMargins
- setContextMenuPolicy
- setCursor
- setDisabled
- setEnabled
- setFixedHeight
- setFixedSize
- setFixedWidth
- setFocus
- setFocusPolicy
- setFocusProxy
- setFont
- setForegroundRole
- setGeometry
- setGraphicsEffect
- setHidden
- setInputMethodHints
- setLayout
- setLayoutDirection
- setLocale
- setMask
- setMaximumHeight
- setMaximumSize
- setMaximumWidth
- setMinimumHeight
- setMinimumSize
- setMinimumWidth
- setMouseTracking
- setPalette
- setParent
- setShortcutAutoRepeat
- setShortcutEnabled
- setSizeIncrement
- setSizePolicy
- setStatusTip
- setStyle
- setStyleSheet
- setTabOrder
- setTabletTracking
- setToolTip
- setToolTipDuration
- setUpdatesEnabled
- setVisible
- setWhatsThis
- setWindowFilePath
- setWindowFlag
- setWindowFlags
- setWindowIcon
- setWindowIconText
- setWindowModality
- setWindowModified
- setWindowOpacity
- setWindowRole
- setWindowState
- setWindowTitle
- show
- showEvent
- showFullScreen
- showMaximized
- showMinimized
- showNormal
- size
- sizeHint
- sizeIncrement
- sizePolicy
- stackUnder
- statusTip
- style
- styleSheet
- tabletEvent
- testAttribute
- toolTip
- toolTipDuration
- underMouse
- ungrabGesture
- unsetCursor
- unsetLayoutDirection
- unsetLocale
- update
- updateGeometry
- updateMicroFocus
- updatesEnabled
- visibleRegion
- whatsThis
- wheelEvent
- width
- winId
- window
- windowFilePath
- windowFlags
- windowHandle
- windowIcon
- windowIconText
- windowModality
- windowOpacity
- windowRole
- windowState
- windowTitle
- windowType
- x
- y
- DrawChildren
- DrawWindowBackground
- IgnoreMask
- windowIconTextChanged
- windowIconChanged
- windowTitleChanged
- customContextMenuRequested
- PyQt5.QtCore.QObject
- blockSignals
- childEvent
- children
- connectNotify
- customEvent
- deleteLater
- disconnect
- disconnectNotify
- dumpObjectInfo
- dumpObjectTree
- dynamicPropertyNames
- eventFilter
- findChild
- findChildren
- inherits
- installEventFilter
- isSignalConnected
- isWidgetType
- isWindowType
- killTimer
- metaObject
- moveToThread
- objectName
- parent
- property
- pyqtConfigure
- receivers
- removeEventFilter
- sender
- senderSignalIndex
- setObjectName
- setProperty
- signalsBlocked
- startTimer
- thread
- timerEvent
- tr
- staticMetaObject
- objectNameChanged
- destroyed
- PyQt5.QtGui.QPaintDevice
- PaintDeviceMetric
- colorCount
- depth
- devicePixelRatio
- devicePixelRatioF
- devicePixelRatioFScale
- heightMM
- logicalDpiX
- logicalDpiY
- paintingActive
- physicalDpiX
- physicalDpiY
- widthMM
- PdmDepth
- PdmDevicePixelRatio
- PdmDevicePixelRatioScaled
- PdmDpiX
- PdmDpiY
- PdmHeight
- PdmHeightMM
- PdmNumColors
- PdmPhysicalDpiX
- PdmPhysicalDpiY
- PdmWidth
- PdmWidthMM