micro_sam.sam_annotator.image_series_annotator

  1import os
  2import time
  3from glob import glob
  4from pathlib import Path
  5from typing import List, Optional, Union, Tuple
  6
  7import numpy as np
  8import imageio.v3 as imageio
  9
 10import torch
 11
 12import napari
 13from magicgui import magicgui
 14from qtpy import QtWidgets
 15
 16from .. import util
 17from . import _widgets as widgets
 18from ._tooltips import get_tooltip
 19from ._state import AnnotatorState
 20from .annotator_2d import Annotator2d
 21from .annotator_3d import Annotator3d
 22from .util import _sync_embedding_widget
 23from ..instance_segmentation import get_decoder
 24from ..precompute_state import _precompute_state_for_files
 25
 26
 27def _precompute(
 28    images, model_type, embedding_path,
 29    tile_shape, halo, precompute_amg_state,
 30    checkpoint_path, device, ndim, prefer_decoder,
 31):
 32    t_start = time.time()
 33
 34    device = util.get_device(device)
 35    predictor, state = util.get_sam_model(
 36        model_type=model_type, checkpoint_path=checkpoint_path, device=device, return_state=True
 37    )
 38    if prefer_decoder and "decoder_state" in state:
 39        decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device)
 40    else:
 41        decoder = None
 42
 43    if embedding_path is None:
 44        embedding_paths = [None] * len(images)
 45    else:
 46        _precompute_state_for_files(
 47            predictor, images, embedding_path, ndim=ndim, tile_shape=tile_shape, halo=halo,
 48            precompute_amg_state=precompute_amg_state, decoder=decoder,
 49        )
 50        if isinstance(images[0], np.ndarray):
 51            embedding_paths = [
 52                os.path.join(embedding_path, f"embedding_{i:05}.zarr") for i, path in enumerate(images)
 53            ]
 54        else:
 55            embedding_paths = [
 56                os.path.join(embedding_path, f"{Path(path).stem}.zarr") for path in images
 57            ]
 58        assert all(os.path.exists(emb_path) for emb_path in embedding_paths)
 59
 60    t_run = time.time() - t_start
 61    minutes = int(t_run // 60)
 62    seconds = int(round(t_run % 60, 0))
 63    print("Precomputation took", t_run, f"seconds (= {minutes:02}:{seconds:02} minutes)")
 64
 65    return predictor, decoder, embedding_paths
 66
 67
 68def _get_input_shape(image, is_volumetric=False):
 69    if image.ndim == 2:
 70        image_shape = image.shape
 71    elif image.ndim == 3:
 72        if is_volumetric:
 73            image_shape = image.shape
 74        else:
 75            image_shape = image.shape[:-1]
 76    elif image.ndim == 4:
 77        image_shape = image.shape[:-1]
 78
 79    return image_shape
 80
 81
 82def _initialize_annotator(
 83    viewer, image, image_embedding_path,
 84    model_type, halo, tile_shape, predictor, decoder, is_volumetric,
 85    precompute_amg_state, checkpoint_path, device, embedding_path,
 86    segmentation_path,
 87):
 88    if viewer is None:
 89        viewer = napari.Viewer()
 90    viewer.add_image(image, name="image")
 91
 92    state = AnnotatorState()
 93    state.initialize_predictor(
 94        image, model_type=model_type, save_path=image_embedding_path, halo=halo, tile_shape=tile_shape,
 95        predictor=predictor, decoder=decoder,
 96        ndim=3 if is_volumetric else 2, precompute_amg_state=precompute_amg_state,
 97        checkpoint_path=checkpoint_path, device=device, skip_load=False, use_cli=True,
 98    )
 99    state.image_shape = _get_input_shape(image, is_volumetric)
100
101    if is_volumetric:
102        if image.ndim not in [3, 4]:
103            raise ValueError(f"Invalid image dimensions for 3d annotator, expect 3 or 4 dimensions, got {image.ndim}")
104        annotator = Annotator3d(viewer)
105    else:
106        if image.ndim not in (2, 3):
107            raise ValueError(f"Invalid image dimensions for 2d annotator, expect 2 or 3 dimensions, got {image.ndim}")
108        annotator = Annotator2d(viewer)
109
110    if os.path.exists(segmentation_path):
111        segmentation_result = imageio.imread(segmentation_path)
112    else:
113        segmentation_result = None
114    annotator._update_image(segmentation_result=segmentation_result)
115
116    # Add the annotator widget to the viewer and sync widgets.
117    viewer.window.add_dock_widget(annotator)
118    _sync_embedding_widget(
119        widget=state.widgets["embeddings"],
120        model_type=model_type if checkpoint_path is None else state.predictor.model_type,
121        save_path=embedding_path,
122        checkpoint_path=checkpoint_path,
123        device=device,
124        tile_shape=tile_shape,
125        halo=halo,
126    )
127    return viewer, annotator
128
129
130def image_series_annotator(
131    images: Union[List[Union[os.PathLike, str]], List[np.ndarray]],
132    output_folder: str,
133    model_type: str = util._DEFAULT_MODEL,
134    embedding_path: Optional[str] = None,
135    tile_shape: Optional[Tuple[int, int]] = None,
136    halo: Optional[Tuple[int, int]] = None,
137    viewer: Optional["napari.viewer.Viewer"] = None,
138    return_viewer: bool = False,
139    precompute_amg_state: bool = False,
140    checkpoint_path: Optional[str] = None,
141    is_volumetric: bool = False,
142    device: Optional[Union[str, torch.device]] = None,
143    prefer_decoder: bool = True,
144    skip_segmented: bool = True,
145) -> Optional["napari.viewer.Viewer"]:
146    """Run the annotation tool for a series of images (supported for both 2d and 3d images).
147
148    Args:
149        images: List of the file paths or list of (set of) slices for the images to be annotated.
150        output_folder: The folder where the segmentation results are saved.
151        model_type: The Segment Anything model to use. For details on the available models check out
152            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
153        embedding_path: Filepath where to save the embeddings.
154        tile_shape: Shape of tiles for tiled embedding prediction.
155            If `None` then the whole image is passed to Segment Anything.
156        halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
157        viewer: The viewer to which the Segment Anything functionality should be added.
158            This enables using a pre-initialized viewer.
159        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
160            By default, does not return the napari viewer.
161        precompute_amg_state: Whether to precompute the state for automatic mask generation.
162            This will take more time when precomputing embeddings, but will then make
163            automatic mask generation much faster. By default, set to 'False'.
164        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
165        is_volumetric: Whether to use the 3d annotator. By default, set to 'False'.
166        prefer_decoder: Whether to use decoder based instance segmentation if
167            the model used has an additional decoder for instance segmentation.
168            By default, set to 'True'.
169        skip_segmented: Whether to skip images that were already segmented.
170            If set to False, then segmentations that already exist will be loaded
171            and used to populate the 'committed_objects' layer.
172
173    Returns:
174        The napari viewer, only returned if `return_viewer=True`.
175    """
176    end_msg = "You have annotated the last image. Do you wish to close napari?"
177    os.makedirs(output_folder, exist_ok=True)
178
179    # Precompute embeddings and amg state (if corresponding options set).
180    predictor, decoder, embedding_paths = _precompute(
181        images, model_type,
182        embedding_path, tile_shape, halo, precompute_amg_state,
183        checkpoint_path=checkpoint_path, device=device,
184        ndim=3 if is_volumetric else 2, prefer_decoder=prefer_decoder,
185    )
186
187    next_image_id = 0
188    have_inputs_as_arrays = isinstance(images[next_image_id], np.ndarray)
189
190    def _get_save_path(image_path, current_idx):
191        if have_inputs_as_arrays:
192            fname = f"seg_{current_idx:05}.tif"
193        else:
194            fname = os.path.basename(image_path)
195            fname = os.path.splitext(fname)[0] + ".tif"
196        return os.path.join(output_folder, fname)
197
198    def _load_image(image_id):
199        image = images[next_image_id]
200        if not have_inputs_as_arrays:
201            image = imageio.imread(image)
202        image_embedding_path = embedding_paths[next_image_id]
203        return image, image_embedding_path
204
205    # Check which image to load next if we skip segmented images.
206    if skip_segmented:
207        while True:
208            if next_image_id == len(images):
209                print("All images have already been annotated and you have set 'skip_segmented=True'. Nothing to do.")
210                return
211
212            save_path = _get_save_path(images[next_image_id], next_image_id)
213            if not os.path.exists(save_path):
214                print("The first image to annotate is image number", next_image_id)
215                image, image_embedding_path = _load_image(next_image_id)
216                break
217
218            next_image_id += 1
219
220    else:
221        save_path = _get_save_path(images[next_image_id], next_image_id)
222        image, image_embedding_path = _load_image(next_image_id)
223
224    # Initialize the viewer and annotator for this image.
225    state = AnnotatorState()
226    viewer, annotator = _initialize_annotator(
227        viewer, image, image_embedding_path,
228        model_type, halo, tile_shape, predictor, decoder, is_volumetric,
229        precompute_amg_state, checkpoint_path, device, embedding_path,
230        save_path,
231    )
232
233    def _save_segmentation(image_path, current_idx, segmentation):
234        save_path = _get_save_path(image_path, next_image_id)
235        imageio.imwrite(save_path, segmentation, compression="zlib")
236
237    # Add functionality for going to the next image.
238    @magicgui(call_button="Next Image [N]")
239    def next_image(*args):
240        nonlocal next_image_id
241
242        segmentation = viewer.layers["committed_objects"].data
243        abort = False
244        if segmentation.sum() == 0:
245            msg = "Nothing is segmented yet. Do you wish to continue to the next image?"
246            abort = widgets._generate_message("info", msg)
247            if abort:
248                return
249
250        # Save the current segmentation.
251        _save_segmentation(images[next_image_id], next_image_id, segmentation)
252
253        # Clear the segmentation already to avoid lagging removal.
254        viewer.layers["committed_objects"].data = np.zeros_like(viewer.layers["committed_objects"].data)
255
256        # Go to the next image.
257        next_image_id += 1
258
259        # Check if we are done.
260        if next_image_id == len(images):
261            # Inform the user via dialog.
262            abort = widgets._generate_message("info", end_msg)
263            if not abort:
264                viewer.close()
265            return
266
267        # If we are skipping images that are already segmented, then check if we have to load the next image.
268        save_path = _get_save_path(images[next_image_id], next_image_id)
269        if skip_segmented:
270            segmentation_result = None
271            while os.path.exists(save_path):
272                next_image_id += 1
273
274                # Check if we are done.
275                if next_image_id == len(images):
276                    # Inform the user via dialog.
277                    abort = widgets._generate_message("info", end_msg)
278                    if not abort:
279                        viewer.close()
280                    return
281
282                save_path = _get_save_path(images[next_image_id], next_image_id)
283        else:
284            if os.path.exists(save_path):
285                segmentation_result = imageio.imread(save_path)
286            else:
287                segmentation_result = None
288
289        print(
290            "Loading next image:", images[next_image_id] if not have_inputs_as_arrays else f"at index {next_image_id}"
291        )
292
293        if have_inputs_as_arrays:
294            image = images[next_image_id]
295        else:
296            image = imageio.imread(images[next_image_id])
297
298        image_embedding_path = embedding_paths[next_image_id]
299
300        # Set the new image in the viewer, state and annotator.
301        viewer.layers["image"].data = image
302
303        if state.amg is not None:
304            state.amg.clear_state()
305
306        state.initialize_predictor(
307            image, model_type=model_type, ndim=3 if is_volumetric else 2,
308            save_path=image_embedding_path,
309            tile_shape=tile_shape, halo=halo,
310            predictor=predictor, decoder=decoder,
311            precompute_amg_state=precompute_amg_state, device=device,
312            skip_load=False,
313        )
314        state.image_shape = _get_input_shape(image, is_volumetric)
315
316        annotator._update_image(segmentation_result=segmentation_result)
317
318    viewer.window.add_dock_widget(next_image)
319
320    @viewer.bind_key("n", overwrite=True)
321    def _next_image(viewer):
322        next_image(viewer)
323
324    if return_viewer:
325        return viewer
326    napari.run()
327
328
329def image_folder_annotator(
330    input_folder: str,
331    output_folder: str,
332    pattern: str = "*",
333    viewer: Optional["napari.viewer.Viewer"] = None,
334    return_viewer: bool = False,
335    **kwargs
336) -> Optional["napari.viewer.Viewer"]:
337    """Run the 2d annotation tool for a series of images in a folder.
338
339    Args:
340        input_folder: The folder with the images to be annotated.
341        output_folder: The folder where the segmentation results are saved.
342        pattern: The glob patter for loading files from `input_folder`.
343            By default all files will be loaded.
344        viewer: The viewer to which the Segment Anything functionality should be added.
345            This enables using a pre-initialized viewer.
346        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
347            By default, does not return the napari viewer.
348        kwargs: The keyword arguments for `micro_sam.sam_annotator.image_series_annotator`.
349
350    Returns:
351        The napari viewer, only returned if `return_viewer=True`.
352    """
353    image_files = sorted(glob(os.path.join(input_folder, pattern)))
354
355    return image_series_annotator(
356        image_files, output_folder, viewer=viewer, return_viewer=return_viewer, **kwargs
357    )
358
359
360class ImageSeriesAnnotator(widgets._WidgetBase):
361    def __init__(self, viewer: napari.Viewer, parent=None):
362        super().__init__(parent=parent)
363        self._viewer = viewer
364
365        # Create the UI: the general options.
366        self._create_options()
367
368        # Add the settings (collapsible).
369        self.layout().addWidget(self._create_settings())
370
371        # Add the run button to trigger the embedding computation.
372        self.run_button = QtWidgets.QPushButton("Annotate Images")
373        self.run_button.clicked.connect(self.__call__)
374        self.layout().addWidget(self.run_button)
375
376    def _create_options(self):
377        self.folder = None
378        _, layout = self._add_path_param(
379            "folder", self.folder, "directory",
380            title="Input Folder", placeholder="Folder with images ...",
381            tooltip=get_tooltip("image_series_annotator", "folder")
382        )
383        self.layout().addLayout(layout)
384
385        self.output_folder = None
386        _, layout = self._add_path_param(
387            "output_folder", self.output_folder, "directory",
388            title="Output Folder", placeholder="Folder to save the results ...",
389            tooltip=get_tooltip("image_series_annotator", "output_folder")
390        )
391        self.layout().addLayout(layout)
392
393        # Add the model family widget section.
394        layout = self._create_model_section(create_layout=False)
395        self.layout().addLayout(layout)
396
397    def _create_settings(self):
398        setting_values = QtWidgets.QWidget()
399        setting_values.setLayout(QtWidgets.QVBoxLayout())
400
401        # Add the model size widget section.
402        layout = self._create_model_size_section()
403        setting_values.layout().addLayout(layout)
404
405        self.pattern = "*"
406        _, layout = self._add_string_param(
407            "pattern", self.pattern, tooltip=get_tooltip("image_series_annotator", "pattern")
408        )
409        setting_values.layout().addLayout(layout)
410
411        self.is_volumetric = False
412        setting_values.layout().addWidget(self._add_boolean_param(
413            "is_volumetric", self.is_volumetric, tooltip=get_tooltip("image_series_annotator", "is_volumetric")
414        ))
415
416        self.device = "auto"
417        device_options = ["auto"] + util._available_devices()
418        self.device_dropdown, layout = self._add_choice_param(
419            "device", self.device, device_options, tooltip=get_tooltip("embedding", "device")
420        )
421        setting_values.layout().addLayout(layout)
422
423        self.embeddings_save_path = None
424        _, layout = self._add_path_param(
425            "embeddings_save_path", self.embeddings_save_path, "directory", title="embeddings save path:",
426            tooltip=get_tooltip("embedding", "embeddings_save_path")
427        )
428        setting_values.layout().addLayout(layout)
429
430        self.custom_weights = None  # select_file
431        _, layout = self._add_path_param(
432            "custom_weights", self.custom_weights, "file", title="custom weights path:",
433            tooltip=get_tooltip("embedding", "custom_weights")
434        )
435        setting_values.layout().addLayout(layout)
436
437        self.tile_x, self.tile_y = 0, 0
438        self.tile_x_param, self.tile_y_param, layout = self._add_shape_param(
439            ("tile_x", "tile_y"), (self.tile_x, self.tile_y), min_val=0, max_val=2048, step=16,
440            tooltip=get_tooltip("embedding", "tiling")
441        )
442        setting_values.layout().addLayout(layout)
443
444        self.halo_x, self.halo_y = 0, 0
445        self.halo_x_param, self.halo_y_param, layout = self._add_shape_param(
446            ("halo_x", "halo_y"), (self.halo_x, self.halo_y), min_val=0, max_val=512,
447            tooltip=get_tooltip("embedding", "halo")
448        )
449        setting_values.layout().addLayout(layout)
450
451        settings = widgets._make_collapsible(setting_values, title="Advanced Settings")
452        return settings
453
454    def _validate_inputs(self):
455        missing_data = self.folder is None or len(glob(os.path.join(self.folder, self.pattern))) == 0
456        missing_output = self.output_folder is None
457        if missing_data or missing_output:
458            msg = ""
459            if missing_data:
460                msg += "The input folder is missing or empty. "
461            if missing_output:
462                msg += "The output folder is missing."
463            return widgets._generate_message("error", msg)
464        return False
465
466    def __call__(self, skip_validate=False):
467        self._validate_model_type_and_custom_weights()
468
469        if not skip_validate and self._validate_inputs():
470            return
471        tile_shape, halo = widgets._process_tiling_inputs(self.tile_x, self.tile_y, self.halo_x, self.halo_y)
472
473        image_folder_annotator(
474            input_folder=self.folder,
475            output_folder=self.output_folder,
476            pattern=self.pattern,
477            model_type=self.model_type,
478            embedding_path=self.embeddings_save_path,
479            tile_shape=tile_shape, halo=halo, checkpoint_path=self.custom_weights,
480            device=self.device, is_volumetric=self.is_volumetric,
481            viewer=self._viewer, return_viewer=True,
482        )
483
484
485def main():
486    """@private"""
487    import argparse
488
489    available_models = list(util.get_model_names())
490    available_models = ", ".join(available_models)
491
492    parser = argparse.ArgumentParser(description="Annotate a series of images from a folder.")
493    parser.add_argument(
494        "-i", "--input_folder", required=True,
495        help="The folder containing the image data. The data can be stored in any common format (tif, jpg, png, ...)."
496    )
497    parser.add_argument(
498        "-o", "--output_folder", required=True,
499        help="The folder where the segmentation results will be stored."
500    )
501    parser.add_argument(
502        "-p", "--pattern", default="*",
503        help="The pattern to select the images to annotator from the input folder. E.g. *.tif to annotate all tifs."
504        "By default all files in the folder will be loaded and annotated."
505    )
506    parser.add_argument(
507        "-e", "--embedding_path",
508        help="The filepath for saving/loading the pre-computed image embeddings. "
509        "NOTE: It is recommended to pass this argument and store the embeddings, "
510        "otherwise they will be recomputed every time (which can take a long time)."
511    )
512    parser.add_argument(
513        "-m", "--model_type", default=util._DEFAULT_MODEL,
514        help=f"The segment anything model that will be used, one of {available_models}."
515    )
516    parser.add_argument(
517        "-c", "--checkpoint", default=None,
518        help="Checkpoint from which the SAM model will be loaded loaded."
519    )
520    parser.add_argument(
521        "-d", "--device", default=None,
522        help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)."
523        "By default the most performant available device will be selected."
524    )
525    parser.add_argument(
526        "--is_volumetric", action="store_true", help="Whether to use the 3d annotator for a set of 3d volumes."
527    )
528
529    parser.add_argument(
530        "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction", default=None
531    )
532    parser.add_argument(
533        "--halo", nargs="+", type=int, help="The halo for using tiled prediction", default=None
534    )
535    parser.add_argument("--precompute_amg_state", action="store_true")
536    parser.add_argument("--prefer_decoder", action="store_false")
537    parser.add_argument("--skip_segmented", action="store_false")
538
539    args = parser.parse_args()
540
541    image_folder_annotator(
542        args.input_folder, args.output_folder, args.pattern,
543        embedding_path=args.embedding_path, model_type=args.model_type,
544        tile_shape=args.tile_shape, halo=args.halo, precompute_amg_state=args.precompute_amg_state,
545        checkpoint_path=args.checkpoint, device=args.device, is_volumetric=args.is_volumetric,
546        prefer_decoder=args.prefer_decoder, skip_segmented=args.skip_segmented
547    )
def image_series_annotator( images: Union[List[Union[str, os.PathLike]], List[numpy.ndarray]], output_folder: str, model_type: str = 'vit_b_lm', embedding_path: Optional[str] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, viewer: Optional[napari.viewer.Viewer] = None, return_viewer: bool = False, precompute_amg_state: bool = False, checkpoint_path: Optional[str] = None, is_volumetric: bool = False, device: Union[str, torch.device, NoneType] = None, prefer_decoder: bool = True, skip_segmented: bool = True) -> Optional[napari.viewer.Viewer]:
131def image_series_annotator(
132    images: Union[List[Union[os.PathLike, str]], List[np.ndarray]],
133    output_folder: str,
134    model_type: str = util._DEFAULT_MODEL,
135    embedding_path: Optional[str] = None,
136    tile_shape: Optional[Tuple[int, int]] = None,
137    halo: Optional[Tuple[int, int]] = None,
138    viewer: Optional["napari.viewer.Viewer"] = None,
139    return_viewer: bool = False,
140    precompute_amg_state: bool = False,
141    checkpoint_path: Optional[str] = None,
142    is_volumetric: bool = False,
143    device: Optional[Union[str, torch.device]] = None,
144    prefer_decoder: bool = True,
145    skip_segmented: bool = True,
146) -> Optional["napari.viewer.Viewer"]:
147    """Run the annotation tool for a series of images (supported for both 2d and 3d images).
148
149    Args:
150        images: List of the file paths or list of (set of) slices for the images to be annotated.
151        output_folder: The folder where the segmentation results are saved.
152        model_type: The Segment Anything model to use. For details on the available models check out
153            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
154        embedding_path: Filepath where to save the embeddings.
155        tile_shape: Shape of tiles for tiled embedding prediction.
156            If `None` then the whole image is passed to Segment Anything.
157        halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
158        viewer: The viewer to which the Segment Anything functionality should be added.
159            This enables using a pre-initialized viewer.
160        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
161            By default, does not return the napari viewer.
162        precompute_amg_state: Whether to precompute the state for automatic mask generation.
163            This will take more time when precomputing embeddings, but will then make
164            automatic mask generation much faster. By default, set to 'False'.
165        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
166        is_volumetric: Whether to use the 3d annotator. By default, set to 'False'.
167        prefer_decoder: Whether to use decoder based instance segmentation if
168            the model used has an additional decoder for instance segmentation.
169            By default, set to 'True'.
170        skip_segmented: Whether to skip images that were already segmented.
171            If set to False, then segmentations that already exist will be loaded
172            and used to populate the 'committed_objects' layer.
173
174    Returns:
175        The napari viewer, only returned if `return_viewer=True`.
176    """
177    end_msg = "You have annotated the last image. Do you wish to close napari?"
178    os.makedirs(output_folder, exist_ok=True)
179
180    # Precompute embeddings and amg state (if corresponding options set).
181    predictor, decoder, embedding_paths = _precompute(
182        images, model_type,
183        embedding_path, tile_shape, halo, precompute_amg_state,
184        checkpoint_path=checkpoint_path, device=device,
185        ndim=3 if is_volumetric else 2, prefer_decoder=prefer_decoder,
186    )
187
188    next_image_id = 0
189    have_inputs_as_arrays = isinstance(images[next_image_id], np.ndarray)
190
191    def _get_save_path(image_path, current_idx):
192        if have_inputs_as_arrays:
193            fname = f"seg_{current_idx:05}.tif"
194        else:
195            fname = os.path.basename(image_path)
196            fname = os.path.splitext(fname)[0] + ".tif"
197        return os.path.join(output_folder, fname)
198
199    def _load_image(image_id):
200        image = images[next_image_id]
201        if not have_inputs_as_arrays:
202            image = imageio.imread(image)
203        image_embedding_path = embedding_paths[next_image_id]
204        return image, image_embedding_path
205
206    # Check which image to load next if we skip segmented images.
207    if skip_segmented:
208        while True:
209            if next_image_id == len(images):
210                print("All images have already been annotated and you have set 'skip_segmented=True'. Nothing to do.")
211                return
212
213            save_path = _get_save_path(images[next_image_id], next_image_id)
214            if not os.path.exists(save_path):
215                print("The first image to annotate is image number", next_image_id)
216                image, image_embedding_path = _load_image(next_image_id)
217                break
218
219            next_image_id += 1
220
221    else:
222        save_path = _get_save_path(images[next_image_id], next_image_id)
223        image, image_embedding_path = _load_image(next_image_id)
224
225    # Initialize the viewer and annotator for this image.
226    state = AnnotatorState()
227    viewer, annotator = _initialize_annotator(
228        viewer, image, image_embedding_path,
229        model_type, halo, tile_shape, predictor, decoder, is_volumetric,
230        precompute_amg_state, checkpoint_path, device, embedding_path,
231        save_path,
232    )
233
234    def _save_segmentation(image_path, current_idx, segmentation):
235        save_path = _get_save_path(image_path, next_image_id)
236        imageio.imwrite(save_path, segmentation, compression="zlib")
237
238    # Add functionality for going to the next image.
239    @magicgui(call_button="Next Image [N]")
240    def next_image(*args):
241        nonlocal next_image_id
242
243        segmentation = viewer.layers["committed_objects"].data
244        abort = False
245        if segmentation.sum() == 0:
246            msg = "Nothing is segmented yet. Do you wish to continue to the next image?"
247            abort = widgets._generate_message("info", msg)
248            if abort:
249                return
250
251        # Save the current segmentation.
252        _save_segmentation(images[next_image_id], next_image_id, segmentation)
253
254        # Clear the segmentation already to avoid lagging removal.
255        viewer.layers["committed_objects"].data = np.zeros_like(viewer.layers["committed_objects"].data)
256
257        # Go to the next image.
258        next_image_id += 1
259
260        # Check if we are done.
261        if next_image_id == len(images):
262            # Inform the user via dialog.
263            abort = widgets._generate_message("info", end_msg)
264            if not abort:
265                viewer.close()
266            return
267
268        # If we are skipping images that are already segmented, then check if we have to load the next image.
269        save_path = _get_save_path(images[next_image_id], next_image_id)
270        if skip_segmented:
271            segmentation_result = None
272            while os.path.exists(save_path):
273                next_image_id += 1
274
275                # Check if we are done.
276                if next_image_id == len(images):
277                    # Inform the user via dialog.
278                    abort = widgets._generate_message("info", end_msg)
279                    if not abort:
280                        viewer.close()
281                    return
282
283                save_path = _get_save_path(images[next_image_id], next_image_id)
284        else:
285            if os.path.exists(save_path):
286                segmentation_result = imageio.imread(save_path)
287            else:
288                segmentation_result = None
289
290        print(
291            "Loading next image:", images[next_image_id] if not have_inputs_as_arrays else f"at index {next_image_id}"
292        )
293
294        if have_inputs_as_arrays:
295            image = images[next_image_id]
296        else:
297            image = imageio.imread(images[next_image_id])
298
299        image_embedding_path = embedding_paths[next_image_id]
300
301        # Set the new image in the viewer, state and annotator.
302        viewer.layers["image"].data = image
303
304        if state.amg is not None:
305            state.amg.clear_state()
306
307        state.initialize_predictor(
308            image, model_type=model_type, ndim=3 if is_volumetric else 2,
309            save_path=image_embedding_path,
310            tile_shape=tile_shape, halo=halo,
311            predictor=predictor, decoder=decoder,
312            precompute_amg_state=precompute_amg_state, device=device,
313            skip_load=False,
314        )
315        state.image_shape = _get_input_shape(image, is_volumetric)
316
317        annotator._update_image(segmentation_result=segmentation_result)
318
319    viewer.window.add_dock_widget(next_image)
320
321    @viewer.bind_key("n", overwrite=True)
322    def _next_image(viewer):
323        next_image(viewer)
324
325    if return_viewer:
326        return viewer
327    napari.run()

Run the annotation tool for a series of images (supported for both 2d and 3d images).

Arguments:
  • images: List of the file paths or list of (set of) slices for the images to be annotated.
  • output_folder: The folder where the segmentation results are saved.
  • model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
  • embedding_path: Filepath where to save the embeddings.
  • tile_shape: Shape of tiles for tiled embedding prediction. If None then the whole image is passed to Segment Anything.
  • halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
  • viewer: The viewer to which the Segment Anything functionality should be added. This enables using a pre-initialized viewer.
  • return_viewer: Whether to return the napari viewer to further modify it before starting the tool. By default, does not return the napari viewer.
  • precompute_amg_state: Whether to precompute the state for automatic mask generation. This will take more time when precomputing embeddings, but will then make automatic mask generation much faster. By default, set to 'False'.
  • checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
  • is_volumetric: Whether to use the 3d annotator. By default, set to 'False'.
  • prefer_decoder: Whether to use decoder based instance segmentation if the model used has an additional decoder for instance segmentation. By default, set to 'True'.
  • skip_segmented: Whether to skip images that were already segmented. If set to False, then segmentations that already exist will be loaded and used to populate the 'committed_objects' layer.
Returns:

The napari viewer, only returned if return_viewer=True.

def image_folder_annotator( input_folder: str, output_folder: str, pattern: str = '*', viewer: Optional[napari.viewer.Viewer] = None, return_viewer: bool = False, **kwargs) -> Optional[napari.viewer.Viewer]:
330def image_folder_annotator(
331    input_folder: str,
332    output_folder: str,
333    pattern: str = "*",
334    viewer: Optional["napari.viewer.Viewer"] = None,
335    return_viewer: bool = False,
336    **kwargs
337) -> Optional["napari.viewer.Viewer"]:
338    """Run the 2d annotation tool for a series of images in a folder.
339
340    Args:
341        input_folder: The folder with the images to be annotated.
342        output_folder: The folder where the segmentation results are saved.
343        pattern: The glob patter for loading files from `input_folder`.
344            By default all files will be loaded.
345        viewer: The viewer to which the Segment Anything functionality should be added.
346            This enables using a pre-initialized viewer.
347        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
348            By default, does not return the napari viewer.
349        kwargs: The keyword arguments for `micro_sam.sam_annotator.image_series_annotator`.
350
351    Returns:
352        The napari viewer, only returned if `return_viewer=True`.
353    """
354    image_files = sorted(glob(os.path.join(input_folder, pattern)))
355
356    return image_series_annotator(
357        image_files, output_folder, viewer=viewer, return_viewer=return_viewer, **kwargs
358    )

Run the 2d annotation tool for a series of images in a folder.

Arguments:
  • input_folder: The folder with the images to be annotated.
  • output_folder: The folder where the segmentation results are saved.
  • pattern: The glob patter for loading files from input_folder. By default all files will be loaded.
  • viewer: The viewer to which the Segment Anything functionality should be added. This enables using a pre-initialized viewer.
  • return_viewer: Whether to return the napari viewer to further modify it before starting the tool. By default, does not return the napari viewer.
  • kwargs: The keyword arguments for micro_sam.sam_annotator.image_series_annotator.
Returns:

The napari viewer, only returned if return_viewer=True.

class ImageSeriesAnnotator(micro_sam.sam_annotator._widgets._WidgetBase):
361class ImageSeriesAnnotator(widgets._WidgetBase):
362    def __init__(self, viewer: napari.Viewer, parent=None):
363        super().__init__(parent=parent)
364        self._viewer = viewer
365
366        # Create the UI: the general options.
367        self._create_options()
368
369        # Add the settings (collapsible).
370        self.layout().addWidget(self._create_settings())
371
372        # Add the run button to trigger the embedding computation.
373        self.run_button = QtWidgets.QPushButton("Annotate Images")
374        self.run_button.clicked.connect(self.__call__)
375        self.layout().addWidget(self.run_button)
376
377    def _create_options(self):
378        self.folder = None
379        _, layout = self._add_path_param(
380            "folder", self.folder, "directory",
381            title="Input Folder", placeholder="Folder with images ...",
382            tooltip=get_tooltip("image_series_annotator", "folder")
383        )
384        self.layout().addLayout(layout)
385
386        self.output_folder = None
387        _, layout = self._add_path_param(
388            "output_folder", self.output_folder, "directory",
389            title="Output Folder", placeholder="Folder to save the results ...",
390            tooltip=get_tooltip("image_series_annotator", "output_folder")
391        )
392        self.layout().addLayout(layout)
393
394        # Add the model family widget section.
395        layout = self._create_model_section(create_layout=False)
396        self.layout().addLayout(layout)
397
398    def _create_settings(self):
399        setting_values = QtWidgets.QWidget()
400        setting_values.setLayout(QtWidgets.QVBoxLayout())
401
402        # Add the model size widget section.
403        layout = self._create_model_size_section()
404        setting_values.layout().addLayout(layout)
405
406        self.pattern = "*"
407        _, layout = self._add_string_param(
408            "pattern", self.pattern, tooltip=get_tooltip("image_series_annotator", "pattern")
409        )
410        setting_values.layout().addLayout(layout)
411
412        self.is_volumetric = False
413        setting_values.layout().addWidget(self._add_boolean_param(
414            "is_volumetric", self.is_volumetric, tooltip=get_tooltip("image_series_annotator", "is_volumetric")
415        ))
416
417        self.device = "auto"
418        device_options = ["auto"] + util._available_devices()
419        self.device_dropdown, layout = self._add_choice_param(
420            "device", self.device, device_options, tooltip=get_tooltip("embedding", "device")
421        )
422        setting_values.layout().addLayout(layout)
423
424        self.embeddings_save_path = None
425        _, layout = self._add_path_param(
426            "embeddings_save_path", self.embeddings_save_path, "directory", title="embeddings save path:",
427            tooltip=get_tooltip("embedding", "embeddings_save_path")
428        )
429        setting_values.layout().addLayout(layout)
430
431        self.custom_weights = None  # select_file
432        _, layout = self._add_path_param(
433            "custom_weights", self.custom_weights, "file", title="custom weights path:",
434            tooltip=get_tooltip("embedding", "custom_weights")
435        )
436        setting_values.layout().addLayout(layout)
437
438        self.tile_x, self.tile_y = 0, 0
439        self.tile_x_param, self.tile_y_param, layout = self._add_shape_param(
440            ("tile_x", "tile_y"), (self.tile_x, self.tile_y), min_val=0, max_val=2048, step=16,
441            tooltip=get_tooltip("embedding", "tiling")
442        )
443        setting_values.layout().addLayout(layout)
444
445        self.halo_x, self.halo_y = 0, 0
446        self.halo_x_param, self.halo_y_param, layout = self._add_shape_param(
447            ("halo_x", "halo_y"), (self.halo_x, self.halo_y), min_val=0, max_val=512,
448            tooltip=get_tooltip("embedding", "halo")
449        )
450        setting_values.layout().addLayout(layout)
451
452        settings = widgets._make_collapsible(setting_values, title="Advanced Settings")
453        return settings
454
455    def _validate_inputs(self):
456        missing_data = self.folder is None or len(glob(os.path.join(self.folder, self.pattern))) == 0
457        missing_output = self.output_folder is None
458        if missing_data or missing_output:
459            msg = ""
460            if missing_data:
461                msg += "The input folder is missing or empty. "
462            if missing_output:
463                msg += "The output folder is missing."
464            return widgets._generate_message("error", msg)
465        return False
466
467    def __call__(self, skip_validate=False):
468        self._validate_model_type_and_custom_weights()
469
470        if not skip_validate and self._validate_inputs():
471            return
472        tile_shape, halo = widgets._process_tiling_inputs(self.tile_x, self.tile_y, self.halo_x, self.halo_y)
473
474        image_folder_annotator(
475            input_folder=self.folder,
476            output_folder=self.output_folder,
477            pattern=self.pattern,
478            model_type=self.model_type,
479            embedding_path=self.embeddings_save_path,
480            tile_shape=tile_shape, halo=halo, checkpoint_path=self.custom_weights,
481            device=self.device, is_volumetric=self.is_volumetric,
482            viewer=self._viewer, return_viewer=True,
483        )

QWidget(parent: Optional[QWidget] = None, flags: Union[Qt.WindowFlags, Qt.WindowType] = Qt.WindowFlags())

ImageSeriesAnnotator(viewer: napari.viewer.Viewer, parent=None)
362    def __init__(self, viewer: napari.Viewer, parent=None):
363        super().__init__(parent=parent)
364        self._viewer = viewer
365
366        # Create the UI: the general options.
367        self._create_options()
368
369        # Add the settings (collapsible).
370        self.layout().addWidget(self._create_settings())
371
372        # Add the run button to trigger the embedding computation.
373        self.run_button = QtWidgets.QPushButton("Annotate Images")
374        self.run_button.clicked.connect(self.__call__)
375        self.layout().addWidget(self.run_button)
run_button
Inherited Members
PyQt5.QtWidgets.QWidget
RenderFlag
RenderFlags
acceptDrops
accessibleDescription
accessibleName
actionEvent
actions
activateWindow
addAction
addActions
adjustSize
autoFillBackground
backgroundRole
baseSize
changeEvent
childAt
childrenRect
childrenRegion
clearFocus
clearMask
close
closeEvent
contentsMargins
contentsRect
contextMenuEvent
contextMenuPolicy
create
createWindowContainer
cursor
destroy
devType
dragEnterEvent
dragLeaveEvent
dragMoveEvent
dropEvent
effectiveWinId
ensurePolished
enterEvent
event
find
focusInEvent
focusNextChild
focusNextPrevChild
focusOutEvent
focusPolicy
focusPreviousChild
focusProxy
focusWidget
font
fontInfo
fontMetrics
foregroundRole
frameGeometry
frameSize
geometry
getContentsMargins
grab
grabGesture
grabKeyboard
grabMouse
grabShortcut
graphicsEffect
graphicsProxyWidget
hasFocus
hasHeightForWidth
hasMouseTracking
hasTabletTracking
height
heightForWidth
hide
hideEvent
initPainter
inputMethodEvent
inputMethodHints
inputMethodQuery
insertAction
insertActions
isActiveWindow
isAncestorOf
isEnabled
isEnabledTo
isFullScreen
isHidden
isLeftToRight
isMaximized
isMinimized
isModal
isRightToLeft
isVisible
isVisibleTo
isWindow
isWindowModified
keyPressEvent
keyReleaseEvent
keyboardGrabber
layout
layoutDirection
leaveEvent
locale
lower
mapFrom
mapFromGlobal
mapFromParent
mapTo
mapToGlobal
mapToParent
mask
maximumHeight
maximumSize
maximumWidth
metric
minimumHeight
minimumSize
minimumSizeHint
minimumWidth
mouseDoubleClickEvent
mouseGrabber
mouseMoveEvent
mousePressEvent
mouseReleaseEvent
move
moveEvent
nativeEvent
nativeParentWidget
nextInFocusChain
normalGeometry
overrideWindowFlags
overrideWindowState
paintEngine
paintEvent
palette
parentWidget
pos
previousInFocusChain
raise_
rect
releaseKeyboard
releaseMouse
releaseShortcut
removeAction
render
repaint
resize
resizeEvent
restoreGeometry
saveGeometry
screen
scroll
setAcceptDrops
setAccessibleDescription
setAccessibleName
setAttribute
setAutoFillBackground
setBackgroundRole
setBaseSize
setContentsMargins
setContextMenuPolicy
setCursor
setDisabled
setEnabled
setFixedHeight
setFixedSize
setFixedWidth
setFocus
setFocusPolicy
setFocusProxy
setFont
setForegroundRole
setGeometry
setGraphicsEffect
setHidden
setInputMethodHints
setLayout
setLayoutDirection
setLocale
setMask
setMaximumHeight
setMaximumSize
setMaximumWidth
setMinimumHeight
setMinimumSize
setMinimumWidth
setMouseTracking
setPalette
setParent
setShortcutAutoRepeat
setShortcutEnabled
setSizeIncrement
setSizePolicy
setStatusTip
setStyle
setStyleSheet
setTabOrder
setTabletTracking
setToolTip
setToolTipDuration
setUpdatesEnabled
setVisible
setWhatsThis
setWindowFilePath
setWindowFlag
setWindowFlags
setWindowIcon
setWindowIconText
setWindowModality
setWindowModified
setWindowOpacity
setWindowRole
setWindowState
setWindowTitle
sharedPainter
show
showEvent
showFullScreen
showMaximized
showMinimized
showNormal
size
sizeHint
sizeIncrement
sizePolicy
stackUnder
statusTip
style
styleSheet
tabletEvent
testAttribute
toolTip
toolTipDuration
underMouse
ungrabGesture
unsetCursor
unsetLayoutDirection
unsetLocale
update
updateGeometry
updateMicroFocus
updatesEnabled
visibleRegion
whatsThis
wheelEvent
width
winId
window
windowFilePath
windowFlags
windowHandle
windowIcon
windowIconText
windowModality
windowOpacity
windowRole
windowState
windowTitle
windowType
x
y
DrawChildren
DrawWindowBackground
IgnoreMask
windowIconTextChanged
windowIconChanged
windowTitleChanged
customContextMenuRequested
PyQt5.QtCore.QObject
blockSignals
childEvent
children
connectNotify
customEvent
deleteLater
disconnect
disconnectNotify
dumpObjectInfo
dumpObjectTree
dynamicPropertyNames
eventFilter
findChild
findChildren
inherits
installEventFilter
isSignalConnected
isWidgetType
isWindowType
killTimer
metaObject
moveToThread
objectName
parent
property
pyqtConfigure
receivers
removeEventFilter
sender
senderSignalIndex
setObjectName
setProperty
signalsBlocked
startTimer
thread
timerEvent
tr
staticMetaObject
objectNameChanged
destroyed
PyQt5.QtGui.QPaintDevice
PaintDeviceMetric
colorCount
depth
devicePixelRatio
devicePixelRatioF
devicePixelRatioFScale
heightMM
logicalDpiX
logicalDpiY
paintingActive
physicalDpiX
physicalDpiY
widthMM
PdmDepth
PdmDevicePixelRatio
PdmDevicePixelRatioScaled
PdmDpiX
PdmDpiY
PdmHeight
PdmHeightMM
PdmNumColors
PdmPhysicalDpiX
PdmPhysicalDpiY
PdmWidth
PdmWidthMM