micro_sam.sam_annotator.image_series_annotator

  1import os
  2import time
  3
  4from glob import glob
  5from pathlib import Path
  6from typing import List, Optional, Union, Tuple
  7
  8import numpy as np
  9import imageio.v3 as imageio
 10
 11import torch
 12
 13import napari
 14from magicgui import magicgui
 15from qtpy import QtWidgets
 16
 17from .. import util
 18from . import _widgets as widgets
 19from ..precompute_state import _precompute_state_for_files
 20from ..instance_segmentation import get_decoder
 21from .annotator_2d import Annotator2d
 22from .annotator_3d import Annotator3d
 23from ._state import AnnotatorState
 24from .util import _sync_embedding_widget
 25from ._tooltips import get_tooltip
 26
 27
 28def _precompute(
 29    images, model_type, embedding_path,
 30    tile_shape, halo, precompute_amg_state,
 31    checkpoint_path, device, ndim, prefer_decoder,
 32):
 33    t_start = time.time()
 34
 35    device = util.get_device(device)
 36    predictor, state = util.get_sam_model(
 37        model_type=model_type, checkpoint_path=checkpoint_path, device=device, return_state=True
 38    )
 39    if prefer_decoder and "decoder_state" in state:
 40        decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device)
 41    else:
 42        decoder = None
 43
 44    if embedding_path is None:
 45        embedding_paths = [None] * len(images)
 46    else:
 47        _precompute_state_for_files(
 48            predictor, images, embedding_path, ndim=ndim, tile_shape=tile_shape, halo=halo,
 49            precompute_amg_state=precompute_amg_state, decoder=decoder,
 50        )
 51        if isinstance(images[0], np.ndarray):
 52            embedding_paths = [
 53                os.path.join(embedding_path, f"embedding_{i:05}.zarr") for i, path in enumerate(images)
 54            ]
 55        else:
 56            embedding_paths = [
 57                os.path.join(embedding_path, f"{Path(path).stem}.zarr") for path in images
 58            ]
 59        assert all(os.path.exists(emb_path) for emb_path in embedding_paths)
 60
 61    t_run = time.time() - t_start
 62    minutes = int(t_run // 60)
 63    seconds = int(round(t_run % 60, 0))
 64    print("Precomputation took", t_run, f"seconds (= {minutes:02}:{seconds:02} minutes)")
 65
 66    return predictor, decoder, embedding_paths
 67
 68
 69def _get_input_shape(image, is_volumetric=False):
 70    if image.ndim == 2:
 71        image_shape = image.shape
 72    elif image.ndim == 3:
 73        if is_volumetric:
 74            image_shape = image.shape
 75        else:
 76            image_shape = image.shape[:-1]
 77    elif image.ndim == 4:
 78        image_shape = image.shape[:-1]
 79
 80    return image_shape
 81
 82
 83def _initialize_annotator(
 84    viewer, image, image_embedding_path,
 85    model_type, halo, tile_shape, predictor, decoder, is_volumetric,
 86    precompute_amg_state, checkpoint_path, device,
 87    embedding_path,
 88):
 89    if viewer is None:
 90        viewer = napari.Viewer()
 91    viewer.add_image(image, name="image")
 92
 93    state = AnnotatorState()
 94    state.initialize_predictor(
 95        image, model_type=model_type, save_path=image_embedding_path, halo=halo, tile_shape=tile_shape,
 96        predictor=predictor, decoder=decoder,
 97        ndim=3 if is_volumetric else 2, precompute_amg_state=precompute_amg_state,
 98        checkpoint_path=checkpoint_path, device=device, skip_load=False,
 99    )
100    state.image_shape = _get_input_shape(image, is_volumetric)
101
102    if is_volumetric:
103        if image.ndim not in [3, 4]:
104            raise ValueError(f"Invalid image dimensions for 3d annotator, expect 3 or 4 dimensions, got {image.ndim}")
105        annotator = Annotator3d(viewer)
106    else:
107        if image.ndim not in (2, 3):
108            raise ValueError(f"Invalid image dimensions for 2d annotator, expect 2 or 3 dimensions, got {image.ndim}")
109        annotator = Annotator2d(viewer)
110
111    annotator._update_image()
112
113    # Add the annotator widget to the viewer and sync widgets.
114    viewer.window.add_dock_widget(annotator)
115    _sync_embedding_widget(
116        state.widgets["embeddings"], model_type,
117        save_path=embedding_path, checkpoint_path=checkpoint_path,
118        device=device, tile_shape=tile_shape, halo=halo
119    )
120    return viewer, annotator
121
122
123def image_series_annotator(
124    images: Union[List[Union[os.PathLike, str]], List[np.ndarray]],
125    output_folder: str,
126    model_type: str = util._DEFAULT_MODEL,
127    embedding_path: Optional[str] = None,
128    tile_shape: Optional[Tuple[int, int]] = None,
129    halo: Optional[Tuple[int, int]] = None,
130    viewer: Optional["napari.viewer.Viewer"] = None,
131    return_viewer: bool = False,
132    precompute_amg_state: bool = False,
133    checkpoint_path: Optional[str] = None,
134    is_volumetric: bool = False,
135    device: Optional[Union[str, torch.device]] = None,
136    prefer_decoder: bool = True,
137    skip_segmented: bool = True,
138) -> Optional["napari.viewer.Viewer"]:
139    """Run the annotation tool for a series of images (supported for both 2d and 3d images).
140
141    Args:
142        images: List of the file paths or list of (set of) slices for the images to be annotated.
143        output_folder: The folder where the segmentation results are saved.
144        model_type: The Segment Anything model to use. For details on the available models check out
145            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
146        embedding_path: Filepath where to save the embeddings.
147        tile_shape: Shape of tiles for tiled embedding prediction.
148            If `None` then the whole image is passed to Segment Anything.
149        halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
150        viewer: The viewer to which the SegmentAnything functionality should be added.
151            This enables using a pre-initialized viewer.
152        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
153        precompute_amg_state: Whether to precompute the state for automatic mask generation.
154            This will take more time when precomputing embeddings, but will then make
155            automatic mask generation much faster.
156        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
157        is_volumetric: Whether to use the 3d annotator.
158        prefer_decoder: Whether to use decoder based instance segmentation if
159            the model used has an additional decoder for instance segmentation.
160        skip_segmented: Whether to skip images that were already segmented.
161
162    Returns:
163        The napari viewer, only returned if `return_viewer=True`.
164    """
165    end_msg = "You have annotated the last image. Do you wish to close napari?"
166    os.makedirs(output_folder, exist_ok=True)
167
168    # Precompute embeddings and amg state (if corresponding options set).
169    predictor, decoder, embedding_paths = _precompute(
170        images, model_type,
171        embedding_path, tile_shape, halo, precompute_amg_state,
172        checkpoint_path=checkpoint_path, device=device,
173        ndim=3 if is_volumetric else 2, prefer_decoder=prefer_decoder,
174    )
175
176    next_image_id = 0
177    have_inputs_as_arrays = isinstance(images[next_image_id], np.ndarray)
178
179    def _get_save_path(image_path, current_idx):
180        if have_inputs_as_arrays:
181            fname = f"seg_{current_idx:05}.tif"
182        else:
183            fname = os.path.basename(image_path)
184            fname = os.path.splitext(fname)[0] + ".tif"
185        return os.path.join(output_folder, fname)
186
187    # Check which image to load next if we skip segmented images.
188    image_embedding_path = None
189    if skip_segmented:
190        while True:
191            if next_image_id == len(images):
192                print(end_msg)
193                return
194
195            save_path = _get_save_path(images[next_image_id], next_image_id)
196            if not os.path.exists(save_path):
197                print("The first image to annotate is image number", next_image_id)
198                image = images[next_image_id]
199                if not have_inputs_as_arrays:
200                    image = imageio.imread(image)
201                image_embedding_path = embedding_paths[next_image_id]
202                break
203
204            next_image_id += 1
205
206    # Initialize the viewer and annotator for this image.
207    state = AnnotatorState()
208    viewer, annotator = _initialize_annotator(
209        viewer, image, image_embedding_path,
210        model_type, halo, tile_shape, predictor, decoder, is_volumetric,
211        precompute_amg_state, checkpoint_path, device,
212        embedding_path,
213    )
214
215    def _save_segmentation(image_path, current_idx, segmentation):
216        save_path = _get_save_path(image_path, next_image_id)
217        imageio.imwrite(save_path, segmentation, compression="zlib")
218
219    # Add functionality for going to the next image.
220    @magicgui(call_button="Next Image [N]")
221    def next_image(*args):
222        nonlocal next_image_id
223
224        segmentation = viewer.layers["committed_objects"].data
225        abort = False
226        if segmentation.sum() == 0:
227            msg = "Nothing is segmented yet. Do you wish to continue to the next image?"
228            abort = widgets._generate_message("info", msg)
229            if abort:
230                return
231
232        # Save the current segmentation.
233        _save_segmentation(images[next_image_id], next_image_id, segmentation)
234
235        # Clear the segmentation already to avoid lagging removal.
236        viewer.layers["committed_objects"].data = np.zeros_like(viewer.layers["committed_objects"].data)
237
238        # Go to the next images, if skipping images that are already segmented check if we have to load it.
239        next_image_id += 1
240        if skip_segmented:
241            save_path = _get_save_path(images[next_image_id], next_image_id)
242            while os.path.exists(save_path):
243                next_image_id += 1
244                if next_image_id == len(images):
245                    break
246                save_path = _get_save_path(images[next_image_id], next_image_id)
247
248        # Load the next image.
249        if next_image_id == len(images):
250            # Inform the user via dialog.
251            abort = widgets._generate_message("info", end_msg)
252            if not abort:
253                viewer.close()
254            return
255
256        print(
257            "Loading next image:",
258            images[next_image_id] if not have_inputs_as_arrays else f"at index {next_image_id}"
259        )
260
261        if have_inputs_as_arrays:
262            image = images[next_image_id]
263        else:
264            image = imageio.imread(images[next_image_id])
265
266        image_embedding_path = embedding_paths[next_image_id]
267
268        # Set the new image in the viewer, state and annotator.
269        viewer.layers["image"].data = image
270
271        if state.amg is not None:
272            state.amg.clear_state()
273        state.initialize_predictor(
274            image, model_type=model_type, ndim=3 if is_volumetric else 2,
275            save_path=image_embedding_path,
276            tile_shape=tile_shape, halo=halo,
277            predictor=predictor, decoder=decoder,
278            precompute_amg_state=precompute_amg_state, device=device,
279            skip_load=False,
280        )
281        state.image_shape = _get_input_shape(image, is_volumetric)
282
283        annotator._update_image()
284
285    viewer.window.add_dock_widget(next_image)
286
287    @viewer.bind_key("n", overwrite=True)
288    def _next_image(viewer):
289        next_image(viewer)
290
291    if return_viewer:
292        return viewer
293    napari.run()
294
295
296def image_folder_annotator(
297    input_folder: str,
298    output_folder: str,
299    pattern: str = "*",
300    viewer: Optional["napari.viewer.Viewer"] = None,
301    return_viewer: bool = False,
302    **kwargs
303) -> Optional["napari.viewer.Viewer"]:
304    """Run the 2d annotation tool for a series of images in a folder.
305
306    Args:
307        input_folder: The folder with the images to be annotated.
308        output_folder: The folder where the segmentation results are saved.
309        pattern: The glob patter for loading files from `input_folder`.
310            By default all files will be loaded.
311        viewer: The viewer to which the SegmentAnything functionality should be added.
312            This enables using a pre-initialized viewer.
313        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
314        kwargs: The keyword arguments for `micro_sam.sam_annotator.image_series_annotator`.
315
316    Returns:
317        The napari viewer, only returned if `return_viewer=True`.
318    """
319    image_files = sorted(glob(os.path.join(input_folder, pattern)))
320    return image_series_annotator(
321        image_files, output_folder, viewer=viewer, return_viewer=return_viewer, **kwargs
322    )
323
324
325class ImageSeriesAnnotator(widgets._WidgetBase):
326    def __init__(self, viewer: napari.Viewer, parent=None):
327        super().__init__(parent=parent)
328        self._viewer = viewer
329
330        # Create the UI: the general options.
331        self._create_options()
332
333        # Add the settings (collapsible).
334        self.layout().addWidget(self._create_settings())
335
336        # Add the run button to trigger the embedding computation.
337        self.run_button = QtWidgets.QPushButton("Annotate Images")
338        self.run_button.clicked.connect(self.__call__)
339        self.layout().addWidget(self.run_button)
340
341    # model_type: str = util._DEFAULT_MODEL,
342    def _create_options(self):
343        self.folder = None
344        _, layout = self._add_path_param(
345            "folder", self.folder, "directory",
346            title="Input Folder", placeholder="Folder with images ...",
347            tooltip=get_tooltip("image_series_annotator", "folder")
348        )
349        self.layout().addLayout(layout)
350
351        self.output_folder = None
352        _, layout = self._add_path_param(
353            "output_folder", self.output_folder, "directory",
354            title="Output Folder", placeholder="Folder to save the results ...",
355            tooltip=get_tooltip("image_series_annotator", "output_folder")
356        )
357        self.layout().addLayout(layout)
358
359        self.model_type = util._DEFAULT_MODEL
360        model_options = list(util.models().urls.keys())
361        model_options = [model for model in model_options if not model.endswith("decoder")]
362        _, layout = self._add_choice_param(
363            "model_type", self.model_type, model_options, title="Model:",
364            tooltip=get_tooltip("embedding", "model")
365        )
366        self.layout().addLayout(layout)
367
368    def _create_settings(self):
369        setting_values = QtWidgets.QWidget()
370        setting_values.setLayout(QtWidgets.QVBoxLayout())
371
372        self.pattern = "*"
373        _, layout = self._add_string_param(
374            "pattern", self.pattern, tooltip=get_tooltip("image_series_annotator", "pattern")
375        )
376        setting_values.layout().addLayout(layout)
377
378        self.is_volumetric = False
379        setting_values.layout().addWidget(self._add_boolean_param(
380            "is_volumetric", self.is_volumetric, tooltip=get_tooltip("image_series_annotator", "is_volumetric")
381        ))
382
383        self.device = "auto"
384        device_options = ["auto"] + util._available_devices()
385        self.device_dropdown, layout = self._add_choice_param(
386            "device", self.device, device_options, tooltip=get_tooltip("embedding", "device")
387        )
388        setting_values.layout().addLayout(layout)
389
390        self.embeddings_save_path = None
391        _, layout = self._add_path_param(
392            "embeddings_save_path", self.embeddings_save_path, "directory", title="embeddings save path:",
393            tooltip=get_tooltip("embedding", "embeddings_save_path")
394        )
395        setting_values.layout().addLayout(layout)
396
397        self.custom_weights = None  # select_file
398        _, layout = self._add_path_param(
399            "custom_weights", self.custom_weights, "file", title="custom weights path:",
400            tooltip=get_tooltip("embedding", "custom_weights")
401        )
402        setting_values.layout().addLayout(layout)
403
404        self.tile_x, self.tile_y = 0, 0
405        self.tile_x_param, self.tile_y_param, layout = self._add_shape_param(
406            ("tile_x", "tile_y"), (self.tile_x, self.tile_y), min_val=0, max_val=2048, step=16,
407            tooltip=get_tooltip("embedding", "tiling")
408        )
409        setting_values.layout().addLayout(layout)
410
411        self.halo_x, self.halo_y = 0, 0
412        self.halo_x_param, self.halo_y_param, layout = self._add_shape_param(
413            ("halo_x", "halo_y"), (self.halo_x, self.halo_y), min_val=0, max_val=512,
414            tooltip=get_tooltip("embedding", "halo")
415        )
416        setting_values.layout().addLayout(layout)
417
418        settings = widgets._make_collapsible(setting_values, title="Advanced Settings")
419        return settings
420
421    def _validate_inputs(self):
422        missing_data = self.folder is None or len(glob(os.path.join(self.folder, self.pattern))) == 0
423        missing_output = self.output_folder is None
424        if missing_data or missing_output:
425            msg = ""
426            if missing_data:
427                msg += "The input folder is missing or empty. "
428            if missing_output:
429                msg += "The output folder is missing."
430            return widgets._generate_message("error", msg)
431        return False
432
433    def __call__(self, skip_validate=False):
434        if not skip_validate and self._validate_inputs():
435            return
436        tile_shape, halo = widgets._process_tiling_inputs(self.tile_x, self.tile_y, self.halo_x, self.halo_y)
437
438        image_folder_annotator(
439            self.folder, self.output_folder, self.pattern,
440            model_type=self.model_type,
441            embedding_path=self.embeddings_save_path,
442            tile_shape=tile_shape, halo=halo, checkpoint_path=self.custom_weights,
443            device=self.device, is_volumetric=self.is_volumetric,
444            viewer=self._viewer, return_viewer=True,
445        )
446
447
448def main():
449    """@private"""
450    import argparse
451
452    available_models = list(util.get_model_names())
453    available_models = ", ".join(available_models)
454
455    parser = argparse.ArgumentParser(description="Annotate a series of images from a folder.")
456    parser.add_argument(
457        "-i", "--input_folder", required=True,
458        help="The folder containing the image data. The data can be stored in any common format (tif, jpg, png, ...)."
459    )
460    parser.add_argument(
461        "-o", "--output_folder", required=True,
462        help="The folder where the segmentation results will be stored."
463    )
464    parser.add_argument(
465        "-p", "--pattern", default="*",
466        help="The pattern to select the images to annotator from the input folder. E.g. *.tif to annotate all tifs."
467        "By default all files in the folder will be loaded and annotated."
468    )
469    parser.add_argument(
470        "-e", "--embedding_path",
471        help="The filepath for saving/loading the pre-computed image embeddings. "
472        "NOTE: It is recommended to pass this argument and store the embeddings, "
473        "otherwise they will be recomputed every time (which can take a long time)."
474    )
475    parser.add_argument(
476        "-m", "--model_type", default=util._DEFAULT_MODEL,
477        help=f"The segment anything model that will be used, one of {available_models}."
478    )
479    parser.add_argument(
480        "-c", "--checkpoint", default=None,
481        help="Checkpoint from which the SAM model will be loaded loaded."
482    )
483    parser.add_argument(
484        "-d", "--device", default=None,
485        help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)."
486        "By default the most performant available device will be selected."
487    )
488    parser.add_argument(
489        "--is_volumetric", action="store_true", help="Whether to use the 3d annotator for a set of 3d volumes."
490    )
491
492    parser.add_argument(
493        "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction", default=None
494    )
495    parser.add_argument(
496        "--halo", nargs="+", type=int, help="The halo for using tiled prediction", default=None
497    )
498    parser.add_argument("--precompute_amg_state", action="store_true")
499    parser.add_argument("--prefer_decoder", action="store_false")
500    parser.add_argument("--skip_segmented", action="store_false")
501
502    args = parser.parse_args()
503
504    image_folder_annotator(
505        args.input_folder, args.output_folder, args.pattern,
506        embedding_path=args.embedding_path, model_type=args.model_type,
507        tile_shape=args.tile_shape, halo=args.halo, precompute_amg_state=args.precompute_amg_state,
508        checkpoint_path=args.checkpoint, device=args.device, is_volumetric=args.is_volumetric,
509        prefer_decoder=args.prefer_decoder, skip_segmented=args.skip_segmented
510    )
def image_series_annotator( images: Union[List[Union[str, os.PathLike]], List[numpy.ndarray]], output_folder: str, model_type: str = 'vit_l', embedding_path: Optional[str] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, viewer: Optional[napari.viewer.Viewer] = None, return_viewer: bool = False, precompute_amg_state: bool = False, checkpoint_path: Optional[str] = None, is_volumetric: bool = False, device: Union[str, torch.device, NoneType] = None, prefer_decoder: bool = True, skip_segmented: bool = True) -> Optional[napari.viewer.Viewer]:
124def image_series_annotator(
125    images: Union[List[Union[os.PathLike, str]], List[np.ndarray]],
126    output_folder: str,
127    model_type: str = util._DEFAULT_MODEL,
128    embedding_path: Optional[str] = None,
129    tile_shape: Optional[Tuple[int, int]] = None,
130    halo: Optional[Tuple[int, int]] = None,
131    viewer: Optional["napari.viewer.Viewer"] = None,
132    return_viewer: bool = False,
133    precompute_amg_state: bool = False,
134    checkpoint_path: Optional[str] = None,
135    is_volumetric: bool = False,
136    device: Optional[Union[str, torch.device]] = None,
137    prefer_decoder: bool = True,
138    skip_segmented: bool = True,
139) -> Optional["napari.viewer.Viewer"]:
140    """Run the annotation tool for a series of images (supported for both 2d and 3d images).
141
142    Args:
143        images: List of the file paths or list of (set of) slices for the images to be annotated.
144        output_folder: The folder where the segmentation results are saved.
145        model_type: The Segment Anything model to use. For details on the available models check out
146            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
147        embedding_path: Filepath where to save the embeddings.
148        tile_shape: Shape of tiles for tiled embedding prediction.
149            If `None` then the whole image is passed to Segment Anything.
150        halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
151        viewer: The viewer to which the SegmentAnything functionality should be added.
152            This enables using a pre-initialized viewer.
153        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
154        precompute_amg_state: Whether to precompute the state for automatic mask generation.
155            This will take more time when precomputing embeddings, but will then make
156            automatic mask generation much faster.
157        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
158        is_volumetric: Whether to use the 3d annotator.
159        prefer_decoder: Whether to use decoder based instance segmentation if
160            the model used has an additional decoder for instance segmentation.
161        skip_segmented: Whether to skip images that were already segmented.
162
163    Returns:
164        The napari viewer, only returned if `return_viewer=True`.
165    """
166    end_msg = "You have annotated the last image. Do you wish to close napari?"
167    os.makedirs(output_folder, exist_ok=True)
168
169    # Precompute embeddings and amg state (if corresponding options set).
170    predictor, decoder, embedding_paths = _precompute(
171        images, model_type,
172        embedding_path, tile_shape, halo, precompute_amg_state,
173        checkpoint_path=checkpoint_path, device=device,
174        ndim=3 if is_volumetric else 2, prefer_decoder=prefer_decoder,
175    )
176
177    next_image_id = 0
178    have_inputs_as_arrays = isinstance(images[next_image_id], np.ndarray)
179
180    def _get_save_path(image_path, current_idx):
181        if have_inputs_as_arrays:
182            fname = f"seg_{current_idx:05}.tif"
183        else:
184            fname = os.path.basename(image_path)
185            fname = os.path.splitext(fname)[0] + ".tif"
186        return os.path.join(output_folder, fname)
187
188    # Check which image to load next if we skip segmented images.
189    image_embedding_path = None
190    if skip_segmented:
191        while True:
192            if next_image_id == len(images):
193                print(end_msg)
194                return
195
196            save_path = _get_save_path(images[next_image_id], next_image_id)
197            if not os.path.exists(save_path):
198                print("The first image to annotate is image number", next_image_id)
199                image = images[next_image_id]
200                if not have_inputs_as_arrays:
201                    image = imageio.imread(image)
202                image_embedding_path = embedding_paths[next_image_id]
203                break
204
205            next_image_id += 1
206
207    # Initialize the viewer and annotator for this image.
208    state = AnnotatorState()
209    viewer, annotator = _initialize_annotator(
210        viewer, image, image_embedding_path,
211        model_type, halo, tile_shape, predictor, decoder, is_volumetric,
212        precompute_amg_state, checkpoint_path, device,
213        embedding_path,
214    )
215
216    def _save_segmentation(image_path, current_idx, segmentation):
217        save_path = _get_save_path(image_path, next_image_id)
218        imageio.imwrite(save_path, segmentation, compression="zlib")
219
220    # Add functionality for going to the next image.
221    @magicgui(call_button="Next Image [N]")
222    def next_image(*args):
223        nonlocal next_image_id
224
225        segmentation = viewer.layers["committed_objects"].data
226        abort = False
227        if segmentation.sum() == 0:
228            msg = "Nothing is segmented yet. Do you wish to continue to the next image?"
229            abort = widgets._generate_message("info", msg)
230            if abort:
231                return
232
233        # Save the current segmentation.
234        _save_segmentation(images[next_image_id], next_image_id, segmentation)
235
236        # Clear the segmentation already to avoid lagging removal.
237        viewer.layers["committed_objects"].data = np.zeros_like(viewer.layers["committed_objects"].data)
238
239        # Go to the next images, if skipping images that are already segmented check if we have to load it.
240        next_image_id += 1
241        if skip_segmented:
242            save_path = _get_save_path(images[next_image_id], next_image_id)
243            while os.path.exists(save_path):
244                next_image_id += 1
245                if next_image_id == len(images):
246                    break
247                save_path = _get_save_path(images[next_image_id], next_image_id)
248
249        # Load the next image.
250        if next_image_id == len(images):
251            # Inform the user via dialog.
252            abort = widgets._generate_message("info", end_msg)
253            if not abort:
254                viewer.close()
255            return
256
257        print(
258            "Loading next image:",
259            images[next_image_id] if not have_inputs_as_arrays else f"at index {next_image_id}"
260        )
261
262        if have_inputs_as_arrays:
263            image = images[next_image_id]
264        else:
265            image = imageio.imread(images[next_image_id])
266
267        image_embedding_path = embedding_paths[next_image_id]
268
269        # Set the new image in the viewer, state and annotator.
270        viewer.layers["image"].data = image
271
272        if state.amg is not None:
273            state.amg.clear_state()
274        state.initialize_predictor(
275            image, model_type=model_type, ndim=3 if is_volumetric else 2,
276            save_path=image_embedding_path,
277            tile_shape=tile_shape, halo=halo,
278            predictor=predictor, decoder=decoder,
279            precompute_amg_state=precompute_amg_state, device=device,
280            skip_load=False,
281        )
282        state.image_shape = _get_input_shape(image, is_volumetric)
283
284        annotator._update_image()
285
286    viewer.window.add_dock_widget(next_image)
287
288    @viewer.bind_key("n", overwrite=True)
289    def _next_image(viewer):
290        next_image(viewer)
291
292    if return_viewer:
293        return viewer
294    napari.run()

Run the annotation tool for a series of images (supported for both 2d and 3d images).

Arguments:
  • images: List of the file paths or list of (set of) slices for the images to be annotated.
  • output_folder: The folder where the segmentation results are saved.
  • model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
  • embedding_path: Filepath where to save the embeddings.
  • tile_shape: Shape of tiles for tiled embedding prediction. If None then the whole image is passed to Segment Anything.
  • halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
  • viewer: The viewer to which the SegmentAnything functionality should be added. This enables using a pre-initialized viewer.
  • return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
  • precompute_amg_state: Whether to precompute the state for automatic mask generation. This will take more time when precomputing embeddings, but will then make automatic mask generation much faster.
  • checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
  • is_volumetric: Whether to use the 3d annotator.
  • prefer_decoder: Whether to use decoder based instance segmentation if the model used has an additional decoder for instance segmentation.
  • skip_segmented: Whether to skip images that were already segmented.
Returns:

The napari viewer, only returned if return_viewer=True.

def image_folder_annotator( input_folder: str, output_folder: str, pattern: str = '*', viewer: Optional[napari.viewer.Viewer] = None, return_viewer: bool = False, **kwargs) -> Optional[napari.viewer.Viewer]:
297def image_folder_annotator(
298    input_folder: str,
299    output_folder: str,
300    pattern: str = "*",
301    viewer: Optional["napari.viewer.Viewer"] = None,
302    return_viewer: bool = False,
303    **kwargs
304) -> Optional["napari.viewer.Viewer"]:
305    """Run the 2d annotation tool for a series of images in a folder.
306
307    Args:
308        input_folder: The folder with the images to be annotated.
309        output_folder: The folder where the segmentation results are saved.
310        pattern: The glob patter for loading files from `input_folder`.
311            By default all files will be loaded.
312        viewer: The viewer to which the SegmentAnything functionality should be added.
313            This enables using a pre-initialized viewer.
314        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
315        kwargs: The keyword arguments for `micro_sam.sam_annotator.image_series_annotator`.
316
317    Returns:
318        The napari viewer, only returned if `return_viewer=True`.
319    """
320    image_files = sorted(glob(os.path.join(input_folder, pattern)))
321    return image_series_annotator(
322        image_files, output_folder, viewer=viewer, return_viewer=return_viewer, **kwargs
323    )

Run the 2d annotation tool for a series of images in a folder.

Arguments:
  • input_folder: The folder with the images to be annotated.
  • output_folder: The folder where the segmentation results are saved.
  • pattern: The glob patter for loading files from input_folder. By default all files will be loaded.
  • viewer: The viewer to which the SegmentAnything functionality should be added. This enables using a pre-initialized viewer.
  • return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
  • kwargs: The keyword arguments for micro_sam.sam_annotator.image_series_annotator.
Returns:

The napari viewer, only returned if return_viewer=True.

class ImageSeriesAnnotator(micro_sam.sam_annotator._widgets._WidgetBase):
326class ImageSeriesAnnotator(widgets._WidgetBase):
327    def __init__(self, viewer: napari.Viewer, parent=None):
328        super().__init__(parent=parent)
329        self._viewer = viewer
330
331        # Create the UI: the general options.
332        self._create_options()
333
334        # Add the settings (collapsible).
335        self.layout().addWidget(self._create_settings())
336
337        # Add the run button to trigger the embedding computation.
338        self.run_button = QtWidgets.QPushButton("Annotate Images")
339        self.run_button.clicked.connect(self.__call__)
340        self.layout().addWidget(self.run_button)
341
342    # model_type: str = util._DEFAULT_MODEL,
343    def _create_options(self):
344        self.folder = None
345        _, layout = self._add_path_param(
346            "folder", self.folder, "directory",
347            title="Input Folder", placeholder="Folder with images ...",
348            tooltip=get_tooltip("image_series_annotator", "folder")
349        )
350        self.layout().addLayout(layout)
351
352        self.output_folder = None
353        _, layout = self._add_path_param(
354            "output_folder", self.output_folder, "directory",
355            title="Output Folder", placeholder="Folder to save the results ...",
356            tooltip=get_tooltip("image_series_annotator", "output_folder")
357        )
358        self.layout().addLayout(layout)
359
360        self.model_type = util._DEFAULT_MODEL
361        model_options = list(util.models().urls.keys())
362        model_options = [model for model in model_options if not model.endswith("decoder")]
363        _, layout = self._add_choice_param(
364            "model_type", self.model_type, model_options, title="Model:",
365            tooltip=get_tooltip("embedding", "model")
366        )
367        self.layout().addLayout(layout)
368
369    def _create_settings(self):
370        setting_values = QtWidgets.QWidget()
371        setting_values.setLayout(QtWidgets.QVBoxLayout())
372
373        self.pattern = "*"
374        _, layout = self._add_string_param(
375            "pattern", self.pattern, tooltip=get_tooltip("image_series_annotator", "pattern")
376        )
377        setting_values.layout().addLayout(layout)
378
379        self.is_volumetric = False
380        setting_values.layout().addWidget(self._add_boolean_param(
381            "is_volumetric", self.is_volumetric, tooltip=get_tooltip("image_series_annotator", "is_volumetric")
382        ))
383
384        self.device = "auto"
385        device_options = ["auto"] + util._available_devices()
386        self.device_dropdown, layout = self._add_choice_param(
387            "device", self.device, device_options, tooltip=get_tooltip("embedding", "device")
388        )
389        setting_values.layout().addLayout(layout)
390
391        self.embeddings_save_path = None
392        _, layout = self._add_path_param(
393            "embeddings_save_path", self.embeddings_save_path, "directory", title="embeddings save path:",
394            tooltip=get_tooltip("embedding", "embeddings_save_path")
395        )
396        setting_values.layout().addLayout(layout)
397
398        self.custom_weights = None  # select_file
399        _, layout = self._add_path_param(
400            "custom_weights", self.custom_weights, "file", title="custom weights path:",
401            tooltip=get_tooltip("embedding", "custom_weights")
402        )
403        setting_values.layout().addLayout(layout)
404
405        self.tile_x, self.tile_y = 0, 0
406        self.tile_x_param, self.tile_y_param, layout = self._add_shape_param(
407            ("tile_x", "tile_y"), (self.tile_x, self.tile_y), min_val=0, max_val=2048, step=16,
408            tooltip=get_tooltip("embedding", "tiling")
409        )
410        setting_values.layout().addLayout(layout)
411
412        self.halo_x, self.halo_y = 0, 0
413        self.halo_x_param, self.halo_y_param, layout = self._add_shape_param(
414            ("halo_x", "halo_y"), (self.halo_x, self.halo_y), min_val=0, max_val=512,
415            tooltip=get_tooltip("embedding", "halo")
416        )
417        setting_values.layout().addLayout(layout)
418
419        settings = widgets._make_collapsible(setting_values, title="Advanced Settings")
420        return settings
421
422    def _validate_inputs(self):
423        missing_data = self.folder is None or len(glob(os.path.join(self.folder, self.pattern))) == 0
424        missing_output = self.output_folder is None
425        if missing_data or missing_output:
426            msg = ""
427            if missing_data:
428                msg += "The input folder is missing or empty. "
429            if missing_output:
430                msg += "The output folder is missing."
431            return widgets._generate_message("error", msg)
432        return False
433
434    def __call__(self, skip_validate=False):
435        if not skip_validate and self._validate_inputs():
436            return
437        tile_shape, halo = widgets._process_tiling_inputs(self.tile_x, self.tile_y, self.halo_x, self.halo_y)
438
439        image_folder_annotator(
440            self.folder, self.output_folder, self.pattern,
441            model_type=self.model_type,
442            embedding_path=self.embeddings_save_path,
443            tile_shape=tile_shape, halo=halo, checkpoint_path=self.custom_weights,
444            device=self.device, is_volumetric=self.is_volumetric,
445            viewer=self._viewer, return_viewer=True,
446        )

QWidget(parent: typing.Optional[QWidget] = None, flags: Union[Qt.WindowFlags, Qt.WindowType] = Qt.WindowFlags())

ImageSeriesAnnotator(viewer: napari.viewer.Viewer, parent=None)
327    def __init__(self, viewer: napari.Viewer, parent=None):
328        super().__init__(parent=parent)
329        self._viewer = viewer
330
331        # Create the UI: the general options.
332        self._create_options()
333
334        # Add the settings (collapsible).
335        self.layout().addWidget(self._create_settings())
336
337        # Add the run button to trigger the embedding computation.
338        self.run_button = QtWidgets.QPushButton("Annotate Images")
339        self.run_button.clicked.connect(self.__call__)
340        self.layout().addWidget(self.run_button)
run_button
Inherited Members
PyQt5.QtWidgets.QWidget
RenderFlag
RenderFlags
acceptDrops
accessibleDescription
accessibleName
actionEvent
actions
activateWindow
addAction
addActions
adjustSize
autoFillBackground
backgroundRole
baseSize
changeEvent
childAt
childrenRect
childrenRegion
clearFocus
clearMask
close
closeEvent
contentsMargins
contentsRect
contextMenuEvent
contextMenuPolicy
create
createWindowContainer
cursor
destroy
devType
dragEnterEvent
dragLeaveEvent
dragMoveEvent
dropEvent
effectiveWinId
ensurePolished
enterEvent
event
find
focusInEvent
focusNextChild
focusNextPrevChild
focusOutEvent
focusPolicy
focusPreviousChild
focusProxy
focusWidget
font
fontInfo
fontMetrics
foregroundRole
frameGeometry
frameSize
geometry
getContentsMargins
grab
grabGesture
grabKeyboard
grabMouse
grabShortcut
graphicsEffect
graphicsProxyWidget
hasFocus
hasHeightForWidth
hasMouseTracking
hasTabletTracking
height
heightForWidth
hide
hideEvent
initPainter
inputMethodEvent
inputMethodHints
inputMethodQuery
insertAction
insertActions
isActiveWindow
isAncestorOf
isEnabled
isEnabledTo
isFullScreen
isHidden
isLeftToRight
isMaximized
isMinimized
isModal
isRightToLeft
isVisible
isVisibleTo
isWindow
isWindowModified
keyPressEvent
keyReleaseEvent
keyboardGrabber
layout
layoutDirection
leaveEvent
locale
lower
mapFrom
mapFromGlobal
mapFromParent
mapTo
mapToGlobal
mapToParent
mask
maximumHeight
maximumSize
maximumWidth
metric
minimumHeight
minimumSize
minimumSizeHint
minimumWidth
mouseDoubleClickEvent
mouseGrabber
mouseMoveEvent
mousePressEvent
mouseReleaseEvent
move
moveEvent
nativeEvent
nativeParentWidget
nextInFocusChain
normalGeometry
overrideWindowFlags
overrideWindowState
paintEngine
paintEvent
palette
parentWidget
pos
previousInFocusChain
raise_
rect
releaseKeyboard
releaseMouse
releaseShortcut
removeAction
render
repaint
resize
resizeEvent
restoreGeometry
saveGeometry
screen
scroll
setAcceptDrops
setAccessibleDescription
setAccessibleName
setAttribute
setAutoFillBackground
setBackgroundRole
setBaseSize
setContentsMargins
setContextMenuPolicy
setCursor
setDisabled
setEnabled
setFixedHeight
setFixedSize
setFixedWidth
setFocus
setFocusPolicy
setFocusProxy
setFont
setForegroundRole
setGeometry
setGraphicsEffect
setHidden
setInputMethodHints
setLayout
setLayoutDirection
setLocale
setMask
setMaximumHeight
setMaximumSize
setMaximumWidth
setMinimumHeight
setMinimumSize
setMinimumWidth
setMouseTracking
setPalette
setParent
setShortcutAutoRepeat
setShortcutEnabled
setSizeIncrement
setSizePolicy
setStatusTip
setStyle
setStyleSheet
setTabOrder
setTabletTracking
setToolTip
setToolTipDuration
setUpdatesEnabled
setVisible
setWhatsThis
setWindowFilePath
setWindowFlag
setWindowFlags
setWindowIcon
setWindowIconText
setWindowModality
setWindowModified
setWindowOpacity
setWindowRole
setWindowState
setWindowTitle
sharedPainter
show
showEvent
showFullScreen
showMaximized
showMinimized
showNormal
size
sizeHint
sizeIncrement
sizePolicy
stackUnder
statusTip
style
styleSheet
tabletEvent
testAttribute
toolTip
toolTipDuration
underMouse
ungrabGesture
unsetCursor
unsetLayoutDirection
unsetLocale
update
updateGeometry
updateMicroFocus
updatesEnabled
visibleRegion
whatsThis
wheelEvent
width
winId
window
windowFilePath
windowFlags
windowHandle
windowIcon
windowIconText
windowModality
windowOpacity
windowRole
windowState
windowTitle
windowType
x
y
DrawChildren
DrawWindowBackground
IgnoreMask
windowIconTextChanged
windowIconChanged
windowTitleChanged
customContextMenuRequested
PyQt5.QtCore.QObject
blockSignals
childEvent
children
connectNotify
customEvent
deleteLater
disconnect
disconnectNotify
dumpObjectInfo
dumpObjectTree
dynamicPropertyNames
eventFilter
findChild
findChildren
inherits
installEventFilter
isSignalConnected
isWidgetType
isWindowType
killTimer
metaObject
moveToThread
objectName
parent
property
pyqtConfigure
receivers
removeEventFilter
sender
senderSignalIndex
setObjectName
setProperty
signalsBlocked
startTimer
thread
timerEvent
tr
staticMetaObject
objectNameChanged
destroyed
PyQt5.QtGui.QPaintDevice
PaintDeviceMetric
colorCount
depth
devicePixelRatio
devicePixelRatioF
devicePixelRatioFScale
heightMM
logicalDpiX
logicalDpiY
paintingActive
physicalDpiX
physicalDpiY
widthMM
PdmDepth
PdmDevicePixelRatio
PdmDevicePixelRatioScaled
PdmDpiX
PdmDpiY
PdmHeight
PdmHeightMM
PdmNumColors
PdmPhysicalDpiX
PdmPhysicalDpiY
PdmWidth
PdmWidthMM