micro_sam.sam_annotator.image_series_annotator

  1import os
  2import time
  3from glob import glob
  4from pathlib import Path
  5from typing import List, Optional, Union, Tuple
  6
  7import numpy as np
  8import imageio.v3 as imageio
  9
 10import torch
 11
 12import napari
 13from magicgui import magicgui
 14from qtpy import QtWidgets
 15
 16from .. import util
 17from . import _widgets as widgets
 18from ._tooltips import get_tooltip
 19from ._state import AnnotatorState
 20from .annotator_2d import Annotator2d
 21from .annotator_3d import Annotator3d
 22from .util import _sync_embedding_widget
 23from ..instance_segmentation import get_decoder
 24from ..precompute_state import _precompute_state_for_files
 25
 26
 27def _precompute(
 28    images, model_type, embedding_path,
 29    tile_shape, halo, precompute_amg_state,
 30    checkpoint_path, device, ndim, prefer_decoder,
 31):
 32    t_start = time.time()
 33
 34    device = util.get_device(device)
 35    predictor, state = util.get_sam_model(
 36        model_type=model_type, checkpoint_path=checkpoint_path, device=device, return_state=True
 37    )
 38    if prefer_decoder and "decoder_state" in state:
 39        decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device)
 40    else:
 41        decoder = None
 42
 43    if embedding_path is None:
 44        embedding_paths = [None] * len(images)
 45    else:
 46        _precompute_state_for_files(
 47            predictor, images, embedding_path, ndim=ndim, tile_shape=tile_shape, halo=halo,
 48            precompute_amg_state=precompute_amg_state, decoder=decoder,
 49        )
 50        if isinstance(images[0], np.ndarray):
 51            embedding_paths = [
 52                os.path.join(embedding_path, f"embedding_{i:05}.zarr") for i, path in enumerate(images)
 53            ]
 54        else:
 55            embedding_paths = [
 56                os.path.join(embedding_path, f"{Path(path).stem}.zarr") for path in images
 57            ]
 58        assert all(os.path.exists(emb_path) for emb_path in embedding_paths)
 59
 60    t_run = time.time() - t_start
 61    minutes = int(t_run // 60)
 62    seconds = int(round(t_run % 60, 0))
 63    print("Precomputation took", t_run, f"seconds (= {minutes:02}:{seconds:02} minutes)")
 64
 65    return predictor, decoder, embedding_paths
 66
 67
 68def _get_input_shape(image, is_volumetric=False):
 69    if image.ndim == 2:
 70        image_shape = image.shape
 71    elif image.ndim == 3:
 72        if is_volumetric:
 73            image_shape = image.shape
 74        else:
 75            image_shape = image.shape[:-1]
 76    elif image.ndim == 4:
 77        image_shape = image.shape[:-1]
 78
 79    return image_shape
 80
 81
 82def _initialize_annotator(
 83    viewer, image, image_embedding_path,
 84    model_type, halo, tile_shape, predictor, decoder, is_volumetric,
 85    precompute_amg_state, checkpoint_path, device, embedding_path,
 86    segmentation_path,
 87):
 88    if viewer is None:
 89        viewer = napari.Viewer()
 90    viewer.add_image(image, name="image")
 91
 92    state = AnnotatorState()
 93    state.initialize_predictor(
 94        image, model_type=model_type, save_path=image_embedding_path, halo=halo, tile_shape=tile_shape,
 95        predictor=predictor, decoder=decoder,
 96        ndim=3 if is_volumetric else 2, precompute_amg_state=precompute_amg_state,
 97        checkpoint_path=checkpoint_path, device=device, skip_load=False,
 98    )
 99    state.image_shape = _get_input_shape(image, is_volumetric)
100
101    if is_volumetric:
102        if image.ndim not in [3, 4]:
103            raise ValueError(f"Invalid image dimensions for 3d annotator, expect 3 or 4 dimensions, got {image.ndim}")
104        annotator = Annotator3d(viewer)
105    else:
106        if image.ndim not in (2, 3):
107            raise ValueError(f"Invalid image dimensions for 2d annotator, expect 2 or 3 dimensions, got {image.ndim}")
108        annotator = Annotator2d(viewer)
109
110    if os.path.exists(segmentation_path):
111        segmentation_result = imageio.imread(segmentation_path)
112    else:
113        segmentation_result = None
114    annotator._update_image(segmentation_result=segmentation_result)
115
116    # Add the annotator widget to the viewer and sync widgets.
117    viewer.window.add_dock_widget(annotator)
118    _sync_embedding_widget(
119        widget=state.widgets["embeddings"],
120        model_type=model_type if checkpoint_path is None else state.predictor.model_type,
121        save_path=embedding_path,
122        checkpoint_path=checkpoint_path,
123        device=device,
124        tile_shape=tile_shape,
125        halo=halo,
126    )
127    return viewer, annotator
128
129
130def image_series_annotator(
131    images: Union[List[Union[os.PathLike, str]], List[np.ndarray]],
132    output_folder: str,
133    model_type: str = util._DEFAULT_MODEL,
134    embedding_path: Optional[str] = None,
135    tile_shape: Optional[Tuple[int, int]] = None,
136    halo: Optional[Tuple[int, int]] = None,
137    viewer: Optional["napari.viewer.Viewer"] = None,
138    return_viewer: bool = False,
139    precompute_amg_state: bool = False,
140    checkpoint_path: Optional[str] = None,
141    is_volumetric: bool = False,
142    device: Optional[Union[str, torch.device]] = None,
143    prefer_decoder: bool = True,
144    skip_segmented: bool = True,
145) -> Optional["napari.viewer.Viewer"]:
146    """Run the annotation tool for a series of images (supported for both 2d and 3d images).
147
148    Args:
149        images: List of the file paths or list of (set of) slices for the images to be annotated.
150        output_folder: The folder where the segmentation results are saved.
151        model_type: The Segment Anything model to use. For details on the available models check out
152            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
153        embedding_path: Filepath where to save the embeddings.
154        tile_shape: Shape of tiles for tiled embedding prediction.
155            If `None` then the whole image is passed to Segment Anything.
156        halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
157        viewer: The viewer to which the Segment Anything functionality should be added.
158            This enables using a pre-initialized viewer.
159        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
160        precompute_amg_state: Whether to precompute the state for automatic mask generation.
161            This will take more time when precomputing embeddings, but will then make
162            automatic mask generation much faster.
163        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
164        is_volumetric: Whether to use the 3d annotator.
165        prefer_decoder: Whether to use decoder based instance segmentation if
166            the model used has an additional decoder for instance segmentation.
167        skip_segmented: Whether to skip images that were already segmented.
168            If set to False, then segmentations that already exist will be loaded
169            and used to populate the 'committed_objects' layer.
170
171    Returns:
172        The napari viewer, only returned if `return_viewer=True`.
173    """
174    end_msg = "You have annotated the last image. Do you wish to close napari?"
175    os.makedirs(output_folder, exist_ok=True)
176
177    # Precompute embeddings and amg state (if corresponding options set).
178    predictor, decoder, embedding_paths = _precompute(
179        images, model_type,
180        embedding_path, tile_shape, halo, precompute_amg_state,
181        checkpoint_path=checkpoint_path, device=device,
182        ndim=3 if is_volumetric else 2, prefer_decoder=prefer_decoder,
183    )
184
185    next_image_id = 0
186    have_inputs_as_arrays = isinstance(images[next_image_id], np.ndarray)
187
188    def _get_save_path(image_path, current_idx):
189        if have_inputs_as_arrays:
190            fname = f"seg_{current_idx:05}.tif"
191        else:
192            fname = os.path.basename(image_path)
193            fname = os.path.splitext(fname)[0] + ".tif"
194        return os.path.join(output_folder, fname)
195
196    def _load_image(image_id):
197        image = images[next_image_id]
198        if not have_inputs_as_arrays:
199            image = imageio.imread(image)
200        image_embedding_path = embedding_paths[next_image_id]
201        return image, image_embedding_path
202
203    # Check which image to load next if we skip segmented images.
204    if skip_segmented:
205        while True:
206            if next_image_id == len(images):
207                print("All images have already been annotated and you have set 'skip_segmented=True'. Nothing to do.")
208                return
209
210            save_path = _get_save_path(images[next_image_id], next_image_id)
211            if not os.path.exists(save_path):
212                print("The first image to annotate is image number", next_image_id)
213                image, image_embedding_path = _load_image(next_image_id)
214                break
215
216            next_image_id += 1
217
218    else:
219        save_path = _get_save_path(images[next_image_id], next_image_id)
220        image, image_embedding_path = _load_image(next_image_id)
221
222    # Initialize the viewer and annotator for this image.
223    state = AnnotatorState()
224    viewer, annotator = _initialize_annotator(
225        viewer, image, image_embedding_path,
226        model_type, halo, tile_shape, predictor, decoder, is_volumetric,
227        precompute_amg_state, checkpoint_path, device, embedding_path,
228        save_path,
229    )
230
231    def _save_segmentation(image_path, current_idx, segmentation):
232        save_path = _get_save_path(image_path, next_image_id)
233        imageio.imwrite(save_path, segmentation, compression="zlib")
234
235    # Add functionality for going to the next image.
236    @magicgui(call_button="Next Image [N]")
237    def next_image(*args):
238        nonlocal next_image_id
239
240        segmentation = viewer.layers["committed_objects"].data
241        abort = False
242        if segmentation.sum() == 0:
243            msg = "Nothing is segmented yet. Do you wish to continue to the next image?"
244            abort = widgets._generate_message("info", msg)
245            if abort:
246                return
247
248        # Save the current segmentation.
249        _save_segmentation(images[next_image_id], next_image_id, segmentation)
250
251        # Clear the segmentation already to avoid lagging removal.
252        viewer.layers["committed_objects"].data = np.zeros_like(viewer.layers["committed_objects"].data)
253
254        # Go to the next image.
255        next_image_id += 1
256
257        # Check if we are done.
258        if next_image_id == len(images):
259            # Inform the user via dialog.
260            abort = widgets._generate_message("info", end_msg)
261            if not abort:
262                viewer.close()
263            return
264
265        # If we are skipping images that are already segmented, then check if we have to load the next image.
266        save_path = _get_save_path(images[next_image_id], next_image_id)
267        if skip_segmented:
268            segmentation_result = None
269            while os.path.exists(save_path):
270                next_image_id += 1
271
272                # Check if we are done.
273                if next_image_id == len(images):
274                    # Inform the user via dialog.
275                    abort = widgets._generate_message("info", end_msg)
276                    if not abort:
277                        viewer.close()
278                    return
279
280                save_path = _get_save_path(images[next_image_id], next_image_id)
281        else:
282            if os.path.exists(save_path):
283                segmentation_result = imageio.imread(save_path)
284            else:
285                segmentation_result = None
286
287        print(
288            "Loading next image:", images[next_image_id] if not have_inputs_as_arrays else f"at index {next_image_id}"
289        )
290
291        if have_inputs_as_arrays:
292            image = images[next_image_id]
293        else:
294            image = imageio.imread(images[next_image_id])
295
296        image_embedding_path = embedding_paths[next_image_id]
297
298        # Set the new image in the viewer, state and annotator.
299        viewer.layers["image"].data = image
300
301        if state.amg is not None:
302            state.amg.clear_state()
303
304        state.initialize_predictor(
305            image, model_type=model_type, ndim=3 if is_volumetric else 2,
306            save_path=image_embedding_path,
307            tile_shape=tile_shape, halo=halo,
308            predictor=predictor, decoder=decoder,
309            precompute_amg_state=precompute_amg_state, device=device,
310            skip_load=False,
311        )
312        state.image_shape = _get_input_shape(image, is_volumetric)
313
314        annotator._update_image(segmentation_result=segmentation_result)
315
316    viewer.window.add_dock_widget(next_image)
317
318    @viewer.bind_key("n", overwrite=True)
319    def _next_image(viewer):
320        next_image(viewer)
321
322    if return_viewer:
323        return viewer
324    napari.run()
325
326
327def image_folder_annotator(
328    input_folder: str,
329    output_folder: str,
330    pattern: str = "*",
331    viewer: Optional["napari.viewer.Viewer"] = None,
332    return_viewer: bool = False,
333    **kwargs
334) -> Optional["napari.viewer.Viewer"]:
335    """Run the 2d annotation tool for a series of images in a folder.
336
337    Args:
338        input_folder: The folder with the images to be annotated.
339        output_folder: The folder where the segmentation results are saved.
340        pattern: The glob patter for loading files from `input_folder`.
341            By default all files will be loaded.
342        viewer: The viewer to which the Segment Anything functionality should be added.
343            This enables using a pre-initialized viewer.
344        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
345        kwargs: The keyword arguments for `micro_sam.sam_annotator.image_series_annotator`.
346
347    Returns:
348        The napari viewer, only returned if `return_viewer=True`.
349    """
350    image_files = sorted(glob(os.path.join(input_folder, pattern)))
351
352    return image_series_annotator(
353        image_files, output_folder, viewer=viewer, return_viewer=return_viewer, **kwargs
354    )
355
356
357class ImageSeriesAnnotator(widgets._WidgetBase):
358    def __init__(self, viewer: napari.Viewer, parent=None):
359        super().__init__(parent=parent)
360        self._viewer = viewer
361
362        # Create the UI: the general options.
363        self._create_options()
364
365        # Add the settings (collapsible).
366        self.layout().addWidget(self._create_settings())
367
368        # Add the run button to trigger the embedding computation.
369        self.run_button = QtWidgets.QPushButton("Annotate Images")
370        self.run_button.clicked.connect(self.__call__)
371        self.layout().addWidget(self.run_button)
372
373    def _create_options(self):
374        self.folder = None
375        _, layout = self._add_path_param(
376            "folder", self.folder, "directory",
377            title="Input Folder", placeholder="Folder with images ...",
378            tooltip=get_tooltip("image_series_annotator", "folder")
379        )
380        self.layout().addLayout(layout)
381
382        self.output_folder = None
383        _, layout = self._add_path_param(
384            "output_folder", self.output_folder, "directory",
385            title="Output Folder", placeholder="Folder to save the results ...",
386            tooltip=get_tooltip("image_series_annotator", "output_folder")
387        )
388        self.layout().addLayout(layout)
389
390        # Add the model family widget section.
391        layout = self._create_model_section(create_layout=False)
392        self.layout().addLayout(layout)
393
394    def _create_settings(self):
395        setting_values = QtWidgets.QWidget()
396        setting_values.setLayout(QtWidgets.QVBoxLayout())
397
398        # Add the model size widget section.
399        layout = self._create_model_size_section()
400        setting_values.layout().addLayout(layout)
401
402        self.pattern = "*"
403        _, layout = self._add_string_param(
404            "pattern", self.pattern, tooltip=get_tooltip("image_series_annotator", "pattern")
405        )
406        setting_values.layout().addLayout(layout)
407
408        self.is_volumetric = False
409        setting_values.layout().addWidget(self._add_boolean_param(
410            "is_volumetric", self.is_volumetric, tooltip=get_tooltip("image_series_annotator", "is_volumetric")
411        ))
412
413        self.device = "auto"
414        device_options = ["auto"] + util._available_devices()
415        self.device_dropdown, layout = self._add_choice_param(
416            "device", self.device, device_options, tooltip=get_tooltip("embedding", "device")
417        )
418        setting_values.layout().addLayout(layout)
419
420        self.embeddings_save_path = None
421        _, layout = self._add_path_param(
422            "embeddings_save_path", self.embeddings_save_path, "directory", title="embeddings save path:",
423            tooltip=get_tooltip("embedding", "embeddings_save_path")
424        )
425        setting_values.layout().addLayout(layout)
426
427        self.custom_weights = None  # select_file
428        _, layout = self._add_path_param(
429            "custom_weights", self.custom_weights, "file", title="custom weights path:",
430            tooltip=get_tooltip("embedding", "custom_weights")
431        )
432        setting_values.layout().addLayout(layout)
433
434        self.tile_x, self.tile_y = 0, 0
435        self.tile_x_param, self.tile_y_param, layout = self._add_shape_param(
436            ("tile_x", "tile_y"), (self.tile_x, self.tile_y), min_val=0, max_val=2048, step=16,
437            tooltip=get_tooltip("embedding", "tiling")
438        )
439        setting_values.layout().addLayout(layout)
440
441        self.halo_x, self.halo_y = 0, 0
442        self.halo_x_param, self.halo_y_param, layout = self._add_shape_param(
443            ("halo_x", "halo_y"), (self.halo_x, self.halo_y), min_val=0, max_val=512,
444            tooltip=get_tooltip("embedding", "halo")
445        )
446        setting_values.layout().addLayout(layout)
447
448        settings = widgets._make_collapsible(setting_values, title="Advanced Settings")
449        return settings
450
451    def _validate_inputs(self):
452        missing_data = self.folder is None or len(glob(os.path.join(self.folder, self.pattern))) == 0
453        missing_output = self.output_folder is None
454        if missing_data or missing_output:
455            msg = ""
456            if missing_data:
457                msg += "The input folder is missing or empty. "
458            if missing_output:
459                msg += "The output folder is missing."
460            return widgets._generate_message("error", msg)
461        return False
462
463    def __call__(self, skip_validate=False):
464        self._validate_model_type_and_custom_weights()
465
466        if not skip_validate and self._validate_inputs():
467            return
468        tile_shape, halo = widgets._process_tiling_inputs(self.tile_x, self.tile_y, self.halo_x, self.halo_y)
469
470        image_folder_annotator(
471            input_folder=self.folder,
472            output_folder=self.output_folder,
473            pattern=self.pattern,
474            model_type=self.model_type,
475            embedding_path=self.embeddings_save_path,
476            tile_shape=tile_shape, halo=halo, checkpoint_path=self.custom_weights,
477            device=self.device, is_volumetric=self.is_volumetric,
478            viewer=self._viewer, return_viewer=True,
479        )
480
481
482def main():
483    """@private"""
484    import argparse
485
486    available_models = list(util.get_model_names())
487    available_models = ", ".join(available_models)
488
489    parser = argparse.ArgumentParser(description="Annotate a series of images from a folder.")
490    parser.add_argument(
491        "-i", "--input_folder", required=True,
492        help="The folder containing the image data. The data can be stored in any common format (tif, jpg, png, ...)."
493    )
494    parser.add_argument(
495        "-o", "--output_folder", required=True,
496        help="The folder where the segmentation results will be stored."
497    )
498    parser.add_argument(
499        "-p", "--pattern", default="*",
500        help="The pattern to select the images to annotator from the input folder. E.g. *.tif to annotate all tifs."
501        "By default all files in the folder will be loaded and annotated."
502    )
503    parser.add_argument(
504        "-e", "--embedding_path",
505        help="The filepath for saving/loading the pre-computed image embeddings. "
506        "NOTE: It is recommended to pass this argument and store the embeddings, "
507        "otherwise they will be recomputed every time (which can take a long time)."
508    )
509    parser.add_argument(
510        "-m", "--model_type", default=util._DEFAULT_MODEL,
511        help=f"The segment anything model that will be used, one of {available_models}."
512    )
513    parser.add_argument(
514        "-c", "--checkpoint", default=None,
515        help="Checkpoint from which the SAM model will be loaded loaded."
516    )
517    parser.add_argument(
518        "-d", "--device", default=None,
519        help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)."
520        "By default the most performant available device will be selected."
521    )
522    parser.add_argument(
523        "--is_volumetric", action="store_true", help="Whether to use the 3d annotator for a set of 3d volumes."
524    )
525
526    parser.add_argument(
527        "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction", default=None
528    )
529    parser.add_argument(
530        "--halo", nargs="+", type=int, help="The halo for using tiled prediction", default=None
531    )
532    parser.add_argument("--precompute_amg_state", action="store_true")
533    parser.add_argument("--prefer_decoder", action="store_false")
534    parser.add_argument("--skip_segmented", action="store_false")
535
536    args = parser.parse_args()
537
538    image_folder_annotator(
539        args.input_folder, args.output_folder, args.pattern,
540        embedding_path=args.embedding_path, model_type=args.model_type,
541        tile_shape=args.tile_shape, halo=args.halo, precompute_amg_state=args.precompute_amg_state,
542        checkpoint_path=args.checkpoint, device=args.device, is_volumetric=args.is_volumetric,
543        prefer_decoder=args.prefer_decoder, skip_segmented=args.skip_segmented
544    )
def image_series_annotator( images: Union[List[Union[os.PathLike, str]], List[numpy.ndarray]], output_folder: str, model_type: str = 'vit_b_lm', embedding_path: Optional[str] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, viewer: Optional[napari.viewer.Viewer] = None, return_viewer: bool = False, precompute_amg_state: bool = False, checkpoint_path: Optional[str] = None, is_volumetric: bool = False, device: Union[str, torch.device, NoneType] = None, prefer_decoder: bool = True, skip_segmented: bool = True) -> Optional[napari.viewer.Viewer]:
131def image_series_annotator(
132    images: Union[List[Union[os.PathLike, str]], List[np.ndarray]],
133    output_folder: str,
134    model_type: str = util._DEFAULT_MODEL,
135    embedding_path: Optional[str] = None,
136    tile_shape: Optional[Tuple[int, int]] = None,
137    halo: Optional[Tuple[int, int]] = None,
138    viewer: Optional["napari.viewer.Viewer"] = None,
139    return_viewer: bool = False,
140    precompute_amg_state: bool = False,
141    checkpoint_path: Optional[str] = None,
142    is_volumetric: bool = False,
143    device: Optional[Union[str, torch.device]] = None,
144    prefer_decoder: bool = True,
145    skip_segmented: bool = True,
146) -> Optional["napari.viewer.Viewer"]:
147    """Run the annotation tool for a series of images (supported for both 2d and 3d images).
148
149    Args:
150        images: List of the file paths or list of (set of) slices for the images to be annotated.
151        output_folder: The folder where the segmentation results are saved.
152        model_type: The Segment Anything model to use. For details on the available models check out
153            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
154        embedding_path: Filepath where to save the embeddings.
155        tile_shape: Shape of tiles for tiled embedding prediction.
156            If `None` then the whole image is passed to Segment Anything.
157        halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
158        viewer: The viewer to which the Segment Anything functionality should be added.
159            This enables using a pre-initialized viewer.
160        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
161        precompute_amg_state: Whether to precompute the state for automatic mask generation.
162            This will take more time when precomputing embeddings, but will then make
163            automatic mask generation much faster.
164        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
165        is_volumetric: Whether to use the 3d annotator.
166        prefer_decoder: Whether to use decoder based instance segmentation if
167            the model used has an additional decoder for instance segmentation.
168        skip_segmented: Whether to skip images that were already segmented.
169            If set to False, then segmentations that already exist will be loaded
170            and used to populate the 'committed_objects' layer.
171
172    Returns:
173        The napari viewer, only returned if `return_viewer=True`.
174    """
175    end_msg = "You have annotated the last image. Do you wish to close napari?"
176    os.makedirs(output_folder, exist_ok=True)
177
178    # Precompute embeddings and amg state (if corresponding options set).
179    predictor, decoder, embedding_paths = _precompute(
180        images, model_type,
181        embedding_path, tile_shape, halo, precompute_amg_state,
182        checkpoint_path=checkpoint_path, device=device,
183        ndim=3 if is_volumetric else 2, prefer_decoder=prefer_decoder,
184    )
185
186    next_image_id = 0
187    have_inputs_as_arrays = isinstance(images[next_image_id], np.ndarray)
188
189    def _get_save_path(image_path, current_idx):
190        if have_inputs_as_arrays:
191            fname = f"seg_{current_idx:05}.tif"
192        else:
193            fname = os.path.basename(image_path)
194            fname = os.path.splitext(fname)[0] + ".tif"
195        return os.path.join(output_folder, fname)
196
197    def _load_image(image_id):
198        image = images[next_image_id]
199        if not have_inputs_as_arrays:
200            image = imageio.imread(image)
201        image_embedding_path = embedding_paths[next_image_id]
202        return image, image_embedding_path
203
204    # Check which image to load next if we skip segmented images.
205    if skip_segmented:
206        while True:
207            if next_image_id == len(images):
208                print("All images have already been annotated and you have set 'skip_segmented=True'. Nothing to do.")
209                return
210
211            save_path = _get_save_path(images[next_image_id], next_image_id)
212            if not os.path.exists(save_path):
213                print("The first image to annotate is image number", next_image_id)
214                image, image_embedding_path = _load_image(next_image_id)
215                break
216
217            next_image_id += 1
218
219    else:
220        save_path = _get_save_path(images[next_image_id], next_image_id)
221        image, image_embedding_path = _load_image(next_image_id)
222
223    # Initialize the viewer and annotator for this image.
224    state = AnnotatorState()
225    viewer, annotator = _initialize_annotator(
226        viewer, image, image_embedding_path,
227        model_type, halo, tile_shape, predictor, decoder, is_volumetric,
228        precompute_amg_state, checkpoint_path, device, embedding_path,
229        save_path,
230    )
231
232    def _save_segmentation(image_path, current_idx, segmentation):
233        save_path = _get_save_path(image_path, next_image_id)
234        imageio.imwrite(save_path, segmentation, compression="zlib")
235
236    # Add functionality for going to the next image.
237    @magicgui(call_button="Next Image [N]")
238    def next_image(*args):
239        nonlocal next_image_id
240
241        segmentation = viewer.layers["committed_objects"].data
242        abort = False
243        if segmentation.sum() == 0:
244            msg = "Nothing is segmented yet. Do you wish to continue to the next image?"
245            abort = widgets._generate_message("info", msg)
246            if abort:
247                return
248
249        # Save the current segmentation.
250        _save_segmentation(images[next_image_id], next_image_id, segmentation)
251
252        # Clear the segmentation already to avoid lagging removal.
253        viewer.layers["committed_objects"].data = np.zeros_like(viewer.layers["committed_objects"].data)
254
255        # Go to the next image.
256        next_image_id += 1
257
258        # Check if we are done.
259        if next_image_id == len(images):
260            # Inform the user via dialog.
261            abort = widgets._generate_message("info", end_msg)
262            if not abort:
263                viewer.close()
264            return
265
266        # If we are skipping images that are already segmented, then check if we have to load the next image.
267        save_path = _get_save_path(images[next_image_id], next_image_id)
268        if skip_segmented:
269            segmentation_result = None
270            while os.path.exists(save_path):
271                next_image_id += 1
272
273                # Check if we are done.
274                if next_image_id == len(images):
275                    # Inform the user via dialog.
276                    abort = widgets._generate_message("info", end_msg)
277                    if not abort:
278                        viewer.close()
279                    return
280
281                save_path = _get_save_path(images[next_image_id], next_image_id)
282        else:
283            if os.path.exists(save_path):
284                segmentation_result = imageio.imread(save_path)
285            else:
286                segmentation_result = None
287
288        print(
289            "Loading next image:", images[next_image_id] if not have_inputs_as_arrays else f"at index {next_image_id}"
290        )
291
292        if have_inputs_as_arrays:
293            image = images[next_image_id]
294        else:
295            image = imageio.imread(images[next_image_id])
296
297        image_embedding_path = embedding_paths[next_image_id]
298
299        # Set the new image in the viewer, state and annotator.
300        viewer.layers["image"].data = image
301
302        if state.amg is not None:
303            state.amg.clear_state()
304
305        state.initialize_predictor(
306            image, model_type=model_type, ndim=3 if is_volumetric else 2,
307            save_path=image_embedding_path,
308            tile_shape=tile_shape, halo=halo,
309            predictor=predictor, decoder=decoder,
310            precompute_amg_state=precompute_amg_state, device=device,
311            skip_load=False,
312        )
313        state.image_shape = _get_input_shape(image, is_volumetric)
314
315        annotator._update_image(segmentation_result=segmentation_result)
316
317    viewer.window.add_dock_widget(next_image)
318
319    @viewer.bind_key("n", overwrite=True)
320    def _next_image(viewer):
321        next_image(viewer)
322
323    if return_viewer:
324        return viewer
325    napari.run()

Run the annotation tool for a series of images (supported for both 2d and 3d images).

Arguments:
  • images: List of the file paths or list of (set of) slices for the images to be annotated.
  • output_folder: The folder where the segmentation results are saved.
  • model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
  • embedding_path: Filepath where to save the embeddings.
  • tile_shape: Shape of tiles for tiled embedding prediction. If None then the whole image is passed to Segment Anything.
  • halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
  • viewer: The viewer to which the Segment Anything functionality should be added. This enables using a pre-initialized viewer.
  • return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
  • precompute_amg_state: Whether to precompute the state for automatic mask generation. This will take more time when precomputing embeddings, but will then make automatic mask generation much faster.
  • checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
  • is_volumetric: Whether to use the 3d annotator.
  • prefer_decoder: Whether to use decoder based instance segmentation if the model used has an additional decoder for instance segmentation.
  • skip_segmented: Whether to skip images that were already segmented. If set to False, then segmentations that already exist will be loaded and used to populate the 'committed_objects' layer.
Returns:

The napari viewer, only returned if return_viewer=True.

def image_folder_annotator( input_folder: str, output_folder: str, pattern: str = '*', viewer: Optional[napari.viewer.Viewer] = None, return_viewer: bool = False, **kwargs) -> Optional[napari.viewer.Viewer]:
328def image_folder_annotator(
329    input_folder: str,
330    output_folder: str,
331    pattern: str = "*",
332    viewer: Optional["napari.viewer.Viewer"] = None,
333    return_viewer: bool = False,
334    **kwargs
335) -> Optional["napari.viewer.Viewer"]:
336    """Run the 2d annotation tool for a series of images in a folder.
337
338    Args:
339        input_folder: The folder with the images to be annotated.
340        output_folder: The folder where the segmentation results are saved.
341        pattern: The glob patter for loading files from `input_folder`.
342            By default all files will be loaded.
343        viewer: The viewer to which the Segment Anything functionality should be added.
344            This enables using a pre-initialized viewer.
345        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
346        kwargs: The keyword arguments for `micro_sam.sam_annotator.image_series_annotator`.
347
348    Returns:
349        The napari viewer, only returned if `return_viewer=True`.
350    """
351    image_files = sorted(glob(os.path.join(input_folder, pattern)))
352
353    return image_series_annotator(
354        image_files, output_folder, viewer=viewer, return_viewer=return_viewer, **kwargs
355    )

Run the 2d annotation tool for a series of images in a folder.

Arguments:
  • input_folder: The folder with the images to be annotated.
  • output_folder: The folder where the segmentation results are saved.
  • pattern: The glob patter for loading files from input_folder. By default all files will be loaded.
  • viewer: The viewer to which the Segment Anything functionality should be added. This enables using a pre-initialized viewer.
  • return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
  • kwargs: The keyword arguments for micro_sam.sam_annotator.image_series_annotator.
Returns:

The napari viewer, only returned if return_viewer=True.

class ImageSeriesAnnotator(micro_sam.sam_annotator._widgets._WidgetBase):
358class ImageSeriesAnnotator(widgets._WidgetBase):
359    def __init__(self, viewer: napari.Viewer, parent=None):
360        super().__init__(parent=parent)
361        self._viewer = viewer
362
363        # Create the UI: the general options.
364        self._create_options()
365
366        # Add the settings (collapsible).
367        self.layout().addWidget(self._create_settings())
368
369        # Add the run button to trigger the embedding computation.
370        self.run_button = QtWidgets.QPushButton("Annotate Images")
371        self.run_button.clicked.connect(self.__call__)
372        self.layout().addWidget(self.run_button)
373
374    def _create_options(self):
375        self.folder = None
376        _, layout = self._add_path_param(
377            "folder", self.folder, "directory",
378            title="Input Folder", placeholder="Folder with images ...",
379            tooltip=get_tooltip("image_series_annotator", "folder")
380        )
381        self.layout().addLayout(layout)
382
383        self.output_folder = None
384        _, layout = self._add_path_param(
385            "output_folder", self.output_folder, "directory",
386            title="Output Folder", placeholder="Folder to save the results ...",
387            tooltip=get_tooltip("image_series_annotator", "output_folder")
388        )
389        self.layout().addLayout(layout)
390
391        # Add the model family widget section.
392        layout = self._create_model_section(create_layout=False)
393        self.layout().addLayout(layout)
394
395    def _create_settings(self):
396        setting_values = QtWidgets.QWidget()
397        setting_values.setLayout(QtWidgets.QVBoxLayout())
398
399        # Add the model size widget section.
400        layout = self._create_model_size_section()
401        setting_values.layout().addLayout(layout)
402
403        self.pattern = "*"
404        _, layout = self._add_string_param(
405            "pattern", self.pattern, tooltip=get_tooltip("image_series_annotator", "pattern")
406        )
407        setting_values.layout().addLayout(layout)
408
409        self.is_volumetric = False
410        setting_values.layout().addWidget(self._add_boolean_param(
411            "is_volumetric", self.is_volumetric, tooltip=get_tooltip("image_series_annotator", "is_volumetric")
412        ))
413
414        self.device = "auto"
415        device_options = ["auto"] + util._available_devices()
416        self.device_dropdown, layout = self._add_choice_param(
417            "device", self.device, device_options, tooltip=get_tooltip("embedding", "device")
418        )
419        setting_values.layout().addLayout(layout)
420
421        self.embeddings_save_path = None
422        _, layout = self._add_path_param(
423            "embeddings_save_path", self.embeddings_save_path, "directory", title="embeddings save path:",
424            tooltip=get_tooltip("embedding", "embeddings_save_path")
425        )
426        setting_values.layout().addLayout(layout)
427
428        self.custom_weights = None  # select_file
429        _, layout = self._add_path_param(
430            "custom_weights", self.custom_weights, "file", title="custom weights path:",
431            tooltip=get_tooltip("embedding", "custom_weights")
432        )
433        setting_values.layout().addLayout(layout)
434
435        self.tile_x, self.tile_y = 0, 0
436        self.tile_x_param, self.tile_y_param, layout = self._add_shape_param(
437            ("tile_x", "tile_y"), (self.tile_x, self.tile_y), min_val=0, max_val=2048, step=16,
438            tooltip=get_tooltip("embedding", "tiling")
439        )
440        setting_values.layout().addLayout(layout)
441
442        self.halo_x, self.halo_y = 0, 0
443        self.halo_x_param, self.halo_y_param, layout = self._add_shape_param(
444            ("halo_x", "halo_y"), (self.halo_x, self.halo_y), min_val=0, max_val=512,
445            tooltip=get_tooltip("embedding", "halo")
446        )
447        setting_values.layout().addLayout(layout)
448
449        settings = widgets._make_collapsible(setting_values, title="Advanced Settings")
450        return settings
451
452    def _validate_inputs(self):
453        missing_data = self.folder is None or len(glob(os.path.join(self.folder, self.pattern))) == 0
454        missing_output = self.output_folder is None
455        if missing_data or missing_output:
456            msg = ""
457            if missing_data:
458                msg += "The input folder is missing or empty. "
459            if missing_output:
460                msg += "The output folder is missing."
461            return widgets._generate_message("error", msg)
462        return False
463
464    def __call__(self, skip_validate=False):
465        self._validate_model_type_and_custom_weights()
466
467        if not skip_validate and self._validate_inputs():
468            return
469        tile_shape, halo = widgets._process_tiling_inputs(self.tile_x, self.tile_y, self.halo_x, self.halo_y)
470
471        image_folder_annotator(
472            input_folder=self.folder,
473            output_folder=self.output_folder,
474            pattern=self.pattern,
475            model_type=self.model_type,
476            embedding_path=self.embeddings_save_path,
477            tile_shape=tile_shape, halo=halo, checkpoint_path=self.custom_weights,
478            device=self.device, is_volumetric=self.is_volumetric,
479            viewer=self._viewer, return_viewer=True,
480        )

QWidget(parent: typing.Optional[QWidget] = None, flags: Union[Qt.WindowFlags, Qt.WindowType] = Qt.WindowFlags())

ImageSeriesAnnotator(viewer: napari.viewer.Viewer, parent=None)
359    def __init__(self, viewer: napari.Viewer, parent=None):
360        super().__init__(parent=parent)
361        self._viewer = viewer
362
363        # Create the UI: the general options.
364        self._create_options()
365
366        # Add the settings (collapsible).
367        self.layout().addWidget(self._create_settings())
368
369        # Add the run button to trigger the embedding computation.
370        self.run_button = QtWidgets.QPushButton("Annotate Images")
371        self.run_button.clicked.connect(self.__call__)
372        self.layout().addWidget(self.run_button)
run_button
Inherited Members
PyQt5.QtWidgets.QWidget
RenderFlag
RenderFlags
acceptDrops
accessibleDescription
accessibleName
actionEvent
actions
activateWindow
addAction
addActions
adjustSize
autoFillBackground
backgroundRole
baseSize
changeEvent
childAt
childrenRect
childrenRegion
clearFocus
clearMask
close
closeEvent
contentsMargins
contentsRect
contextMenuEvent
contextMenuPolicy
create
createWindowContainer
cursor
destroy
devType
dragEnterEvent
dragLeaveEvent
dragMoveEvent
dropEvent
effectiveWinId
ensurePolished
enterEvent
event
find
focusInEvent
focusNextChild
focusNextPrevChild
focusOutEvent
focusPolicy
focusPreviousChild
focusProxy
focusWidget
font
fontInfo
fontMetrics
foregroundRole
frameGeometry
frameSize
geometry
getContentsMargins
grab
grabGesture
grabKeyboard
grabMouse
grabShortcut
graphicsEffect
graphicsProxyWidget
hasFocus
hasHeightForWidth
hasMouseTracking
hasTabletTracking
height
heightForWidth
hide
hideEvent
initPainter
inputMethodEvent
inputMethodHints
inputMethodQuery
insertAction
insertActions
isActiveWindow
isAncestorOf
isEnabled
isEnabledTo
isFullScreen
isHidden
isLeftToRight
isMaximized
isMinimized
isModal
isRightToLeft
isVisible
isVisibleTo
isWindow
isWindowModified
keyPressEvent
keyReleaseEvent
keyboardGrabber
layout
layoutDirection
leaveEvent
locale
lower
mapFrom
mapFromGlobal
mapFromParent
mapTo
mapToGlobal
mapToParent
mask
maximumHeight
maximumSize
maximumWidth
metric
minimumHeight
minimumSize
minimumSizeHint
minimumWidth
mouseDoubleClickEvent
mouseGrabber
mouseMoveEvent
mousePressEvent
mouseReleaseEvent
move
moveEvent
nativeEvent
nativeParentWidget
nextInFocusChain
normalGeometry
overrideWindowFlags
overrideWindowState
paintEngine
paintEvent
palette
parentWidget
pos
previousInFocusChain
raise_
rect
releaseKeyboard
releaseMouse
releaseShortcut
removeAction
render
repaint
resize
resizeEvent
restoreGeometry
saveGeometry
screen
scroll
setAcceptDrops
setAccessibleDescription
setAccessibleName
setAttribute
setAutoFillBackground
setBackgroundRole
setBaseSize
setContentsMargins
setContextMenuPolicy
setCursor
setDisabled
setEnabled
setFixedHeight
setFixedSize
setFixedWidth
setFocus
setFocusPolicy
setFocusProxy
setFont
setForegroundRole
setGeometry
setGraphicsEffect
setHidden
setInputMethodHints
setLayout
setLayoutDirection
setLocale
setMask
setMaximumHeight
setMaximumSize
setMaximumWidth
setMinimumHeight
setMinimumSize
setMinimumWidth
setMouseTracking
setPalette
setParent
setShortcutAutoRepeat
setShortcutEnabled
setSizeIncrement
setSizePolicy
setStatusTip
setStyle
setStyleSheet
setTabOrder
setTabletTracking
setToolTip
setToolTipDuration
setUpdatesEnabled
setVisible
setWhatsThis
setWindowFilePath
setWindowFlag
setWindowFlags
setWindowIcon
setWindowIconText
setWindowModality
setWindowModified
setWindowOpacity
setWindowRole
setWindowState
setWindowTitle
sharedPainter
show
showEvent
showFullScreen
showMaximized
showMinimized
showNormal
size
sizeHint
sizeIncrement
sizePolicy
stackUnder
statusTip
style
styleSheet
tabletEvent
testAttribute
toolTip
toolTipDuration
underMouse
ungrabGesture
unsetCursor
unsetLayoutDirection
unsetLocale
update
updateGeometry
updateMicroFocus
updatesEnabled
visibleRegion
whatsThis
wheelEvent
width
winId
window
windowFilePath
windowFlags
windowHandle
windowIcon
windowIconText
windowModality
windowOpacity
windowRole
windowState
windowTitle
windowType
x
y
DrawChildren
DrawWindowBackground
IgnoreMask
windowIconTextChanged
windowIconChanged
windowTitleChanged
customContextMenuRequested
PyQt5.QtCore.QObject
blockSignals
childEvent
children
connectNotify
customEvent
deleteLater
disconnect
disconnectNotify
dumpObjectInfo
dumpObjectTree
dynamicPropertyNames
eventFilter
findChild
findChildren
inherits
installEventFilter
isSignalConnected
isWidgetType
isWindowType
killTimer
metaObject
moveToThread
objectName
parent
property
pyqtConfigure
receivers
removeEventFilter
sender
senderSignalIndex
setObjectName
setProperty
signalsBlocked
startTimer
thread
timerEvent
tr
staticMetaObject
objectNameChanged
destroyed
PyQt5.QtGui.QPaintDevice
PaintDeviceMetric
colorCount
depth
devicePixelRatio
devicePixelRatioF
devicePixelRatioFScale
heightMM
logicalDpiX
logicalDpiY
paintingActive
physicalDpiX
physicalDpiY
widthMM
PdmDepth
PdmDevicePixelRatio
PdmDevicePixelRatioScaled
PdmDpiX
PdmDpiY
PdmHeight
PdmHeightMM
PdmNumColors
PdmPhysicalDpiX
PdmPhysicalDpiY
PdmWidth
PdmWidthMM