micro_sam.sam_annotator.image_series_annotator

  1import os
  2import time
  3from glob import glob
  4from pathlib import Path
  5from typing import List, Optional, Union, Tuple
  6
  7import numpy as np
  8import imageio.v3 as imageio
  9
 10import torch
 11
 12import napari
 13from magicgui import magicgui
 14from qtpy import QtWidgets
 15from qtpy.QtCore import QTimer
 16
 17from .. import util
 18from . import _widgets as widgets
 19from ._tooltips import get_tooltip
 20from ._state import AnnotatorState
 21from .annotator_2d import Annotator2d
 22from .annotator_3d import Annotator3d
 23from .util import _sync_embedding_widget
 24from ..instance_segmentation import get_decoder
 25from ..precompute_state import _precompute_state_for_files
 26
 27
 28def _precompute(
 29    images, model_type, embedding_path,
 30    tile_shape, halo, precompute_amg_state,
 31    checkpoint_path, device, ndim, prefer_decoder,
 32):
 33    t_start = time.time()
 34
 35    device = util.get_device(device)
 36    predictor, state = util.get_sam_model(
 37        model_type=model_type, checkpoint_path=checkpoint_path, device=device, return_state=True
 38    )
 39    if prefer_decoder and "decoder_state" in state:
 40        decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device)
 41    else:
 42        decoder = None
 43
 44    if embedding_path is None:
 45        embedding_paths = [None] * len(images)
 46    else:
 47        _precompute_state_for_files(
 48            predictor, images, embedding_path, ndim=ndim, tile_shape=tile_shape, halo=halo,
 49            precompute_amg_state=precompute_amg_state, decoder=decoder,
 50        )
 51        if isinstance(images[0], np.ndarray):
 52            embedding_paths = [
 53                os.path.join(embedding_path, f"embedding_{i:05}.zarr") for i, path in enumerate(images)
 54            ]
 55        else:
 56            embedding_paths = [
 57                os.path.join(embedding_path, f"{Path(path).stem}.zarr") for path in images
 58            ]
 59        assert all(os.path.exists(emb_path) for emb_path in embedding_paths)
 60
 61        t_run = time.time() - t_start
 62        minutes = int(t_run // 60)
 63        seconds = int(round(t_run % 60, 0))
 64        print("Precomputation took", t_run, f"seconds (= {minutes:02}:{seconds:02} minutes)")
 65
 66    return predictor, decoder, embedding_paths
 67
 68
 69def _get_input_shape(image, is_volumetric=False):
 70    if image.ndim == 2:
 71        image_shape = image.shape
 72    elif image.ndim == 3:
 73        if is_volumetric:
 74            image_shape = image.shape
 75        else:
 76            image_shape = image.shape[:-1]
 77    elif image.ndim == 4:
 78        image_shape = image.shape[:-1]
 79
 80    return image_shape
 81
 82
 83def _initialize_annotator(
 84    viewer, image, image_embedding_path,
 85    model_type, halo, tile_shape, predictor, decoder, is_volumetric,
 86    precompute_amg_state, checkpoint_path, device, embedding_path,
 87    segmentation_path, initial_seg,
 88):
 89    if viewer is None:
 90        viewer = napari.Viewer()
 91    viewer.add_image(image, name="image")
 92
 93    state = AnnotatorState()
 94    state.initialize_predictor(
 95        image, model_type=model_type, save_path=image_embedding_path, halo=halo, tile_shape=tile_shape,
 96        predictor=predictor, decoder=decoder,
 97        ndim=3 if is_volumetric else 2, precompute_amg_state=precompute_amg_state,
 98        checkpoint_path=checkpoint_path, device=device, skip_load=False, use_cli=True,
 99    )
100    state.image_shape = _get_input_shape(image, is_volumetric)
101
102    if is_volumetric:
103        if image.ndim not in [3, 4]:
104            raise ValueError(f"Invalid image dimensions for 3d annotator, expect 3 or 4 dimensions, got {image.ndim}")
105        annotator = Annotator3d(viewer, reset_state=False)
106    else:
107        if image.ndim not in (2, 3):
108            raise ValueError(f"Invalid image dimensions for 2d annotator, expect 2 or 3 dimensions, got {image.ndim}")
109        annotator = Annotator2d(viewer, reset_state=False)
110
111    if os.path.exists(segmentation_path):
112        segmentation_result = imageio.imread(segmentation_path)
113    else:
114        segmentation_result = None
115    if initial_seg is not None and segmentation_result is None:
116        segmentation_result = initial_seg if isinstance(initial_seg, np.ndarray) else imageio.imread(initial_seg)
117    annotator._update_image(segmentation_result=segmentation_result)
118
119    # Add the annotator widget to the viewer and sync widgets.
120    viewer.window.add_dock_widget(annotator)
121    _sync_embedding_widget(
122        widget=state.widgets["embeddings"],
123        model_type=model_type if checkpoint_path is None else state.predictor.model_type,
124        save_path=embedding_path,
125        checkpoint_path=checkpoint_path,
126        device=device,
127        tile_shape=tile_shape,
128        halo=halo,
129    )
130    return viewer, annotator
131
132
133def image_series_annotator(
134    images: Union[List[Union[os.PathLike, str]], List[np.ndarray]],
135    output_folder: str,
136    model_type: str = util._DEFAULT_MODEL,
137    embedding_path: Optional[str] = None,
138    initial_segmentations: Optional[Union[List[Union[os.PathLike, str]], List[np.ndarray]]] = None,
139    tile_shape: Optional[Tuple[int, int]] = None,
140    halo: Optional[Tuple[int, int]] = None,
141    viewer: Optional["napari.viewer.Viewer"] = None,
142    return_viewer: bool = False,
143    precompute_amg_state: bool = False,
144    checkpoint_path: Optional[str] = None,
145    is_volumetric: bool = False,
146    device: Optional[Union[str, torch.device]] = None,
147    prefer_decoder: bool = True,
148    skip_segmented: bool = True,
149) -> Optional["napari.viewer.Viewer"]:
150    """Run the annotation tool for a series of images (supported for both 2d and 3d images).
151
152    Args:
153        images: List of the file paths or list of (set of) slices for the images to be annotated.
154        output_folder: The folder where the segmentation results are saved.
155        model_type: The Segment Anything model to use. For details on the available models check out
156            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
157        embedding_path: Filepath where to save the embeddings.
158        initial_segmentations: Initial segmentations to be corrected.
159            By default no initial segmentations are loaded.
160            If given, the initial segmentations will be loaded into 'committed_objects'.
161        tile_shape: Shape of tiles for tiled embedding prediction.
162            If `None` then the whole image is passed to Segment Anything.
163        halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
164        viewer: The viewer to which the Segment Anything functionality should be added.
165            This enables using a pre-initialized viewer.
166        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
167            By default, does not return the napari viewer.
168        precompute_amg_state: Whether to precompute the state for automatic mask generation.
169            This will take more time when precomputing embeddings, but will then make
170            automatic mask generation much faster. By default, set to 'False'.
171        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
172        is_volumetric: Whether to use the 3d annotator. By default, set to 'False'.
173        prefer_decoder: Whether to use decoder based instance segmentation if
174            the model used has an additional decoder for instance segmentation.
175            By default, set to 'True'.
176        skip_segmented: Whether to skip images that were already segmented.
177            If set to False, then segmentations that already exist will be loaded
178            and used to populate the 'committed_objects' layer.
179
180    Returns:
181        The napari viewer, only returned if `return_viewer=True`.
182    """
183    if initial_segmentations is not None and len(initial_segmentations) != len(images):
184        raise ValueError(
185            "You have passed initial segmentations, but the number of images and segmentations is not the same: "
186            f"{len(images)} != {len(initial_segmentations)}."
187        )
188
189    end_msg = "You have annotated the last image. Do you wish to close napari?"
190    os.makedirs(output_folder, exist_ok=True)
191
192    # Precompute embeddings and amg state (if corresponding options set).
193    predictor, decoder, embedding_paths = _precompute(
194        images, model_type,
195        embedding_path, tile_shape, halo, precompute_amg_state,
196        checkpoint_path=checkpoint_path, device=device,
197        ndim=3 if is_volumetric else 2, prefer_decoder=prefer_decoder,
198    )
199
200    next_image_id = 0
201    have_inputs_as_arrays = isinstance(images[next_image_id], np.ndarray)
202
203    def _get_save_path(image_path, current_idx):
204        if have_inputs_as_arrays:
205            fname = f"seg_{current_idx:05}.tif"
206        else:
207            fname = os.path.basename(image_path)
208            fname = os.path.splitext(fname)[0] + ".tif"
209        return os.path.join(output_folder, fname)
210
211    def _load_image(image_id):
212        image = images[next_image_id]
213        if not have_inputs_as_arrays:
214            image = imageio.imread(image)
215        image_embedding_path = embedding_paths[next_image_id]
216        return image, image_embedding_path
217
218    # Check which image to load next if we skip segmented images.
219    if skip_segmented:
220        while True:
221            if next_image_id == len(images):
222                print("All images have already been annotated and you have set 'skip_segmented=True'. Nothing to do.")
223                return
224
225            save_path = _get_save_path(images[next_image_id], next_image_id)
226            if not os.path.exists(save_path):
227                print("The first image to annotate is image number", next_image_id)
228                image, image_embedding_path = _load_image(next_image_id)
229                break
230
231            next_image_id += 1
232
233    else:
234        save_path = _get_save_path(images[next_image_id], next_image_id)
235        image, image_embedding_path = _load_image(next_image_id)
236
237    # Initialize the viewer and annotator for this image.
238    state = AnnotatorState()
239    viewer, annotator = _initialize_annotator(
240        viewer, image, image_embedding_path,
241        model_type, halo, tile_shape, predictor, decoder, is_volumetric,
242        precompute_amg_state, checkpoint_path, device, embedding_path,
243        save_path, None if initial_segmentations is None else initial_segmentations[next_image_id],
244    )
245
246    def _save_segmentation(image_path, current_idx, segmentation):
247        save_path = _get_save_path(image_path, next_image_id)
248        imageio.imwrite(save_path, segmentation, compression="zlib")
249
250    # Add functionality for going to the next image.
251    @magicgui(call_button="Next Image [N]")
252    def next_image(*args):
253        nonlocal next_image_id
254
255        segmentation = viewer.layers["committed_objects"].data
256        abort = False
257        if segmentation.sum() == 0:
258            msg = "Nothing is segmented yet. Do you wish to continue to the next image?"
259            abort = widgets._generate_message("info", msg)
260            if abort:
261                return
262
263        # Save the current segmentation.
264        _save_segmentation(images[next_image_id], next_image_id, segmentation)
265
266        # Check if we are done.
267        if (next_image_id + 1) == len(images):
268            # Inform the user via dialog.
269            abort = widgets._generate_message("info", end_msg)
270            if not abort:
271                QTimer.singleShot(0, viewer.close)
272            return
273
274        # Clear the segmentation already to avoid lagging removal.
275        viewer.layers["committed_objects"].data = np.zeros_like(viewer.layers["committed_objects"].data)
276
277        # Go to the next image.
278        next_image_id += 1
279
280        # If we are skipping images that are already segmented, then check if we have to load the next image.
281        save_path = _get_save_path(images[next_image_id], next_image_id)
282        if skip_segmented:
283            segmentation_result = None
284            while os.path.exists(save_path):
285                next_image_id += 1
286
287                # Check if we are done.
288                if next_image_id == len(images):
289                    # Inform the user via dialog.
290                    abort = widgets._generate_message("info", end_msg)
291                    if not abort:
292                        viewer.close()
293                    return
294
295                save_path = _get_save_path(images[next_image_id], next_image_id)
296        else:
297            if os.path.exists(save_path):
298                segmentation_result = imageio.imread(save_path)
299            else:
300                segmentation_result = None
301
302        # Load initial segmentation if it exists and if we don't have a segmentation result loaded yet.
303        if initial_segmentations is not None and segmentation_result is None:
304            initial_seg = initial_segmentations[next_image_id]
305            segmentation_result = initial_seg if isinstance(initial_seg, np.ndarray) else imageio.imread(initial_seg)
306
307        print(
308            "Loading next image:", images[next_image_id] if not have_inputs_as_arrays else f"at index {next_image_id}"
309        )
310
311        if have_inputs_as_arrays:
312            image = images[next_image_id]
313        else:
314            image = imageio.imread(images[next_image_id])
315
316        image_embedding_path = embedding_paths[next_image_id]
317
318        # Set the new image in the viewer, state and annotator.
319        viewer.layers["image"].data = image
320
321        if state.amg is not None:
322            state.amg.clear_state()
323
324        state.initialize_predictor(
325            image, model_type=model_type, ndim=3 if is_volumetric else 2,
326            save_path=image_embedding_path,
327            tile_shape=tile_shape, halo=halo,
328            predictor=predictor, decoder=decoder,
329            precompute_amg_state=precompute_amg_state, device=device,
330            skip_load=False,
331        )
332        state.image_shape = _get_input_shape(image, is_volumetric)
333
334        annotator._update_image(segmentation_result=segmentation_result)
335
336    viewer.window.add_dock_widget(next_image)
337
338    @viewer.bind_key("n", overwrite=True)
339    def _next_image(viewer):
340        next_image(viewer)
341
342    if return_viewer:
343        return viewer
344    napari.run()
345
346
347def image_folder_annotator(
348    input_folder: str,
349    output_folder: str,
350    pattern: str = "*",
351    initial_segmentation_folder: Optional[str] = None,
352    initial_segmentation_pattern: str = "*",
353    viewer: Optional["napari.viewer.Viewer"] = None,
354    return_viewer: bool = False,
355    **kwargs
356) -> Optional["napari.viewer.Viewer"]:
357    """Run the 2d annotation tool for a series of images in a folder.
358
359    Args:
360        input_folder: The folder with the images to be annotated.
361        output_folder: The folder where the segmentation results are saved.
362        pattern: The glob pattern for loading files from `input_folder`.
363            By default all files will be loaded.
364        initial_segmentation_folder: A folder with initial segmentation results.
365            By default no initial segmentations are loaded.
366        initial_segmentation_pattern: The glob pattern for loading files from `initial_segmentation_folder`.
367        viewer: The viewer to which the Segment Anything functionality should be added.
368            This enables using a pre-initialized viewer.
369        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
370            By default, does not return the napari viewer.
371        kwargs: The keyword arguments for `micro_sam.sam_annotator.image_series_annotator`.
372
373    Returns:
374        The napari viewer, only returned if `return_viewer=True`.
375    """
376    image_files = sorted(glob(os.path.join(input_folder, pattern)))
377    if initial_segmentation_folder is None:
378        initial_segmentations = None
379    else:
380        initial_segmentations = sorted(glob(os.path.join(
381            initial_segmentation_folder, initial_segmentation_pattern
382        )))
383
384    return image_series_annotator(
385        image_files, output_folder,
386        initial_segmentations=initial_segmentations,
387        viewer=viewer, return_viewer=return_viewer, **kwargs
388    )
389
390
391class ImageSeriesAnnotator(widgets._WidgetBase):
392    def __init__(self, viewer: napari.Viewer, parent=None):
393        super().__init__(parent=parent)
394        self._viewer = viewer
395
396        # Create the UI: the general options.
397        self._create_options()
398
399        # Add the settings (collapsible).
400        self.layout().addWidget(self._create_settings())
401
402        # Add the run button to trigger the embedding computation.
403        self.run_button = QtWidgets.QPushButton("Annotate Images")
404        self.run_button.clicked.connect(self.__call__)
405        self.layout().addWidget(self.run_button)
406
407    def _create_options(self):
408        self.folder = None
409        _, layout = self._add_path_param(
410            "folder", self.folder, "directory",
411            title="Input Folder", placeholder="Folder with images ...",
412            tooltip=get_tooltip("image_series_annotator", "folder")
413        )
414        self.layout().addLayout(layout)
415
416        self.output_folder = None
417        _, layout = self._add_path_param(
418            "output_folder", self.output_folder, "directory",
419            title="Output Folder", placeholder="Folder to save the results ...",
420            tooltip=get_tooltip("image_series_annotator", "output_folder")
421        )
422        self.layout().addLayout(layout)
423
424        # Add the model family widget section.
425        layout = self._create_model_section(create_layout=False)
426        self.layout().addLayout(layout)
427
428    def _create_settings(self):
429        setting_values = QtWidgets.QWidget()
430        setting_values.setLayout(QtWidgets.QVBoxLayout())
431
432        # Add the model size widget section.
433        layout = self._create_model_size_section()
434        setting_values.layout().addLayout(layout)
435
436        self.pattern = "*"
437        _, layout = self._add_string_param(
438            "pattern", self.pattern, tooltip=get_tooltip("image_series_annotator", "pattern")
439        )
440        setting_values.layout().addLayout(layout)
441
442        self.is_volumetric = False
443        setting_values.layout().addWidget(self._add_boolean_param(
444            "is_volumetric", self.is_volumetric, tooltip=get_tooltip("image_series_annotator", "is_volumetric")
445        ))
446
447        self.device = "auto"
448        device_options = ["auto"] + util._available_devices()
449        self.device_dropdown, layout = self._add_choice_param(
450            "device", self.device, device_options, tooltip=get_tooltip("embedding", "device")
451        )
452        setting_values.layout().addLayout(layout)
453
454        self.embeddings_save_path = None
455        _, layout = self._add_path_param(
456            "embeddings_save_path", self.embeddings_save_path, "directory", title="embeddings save path:",
457            tooltip=get_tooltip("embedding", "embeddings_save_path")
458        )
459        setting_values.layout().addLayout(layout)
460
461        self.custom_weights = None  # select_file
462        _, layout = self._add_path_param(
463            "custom_weights", self.custom_weights, "file", title="custom weights path:",
464            tooltip=get_tooltip("embedding", "custom_weights")
465        )
466        setting_values.layout().addLayout(layout)
467
468        self.tile_x, self.tile_y = 0, 0
469        self.tile_x_param, self.tile_y_param, layout = self._add_shape_param(
470            ("tile_x", "tile_y"), (self.tile_x, self.tile_y), min_val=0, max_val=2048, step=16,
471            tooltip=get_tooltip("embedding", "tiling")
472        )
473        setting_values.layout().addLayout(layout)
474
475        self.halo_x, self.halo_y = 0, 0
476        self.halo_x_param, self.halo_y_param, layout = self._add_shape_param(
477            ("halo_x", "halo_y"), (self.halo_x, self.halo_y), min_val=0, max_val=512,
478            tooltip=get_tooltip("embedding", "halo")
479        )
480        setting_values.layout().addLayout(layout)
481
482        settings = widgets._make_collapsible(setting_values, title="Advanced Settings")
483        return settings
484
485    def _validate_inputs(self):
486        missing_data = self.folder is None or len(glob(os.path.join(self.folder, self.pattern))) == 0
487        missing_output = self.output_folder is None
488        if missing_data or missing_output:
489            msg = ""
490            if missing_data:
491                msg += "The input folder is missing or empty. "
492            if missing_output:
493                msg += "The output folder is missing."
494            return widgets._generate_message("error", msg)
495        return False
496
497    def __call__(self, skip_validate=False):
498        self._validate_model_type_and_custom_weights()
499
500        if not skip_validate and self._validate_inputs():
501            return
502        tile_shape, halo = widgets._process_tiling_inputs(self.tile_x, self.tile_y, self.halo_x, self.halo_y)
503
504        image_folder_annotator(
505            input_folder=self.folder,
506            output_folder=self.output_folder,
507            pattern=self.pattern,
508            model_type=self.model_type,
509            embedding_path=self.embeddings_save_path,
510            tile_shape=tile_shape, halo=halo, checkpoint_path=self.custom_weights,
511            device=self.device, is_volumetric=self.is_volumetric,
512            viewer=self._viewer, return_viewer=True,
513        )
514
515
516def main():
517    """@private"""
518    import argparse
519
520    available_models = list(util.get_model_names())
521    available_models = ", ".join(available_models)
522
523    parser = argparse.ArgumentParser(description="Annotate a series of images from a folder.")
524    parser.add_argument(
525        "-i", "--input_folder", required=True,
526        help="The folder containing the image data. The data can be stored in any common format (tif, jpg, png, ...)."
527    )
528    parser.add_argument(
529        "-o", "--output_folder", required=True,
530        help="The folder where the segmentation results will be stored."
531    )
532    parser.add_argument(
533        "-p", "--pattern", default="*",
534        help="The pattern to select the images to annotator from the input folder. E.g. *.tif to annotate all tifs."
535        "By default all files in the folder will be loaded and annotated."
536    )
537    parser.add_argument(
538        "--initial_segmentation_folder",
539        help="A folder with initial segmentation results. By default no initial segmentations are loaded."
540    )
541    parser.add_argument(
542        "--initial_segmentation_pattern",
543        help="The glob pattern for loading files from `initial_segmentation_folder`."
544    )
545    parser.add_argument(
546        "-e", "--embedding_path",
547        help="The filepath for saving/loading the pre-computed image embeddings. "
548        "NOTE: It is recommended to pass this argument and store the embeddings, "
549        "otherwise they will be recomputed every time (which can take a long time)."
550    )
551    parser.add_argument(
552        "-m", "--model_type", default=util._DEFAULT_MODEL,
553        help=f"The segment anything model that will be used, one of {available_models}."
554    )
555    parser.add_argument(
556        "-c", "--checkpoint", default=None,
557        help="Checkpoint from which the SAM model will be loaded loaded."
558    )
559    parser.add_argument(
560        "-d", "--device", default=None,
561        help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)."
562        "By default the most performant available device will be selected."
563    )
564    parser.add_argument(
565        "--is_volumetric", action="store_true", help="Whether to use the 3d annotator for a set of 3d volumes."
566    )
567
568    parser.add_argument(
569        "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction", default=None
570    )
571    parser.add_argument(
572        "--halo", nargs="+", type=int, help="The halo for using tiled prediction", default=None
573    )
574    parser.add_argument("--precompute_amg_state", action="store_true")
575    parser.add_argument("--prefer_decoder", action="store_false")
576    parser.add_argument("--skip_segmented", action="store_false")
577
578    args = parser.parse_args()
579
580    image_folder_annotator(
581        args.input_folder, args.output_folder, args.pattern,
582        initial_segmentation_folder=args.initial_segmentation_folder,
583        initial_segmentation_pattern=args.initial_segmentation_pattern,
584        embedding_path=args.embedding_path, model_type=args.model_type,
585        tile_shape=args.tile_shape, halo=args.halo, precompute_amg_state=args.precompute_amg_state,
586        checkpoint_path=args.checkpoint, device=args.device, is_volumetric=args.is_volumetric,
587        prefer_decoder=args.prefer_decoder, skip_segmented=args.skip_segmented
588    )
def image_series_annotator( images: Union[List[Union[os.PathLike, str]], List[numpy.ndarray]], output_folder: str, model_type: str = 'vit_b_lm', embedding_path: Optional[str] = None, initial_segmentations: Union[List[Union[os.PathLike, str]], List[numpy.ndarray], NoneType] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, viewer: Optional[napari.viewer.Viewer] = None, return_viewer: bool = False, precompute_amg_state: bool = False, checkpoint_path: Optional[str] = None, is_volumetric: bool = False, device: Union[str, torch.device, NoneType] = None, prefer_decoder: bool = True, skip_segmented: bool = True) -> Optional[napari.viewer.Viewer]:
134def image_series_annotator(
135    images: Union[List[Union[os.PathLike, str]], List[np.ndarray]],
136    output_folder: str,
137    model_type: str = util._DEFAULT_MODEL,
138    embedding_path: Optional[str] = None,
139    initial_segmentations: Optional[Union[List[Union[os.PathLike, str]], List[np.ndarray]]] = None,
140    tile_shape: Optional[Tuple[int, int]] = None,
141    halo: Optional[Tuple[int, int]] = None,
142    viewer: Optional["napari.viewer.Viewer"] = None,
143    return_viewer: bool = False,
144    precompute_amg_state: bool = False,
145    checkpoint_path: Optional[str] = None,
146    is_volumetric: bool = False,
147    device: Optional[Union[str, torch.device]] = None,
148    prefer_decoder: bool = True,
149    skip_segmented: bool = True,
150) -> Optional["napari.viewer.Viewer"]:
151    """Run the annotation tool for a series of images (supported for both 2d and 3d images).
152
153    Args:
154        images: List of the file paths or list of (set of) slices for the images to be annotated.
155        output_folder: The folder where the segmentation results are saved.
156        model_type: The Segment Anything model to use. For details on the available models check out
157            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
158        embedding_path: Filepath where to save the embeddings.
159        initial_segmentations: Initial segmentations to be corrected.
160            By default no initial segmentations are loaded.
161            If given, the initial segmentations will be loaded into 'committed_objects'.
162        tile_shape: Shape of tiles for tiled embedding prediction.
163            If `None` then the whole image is passed to Segment Anything.
164        halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
165        viewer: The viewer to which the Segment Anything functionality should be added.
166            This enables using a pre-initialized viewer.
167        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
168            By default, does not return the napari viewer.
169        precompute_amg_state: Whether to precompute the state for automatic mask generation.
170            This will take more time when precomputing embeddings, but will then make
171            automatic mask generation much faster. By default, set to 'False'.
172        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
173        is_volumetric: Whether to use the 3d annotator. By default, set to 'False'.
174        prefer_decoder: Whether to use decoder based instance segmentation if
175            the model used has an additional decoder for instance segmentation.
176            By default, set to 'True'.
177        skip_segmented: Whether to skip images that were already segmented.
178            If set to False, then segmentations that already exist will be loaded
179            and used to populate the 'committed_objects' layer.
180
181    Returns:
182        The napari viewer, only returned if `return_viewer=True`.
183    """
184    if initial_segmentations is not None and len(initial_segmentations) != len(images):
185        raise ValueError(
186            "You have passed initial segmentations, but the number of images and segmentations is not the same: "
187            f"{len(images)} != {len(initial_segmentations)}."
188        )
189
190    end_msg = "You have annotated the last image. Do you wish to close napari?"
191    os.makedirs(output_folder, exist_ok=True)
192
193    # Precompute embeddings and amg state (if corresponding options set).
194    predictor, decoder, embedding_paths = _precompute(
195        images, model_type,
196        embedding_path, tile_shape, halo, precompute_amg_state,
197        checkpoint_path=checkpoint_path, device=device,
198        ndim=3 if is_volumetric else 2, prefer_decoder=prefer_decoder,
199    )
200
201    next_image_id = 0
202    have_inputs_as_arrays = isinstance(images[next_image_id], np.ndarray)
203
204    def _get_save_path(image_path, current_idx):
205        if have_inputs_as_arrays:
206            fname = f"seg_{current_idx:05}.tif"
207        else:
208            fname = os.path.basename(image_path)
209            fname = os.path.splitext(fname)[0] + ".tif"
210        return os.path.join(output_folder, fname)
211
212    def _load_image(image_id):
213        image = images[next_image_id]
214        if not have_inputs_as_arrays:
215            image = imageio.imread(image)
216        image_embedding_path = embedding_paths[next_image_id]
217        return image, image_embedding_path
218
219    # Check which image to load next if we skip segmented images.
220    if skip_segmented:
221        while True:
222            if next_image_id == len(images):
223                print("All images have already been annotated and you have set 'skip_segmented=True'. Nothing to do.")
224                return
225
226            save_path = _get_save_path(images[next_image_id], next_image_id)
227            if not os.path.exists(save_path):
228                print("The first image to annotate is image number", next_image_id)
229                image, image_embedding_path = _load_image(next_image_id)
230                break
231
232            next_image_id += 1
233
234    else:
235        save_path = _get_save_path(images[next_image_id], next_image_id)
236        image, image_embedding_path = _load_image(next_image_id)
237
238    # Initialize the viewer and annotator for this image.
239    state = AnnotatorState()
240    viewer, annotator = _initialize_annotator(
241        viewer, image, image_embedding_path,
242        model_type, halo, tile_shape, predictor, decoder, is_volumetric,
243        precompute_amg_state, checkpoint_path, device, embedding_path,
244        save_path, None if initial_segmentations is None else initial_segmentations[next_image_id],
245    )
246
247    def _save_segmentation(image_path, current_idx, segmentation):
248        save_path = _get_save_path(image_path, next_image_id)
249        imageio.imwrite(save_path, segmentation, compression="zlib")
250
251    # Add functionality for going to the next image.
252    @magicgui(call_button="Next Image [N]")
253    def next_image(*args):
254        nonlocal next_image_id
255
256        segmentation = viewer.layers["committed_objects"].data
257        abort = False
258        if segmentation.sum() == 0:
259            msg = "Nothing is segmented yet. Do you wish to continue to the next image?"
260            abort = widgets._generate_message("info", msg)
261            if abort:
262                return
263
264        # Save the current segmentation.
265        _save_segmentation(images[next_image_id], next_image_id, segmentation)
266
267        # Check if we are done.
268        if (next_image_id + 1) == len(images):
269            # Inform the user via dialog.
270            abort = widgets._generate_message("info", end_msg)
271            if not abort:
272                QTimer.singleShot(0, viewer.close)
273            return
274
275        # Clear the segmentation already to avoid lagging removal.
276        viewer.layers["committed_objects"].data = np.zeros_like(viewer.layers["committed_objects"].data)
277
278        # Go to the next image.
279        next_image_id += 1
280
281        # If we are skipping images that are already segmented, then check if we have to load the next image.
282        save_path = _get_save_path(images[next_image_id], next_image_id)
283        if skip_segmented:
284            segmentation_result = None
285            while os.path.exists(save_path):
286                next_image_id += 1
287
288                # Check if we are done.
289                if next_image_id == len(images):
290                    # Inform the user via dialog.
291                    abort = widgets._generate_message("info", end_msg)
292                    if not abort:
293                        viewer.close()
294                    return
295
296                save_path = _get_save_path(images[next_image_id], next_image_id)
297        else:
298            if os.path.exists(save_path):
299                segmentation_result = imageio.imread(save_path)
300            else:
301                segmentation_result = None
302
303        # Load initial segmentation if it exists and if we don't have a segmentation result loaded yet.
304        if initial_segmentations is not None and segmentation_result is None:
305            initial_seg = initial_segmentations[next_image_id]
306            segmentation_result = initial_seg if isinstance(initial_seg, np.ndarray) else imageio.imread(initial_seg)
307
308        print(
309            "Loading next image:", images[next_image_id] if not have_inputs_as_arrays else f"at index {next_image_id}"
310        )
311
312        if have_inputs_as_arrays:
313            image = images[next_image_id]
314        else:
315            image = imageio.imread(images[next_image_id])
316
317        image_embedding_path = embedding_paths[next_image_id]
318
319        # Set the new image in the viewer, state and annotator.
320        viewer.layers["image"].data = image
321
322        if state.amg is not None:
323            state.amg.clear_state()
324
325        state.initialize_predictor(
326            image, model_type=model_type, ndim=3 if is_volumetric else 2,
327            save_path=image_embedding_path,
328            tile_shape=tile_shape, halo=halo,
329            predictor=predictor, decoder=decoder,
330            precompute_amg_state=precompute_amg_state, device=device,
331            skip_load=False,
332        )
333        state.image_shape = _get_input_shape(image, is_volumetric)
334
335        annotator._update_image(segmentation_result=segmentation_result)
336
337    viewer.window.add_dock_widget(next_image)
338
339    @viewer.bind_key("n", overwrite=True)
340    def _next_image(viewer):
341        next_image(viewer)
342
343    if return_viewer:
344        return viewer
345    napari.run()

Run the annotation tool for a series of images (supported for both 2d and 3d images).

Arguments:
  • images: List of the file paths or list of (set of) slices for the images to be annotated.
  • output_folder: The folder where the segmentation results are saved.
  • model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
  • embedding_path: Filepath where to save the embeddings.
  • initial_segmentations: Initial segmentations to be corrected. By default no initial segmentations are loaded. If given, the initial segmentations will be loaded into 'committed_objects'.
  • tile_shape: Shape of tiles for tiled embedding prediction. If None then the whole image is passed to Segment Anything.
  • halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
  • viewer: The viewer to which the Segment Anything functionality should be added. This enables using a pre-initialized viewer.
  • return_viewer: Whether to return the napari viewer to further modify it before starting the tool. By default, does not return the napari viewer.
  • precompute_amg_state: Whether to precompute the state for automatic mask generation. This will take more time when precomputing embeddings, but will then make automatic mask generation much faster. By default, set to 'False'.
  • checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
  • is_volumetric: Whether to use the 3d annotator. By default, set to 'False'.
  • prefer_decoder: Whether to use decoder based instance segmentation if the model used has an additional decoder for instance segmentation. By default, set to 'True'.
  • skip_segmented: Whether to skip images that were already segmented. If set to False, then segmentations that already exist will be loaded and used to populate the 'committed_objects' layer.
Returns:

The napari viewer, only returned if return_viewer=True.

def image_folder_annotator( input_folder: str, output_folder: str, pattern: str = '*', initial_segmentation_folder: Optional[str] = None, initial_segmentation_pattern: str = '*', viewer: Optional[napari.viewer.Viewer] = None, return_viewer: bool = False, **kwargs) -> Optional[napari.viewer.Viewer]:
348def image_folder_annotator(
349    input_folder: str,
350    output_folder: str,
351    pattern: str = "*",
352    initial_segmentation_folder: Optional[str] = None,
353    initial_segmentation_pattern: str = "*",
354    viewer: Optional["napari.viewer.Viewer"] = None,
355    return_viewer: bool = False,
356    **kwargs
357) -> Optional["napari.viewer.Viewer"]:
358    """Run the 2d annotation tool for a series of images in a folder.
359
360    Args:
361        input_folder: The folder with the images to be annotated.
362        output_folder: The folder where the segmentation results are saved.
363        pattern: The glob pattern for loading files from `input_folder`.
364            By default all files will be loaded.
365        initial_segmentation_folder: A folder with initial segmentation results.
366            By default no initial segmentations are loaded.
367        initial_segmentation_pattern: The glob pattern for loading files from `initial_segmentation_folder`.
368        viewer: The viewer to which the Segment Anything functionality should be added.
369            This enables using a pre-initialized viewer.
370        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
371            By default, does not return the napari viewer.
372        kwargs: The keyword arguments for `micro_sam.sam_annotator.image_series_annotator`.
373
374    Returns:
375        The napari viewer, only returned if `return_viewer=True`.
376    """
377    image_files = sorted(glob(os.path.join(input_folder, pattern)))
378    if initial_segmentation_folder is None:
379        initial_segmentations = None
380    else:
381        initial_segmentations = sorted(glob(os.path.join(
382            initial_segmentation_folder, initial_segmentation_pattern
383        )))
384
385    return image_series_annotator(
386        image_files, output_folder,
387        initial_segmentations=initial_segmentations,
388        viewer=viewer, return_viewer=return_viewer, **kwargs
389    )

Run the 2d annotation tool for a series of images in a folder.

Arguments:
  • input_folder: The folder with the images to be annotated.
  • output_folder: The folder where the segmentation results are saved.
  • pattern: The glob pattern for loading files from input_folder. By default all files will be loaded.
  • initial_segmentation_folder: A folder with initial segmentation results. By default no initial segmentations are loaded.
  • initial_segmentation_pattern: The glob pattern for loading files from initial_segmentation_folder.
  • viewer: The viewer to which the Segment Anything functionality should be added. This enables using a pre-initialized viewer.
  • return_viewer: Whether to return the napari viewer to further modify it before starting the tool. By default, does not return the napari viewer.
  • kwargs: The keyword arguments for micro_sam.sam_annotator.image_series_annotator.
Returns:

The napari viewer, only returned if return_viewer=True.

class ImageSeriesAnnotator(micro_sam.sam_annotator._widgets._WidgetBase):
392class ImageSeriesAnnotator(widgets._WidgetBase):
393    def __init__(self, viewer: napari.Viewer, parent=None):
394        super().__init__(parent=parent)
395        self._viewer = viewer
396
397        # Create the UI: the general options.
398        self._create_options()
399
400        # Add the settings (collapsible).
401        self.layout().addWidget(self._create_settings())
402
403        # Add the run button to trigger the embedding computation.
404        self.run_button = QtWidgets.QPushButton("Annotate Images")
405        self.run_button.clicked.connect(self.__call__)
406        self.layout().addWidget(self.run_button)
407
408    def _create_options(self):
409        self.folder = None
410        _, layout = self._add_path_param(
411            "folder", self.folder, "directory",
412            title="Input Folder", placeholder="Folder with images ...",
413            tooltip=get_tooltip("image_series_annotator", "folder")
414        )
415        self.layout().addLayout(layout)
416
417        self.output_folder = None
418        _, layout = self._add_path_param(
419            "output_folder", self.output_folder, "directory",
420            title="Output Folder", placeholder="Folder to save the results ...",
421            tooltip=get_tooltip("image_series_annotator", "output_folder")
422        )
423        self.layout().addLayout(layout)
424
425        # Add the model family widget section.
426        layout = self._create_model_section(create_layout=False)
427        self.layout().addLayout(layout)
428
429    def _create_settings(self):
430        setting_values = QtWidgets.QWidget()
431        setting_values.setLayout(QtWidgets.QVBoxLayout())
432
433        # Add the model size widget section.
434        layout = self._create_model_size_section()
435        setting_values.layout().addLayout(layout)
436
437        self.pattern = "*"
438        _, layout = self._add_string_param(
439            "pattern", self.pattern, tooltip=get_tooltip("image_series_annotator", "pattern")
440        )
441        setting_values.layout().addLayout(layout)
442
443        self.is_volumetric = False
444        setting_values.layout().addWidget(self._add_boolean_param(
445            "is_volumetric", self.is_volumetric, tooltip=get_tooltip("image_series_annotator", "is_volumetric")
446        ))
447
448        self.device = "auto"
449        device_options = ["auto"] + util._available_devices()
450        self.device_dropdown, layout = self._add_choice_param(
451            "device", self.device, device_options, tooltip=get_tooltip("embedding", "device")
452        )
453        setting_values.layout().addLayout(layout)
454
455        self.embeddings_save_path = None
456        _, layout = self._add_path_param(
457            "embeddings_save_path", self.embeddings_save_path, "directory", title="embeddings save path:",
458            tooltip=get_tooltip("embedding", "embeddings_save_path")
459        )
460        setting_values.layout().addLayout(layout)
461
462        self.custom_weights = None  # select_file
463        _, layout = self._add_path_param(
464            "custom_weights", self.custom_weights, "file", title="custom weights path:",
465            tooltip=get_tooltip("embedding", "custom_weights")
466        )
467        setting_values.layout().addLayout(layout)
468
469        self.tile_x, self.tile_y = 0, 0
470        self.tile_x_param, self.tile_y_param, layout = self._add_shape_param(
471            ("tile_x", "tile_y"), (self.tile_x, self.tile_y), min_val=0, max_val=2048, step=16,
472            tooltip=get_tooltip("embedding", "tiling")
473        )
474        setting_values.layout().addLayout(layout)
475
476        self.halo_x, self.halo_y = 0, 0
477        self.halo_x_param, self.halo_y_param, layout = self._add_shape_param(
478            ("halo_x", "halo_y"), (self.halo_x, self.halo_y), min_val=0, max_val=512,
479            tooltip=get_tooltip("embedding", "halo")
480        )
481        setting_values.layout().addLayout(layout)
482
483        settings = widgets._make_collapsible(setting_values, title="Advanced Settings")
484        return settings
485
486    def _validate_inputs(self):
487        missing_data = self.folder is None or len(glob(os.path.join(self.folder, self.pattern))) == 0
488        missing_output = self.output_folder is None
489        if missing_data or missing_output:
490            msg = ""
491            if missing_data:
492                msg += "The input folder is missing or empty. "
493            if missing_output:
494                msg += "The output folder is missing."
495            return widgets._generate_message("error", msg)
496        return False
497
498    def __call__(self, skip_validate=False):
499        self._validate_model_type_and_custom_weights()
500
501        if not skip_validate and self._validate_inputs():
502            return
503        tile_shape, halo = widgets._process_tiling_inputs(self.tile_x, self.tile_y, self.halo_x, self.halo_y)
504
505        image_folder_annotator(
506            input_folder=self.folder,
507            output_folder=self.output_folder,
508            pattern=self.pattern,
509            model_type=self.model_type,
510            embedding_path=self.embeddings_save_path,
511            tile_shape=tile_shape, halo=halo, checkpoint_path=self.custom_weights,
512            device=self.device, is_volumetric=self.is_volumetric,
513            viewer=self._viewer, return_viewer=True,
514        )

QWidget(parent: typing.Optional[QWidget] = None, flags: Union[Qt.WindowFlags, Qt.WindowType] = Qt.WindowFlags())

ImageSeriesAnnotator(viewer: napari.viewer.Viewer, parent=None)
393    def __init__(self, viewer: napari.Viewer, parent=None):
394        super().__init__(parent=parent)
395        self._viewer = viewer
396
397        # Create the UI: the general options.
398        self._create_options()
399
400        # Add the settings (collapsible).
401        self.layout().addWidget(self._create_settings())
402
403        # Add the run button to trigger the embedding computation.
404        self.run_button = QtWidgets.QPushButton("Annotate Images")
405        self.run_button.clicked.connect(self.__call__)
406        self.layout().addWidget(self.run_button)
run_button
Inherited Members
PyQt5.QtWidgets.QWidget
RenderFlag
RenderFlags
acceptDrops
accessibleDescription
accessibleName
actionEvent
actions
activateWindow
addAction
addActions
adjustSize
autoFillBackground
backgroundRole
baseSize
changeEvent
childAt
childrenRect
childrenRegion
clearFocus
clearMask
close
closeEvent
contentsMargins
contentsRect
contextMenuEvent
contextMenuPolicy
create
createWindowContainer
cursor
destroy
devType
dragEnterEvent
dragLeaveEvent
dragMoveEvent
dropEvent
effectiveWinId
ensurePolished
enterEvent
event
find
focusInEvent
focusNextChild
focusNextPrevChild
focusOutEvent
focusPolicy
focusPreviousChild
focusProxy
focusWidget
font
fontInfo
fontMetrics
foregroundRole
frameGeometry
frameSize
geometry
getContentsMargins
grab
grabGesture
grabKeyboard
grabMouse
grabShortcut
graphicsEffect
graphicsProxyWidget
hasFocus
hasHeightForWidth
hasMouseTracking
hasTabletTracking
height
heightForWidth
hide
hideEvent
initPainter
inputMethodEvent
inputMethodHints
inputMethodQuery
insertAction
insertActions
isActiveWindow
isAncestorOf
isEnabled
isEnabledTo
isFullScreen
isHidden
isLeftToRight
isMaximized
isMinimized
isModal
isRightToLeft
isVisible
isVisibleTo
isWindow
isWindowModified
keyPressEvent
keyReleaseEvent
keyboardGrabber
layout
layoutDirection
leaveEvent
locale
lower
mapFrom
mapFromGlobal
mapFromParent
mapTo
mapToGlobal
mapToParent
mask
maximumHeight
maximumSize
maximumWidth
metric
minimumHeight
minimumSize
minimumSizeHint
minimumWidth
mouseDoubleClickEvent
mouseGrabber
mouseMoveEvent
mousePressEvent
mouseReleaseEvent
move
moveEvent
nativeEvent
nativeParentWidget
nextInFocusChain
normalGeometry
overrideWindowFlags
overrideWindowState
paintEngine
paintEvent
palette
parentWidget
pos
previousInFocusChain
raise_
rect
releaseKeyboard
releaseMouse
releaseShortcut
removeAction
render
repaint
resize
resizeEvent
restoreGeometry
saveGeometry
screen
scroll
setAcceptDrops
setAccessibleDescription
setAccessibleName
setAttribute
setAutoFillBackground
setBackgroundRole
setBaseSize
setContentsMargins
setContextMenuPolicy
setCursor
setDisabled
setEnabled
setFixedHeight
setFixedSize
setFixedWidth
setFocus
setFocusPolicy
setFocusProxy
setFont
setForegroundRole
setGeometry
setGraphicsEffect
setHidden
setInputMethodHints
setLayout
setLayoutDirection
setLocale
setMask
setMaximumHeight
setMaximumSize
setMaximumWidth
setMinimumHeight
setMinimumSize
setMinimumWidth
setMouseTracking
setPalette
setParent
setShortcutAutoRepeat
setShortcutEnabled
setSizeIncrement
setSizePolicy
setStatusTip
setStyle
setStyleSheet
setTabOrder
setTabletTracking
setToolTip
setToolTipDuration
setUpdatesEnabled
setVisible
setWhatsThis
setWindowFilePath
setWindowFlag
setWindowFlags
setWindowIcon
setWindowIconText
setWindowModality
setWindowModified
setWindowOpacity
setWindowRole
setWindowState
setWindowTitle
sharedPainter
show
showEvent
showFullScreen
showMaximized
showMinimized
showNormal
size
sizeHint
sizeIncrement
sizePolicy
stackUnder
statusTip
style
styleSheet
tabletEvent
testAttribute
toolTip
toolTipDuration
underMouse
ungrabGesture
unsetCursor
unsetLayoutDirection
unsetLocale
update
updateGeometry
updateMicroFocus
updatesEnabled
visibleRegion
whatsThis
wheelEvent
width
winId
window
windowFilePath
windowFlags
windowHandle
windowIcon
windowIconText
windowModality
windowOpacity
windowRole
windowState
windowTitle
windowType
x
y
DrawChildren
DrawWindowBackground
IgnoreMask
windowIconTextChanged
windowIconChanged
windowTitleChanged
customContextMenuRequested
PyQt5.QtCore.QObject
blockSignals
childEvent
children
connectNotify
customEvent
deleteLater
disconnect
disconnectNotify
dumpObjectInfo
dumpObjectTree
dynamicPropertyNames
eventFilter
findChild
findChildren
inherits
installEventFilter
isSignalConnected
isWidgetType
isWindowType
killTimer
metaObject
moveToThread
objectName
parent
property
pyqtConfigure
receivers
removeEventFilter
sender
senderSignalIndex
setObjectName
setProperty
signalsBlocked
startTimer
thread
timerEvent
tr
staticMetaObject
objectNameChanged
destroyed
PyQt5.QtGui.QPaintDevice
PaintDeviceMetric
colorCount
depth
devicePixelRatio
devicePixelRatioF
devicePixelRatioFScale
heightMM
logicalDpiX
logicalDpiY
paintingActive
physicalDpiX
physicalDpiY
widthMM
PdmDepth
PdmDevicePixelRatio
PdmDevicePixelRatioScaled
PdmDpiX
PdmDpiY
PdmHeight
PdmHeightMM
PdmNumColors
PdmPhysicalDpiX
PdmPhysicalDpiY
PdmWidth
PdmWidthMM