micro_sam.sam_annotator.image_series_annotator

  1import os
  2import time
  3from glob import glob
  4from pathlib import Path
  5from typing import List, Optional, Union, Tuple
  6
  7import numpy as np
  8import imageio.v3 as imageio
  9
 10import torch
 11
 12import napari
 13from magicgui import magicgui
 14from qtpy import QtWidgets
 15
 16from .. import util
 17from . import _widgets as widgets
 18from ._tooltips import get_tooltip
 19from ._state import AnnotatorState
 20from .annotator_2d import Annotator2d
 21from .annotator_3d import Annotator3d
 22from .util import _sync_embedding_widget
 23from ..instance_segmentation import get_decoder
 24from ..precompute_state import _precompute_state_for_files
 25
 26
 27def _precompute(
 28    images, model_type, embedding_path,
 29    tile_shape, halo, precompute_amg_state,
 30    checkpoint_path, device, ndim, prefer_decoder,
 31):
 32    t_start = time.time()
 33
 34    device = util.get_device(device)
 35    predictor, state = util.get_sam_model(
 36        model_type=model_type, checkpoint_path=checkpoint_path, device=device, return_state=True
 37    )
 38    if prefer_decoder and "decoder_state" in state:
 39        decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device)
 40    else:
 41        decoder = None
 42
 43    if embedding_path is None:
 44        embedding_paths = [None] * len(images)
 45    else:
 46        _precompute_state_for_files(
 47            predictor, images, embedding_path, ndim=ndim, tile_shape=tile_shape, halo=halo,
 48            precompute_amg_state=precompute_amg_state, decoder=decoder,
 49        )
 50        if isinstance(images[0], np.ndarray):
 51            embedding_paths = [
 52                os.path.join(embedding_path, f"embedding_{i:05}.zarr") for i, path in enumerate(images)
 53            ]
 54        else:
 55            embedding_paths = [
 56                os.path.join(embedding_path, f"{Path(path).stem}.zarr") for path in images
 57            ]
 58        assert all(os.path.exists(emb_path) for emb_path in embedding_paths)
 59
 60    t_run = time.time() - t_start
 61    minutes = int(t_run // 60)
 62    seconds = int(round(t_run % 60, 0))
 63    print("Precomputation took", t_run, f"seconds (= {minutes:02}:{seconds:02} minutes)")
 64
 65    return predictor, decoder, embedding_paths
 66
 67
 68def _get_input_shape(image, is_volumetric=False):
 69    if image.ndim == 2:
 70        image_shape = image.shape
 71    elif image.ndim == 3:
 72        if is_volumetric:
 73            image_shape = image.shape
 74        else:
 75            image_shape = image.shape[:-1]
 76    elif image.ndim == 4:
 77        image_shape = image.shape[:-1]
 78
 79    return image_shape
 80
 81
 82def _initialize_annotator(
 83    viewer, image, image_embedding_path,
 84    model_type, halo, tile_shape, predictor, decoder, is_volumetric,
 85    precompute_amg_state, checkpoint_path, device, embedding_path,
 86):
 87    if viewer is None:
 88        viewer = napari.Viewer()
 89    viewer.add_image(image, name="image")
 90
 91    state = AnnotatorState()
 92    state.initialize_predictor(
 93        image, model_type=model_type, save_path=image_embedding_path, halo=halo, tile_shape=tile_shape,
 94        predictor=predictor, decoder=decoder,
 95        ndim=3 if is_volumetric else 2, precompute_amg_state=precompute_amg_state,
 96        checkpoint_path=checkpoint_path, device=device, skip_load=False,
 97    )
 98    state.image_shape = _get_input_shape(image, is_volumetric)
 99
100    if is_volumetric:
101        if image.ndim not in [3, 4]:
102            raise ValueError(f"Invalid image dimensions for 3d annotator, expect 3 or 4 dimensions, got {image.ndim}")
103        annotator = Annotator3d(viewer)
104    else:
105        if image.ndim not in (2, 3):
106            raise ValueError(f"Invalid image dimensions for 2d annotator, expect 2 or 3 dimensions, got {image.ndim}")
107        annotator = Annotator2d(viewer)
108
109    annotator._update_image()
110
111    # Add the annotator widget to the viewer and sync widgets.
112    viewer.window.add_dock_widget(annotator)
113    _sync_embedding_widget(
114        state.widgets["embeddings"], model_type,
115        save_path=embedding_path, checkpoint_path=checkpoint_path,
116        device=device, tile_shape=tile_shape, halo=halo
117    )
118    return viewer, annotator
119
120
121def image_series_annotator(
122    images: Union[List[Union[os.PathLike, str]], List[np.ndarray]],
123    output_folder: str,
124    model_type: str = util._DEFAULT_MODEL,
125    embedding_path: Optional[str] = None,
126    tile_shape: Optional[Tuple[int, int]] = None,
127    halo: Optional[Tuple[int, int]] = None,
128    viewer: Optional["napari.viewer.Viewer"] = None,
129    return_viewer: bool = False,
130    precompute_amg_state: bool = False,
131    checkpoint_path: Optional[str] = None,
132    is_volumetric: bool = False,
133    device: Optional[Union[str, torch.device]] = None,
134    prefer_decoder: bool = True,
135    skip_segmented: bool = True,
136) -> Optional["napari.viewer.Viewer"]:
137    """Run the annotation tool for a series of images (supported for both 2d and 3d images).
138
139    Args:
140        images: List of the file paths or list of (set of) slices for the images to be annotated.
141        output_folder: The folder where the segmentation results are saved.
142        model_type: The Segment Anything model to use. For details on the available models check out
143            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
144        embedding_path: Filepath where to save the embeddings.
145        tile_shape: Shape of tiles for tiled embedding prediction.
146            If `None` then the whole image is passed to Segment Anything.
147        halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
148        viewer: The viewer to which the SegmentAnything functionality should be added.
149            This enables using a pre-initialized viewer.
150        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
151        precompute_amg_state: Whether to precompute the state for automatic mask generation.
152            This will take more time when precomputing embeddings, but will then make
153            automatic mask generation much faster.
154        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
155        is_volumetric: Whether to use the 3d annotator.
156        prefer_decoder: Whether to use decoder based instance segmentation if
157            the model used has an additional decoder for instance segmentation.
158        skip_segmented: Whether to skip images that were already segmented.
159
160    Returns:
161        The napari viewer, only returned if `return_viewer=True`.
162    """
163    end_msg = "You have annotated the last image. Do you wish to close napari?"
164    os.makedirs(output_folder, exist_ok=True)
165
166    # Precompute embeddings and amg state (if corresponding options set).
167    predictor, decoder, embedding_paths = _precompute(
168        images, model_type,
169        embedding_path, tile_shape, halo, precompute_amg_state,
170        checkpoint_path=checkpoint_path, device=device,
171        ndim=3 if is_volumetric else 2, prefer_decoder=prefer_decoder,
172    )
173
174    next_image_id = 0
175    have_inputs_as_arrays = isinstance(images[next_image_id], np.ndarray)
176
177    def _get_save_path(image_path, current_idx):
178        if have_inputs_as_arrays:
179            fname = f"seg_{current_idx:05}.tif"
180        else:
181            fname = os.path.basename(image_path)
182            fname = os.path.splitext(fname)[0] + ".tif"
183        return os.path.join(output_folder, fname)
184
185    # Check which image to load next if we skip segmented images.
186    image_embedding_path = None
187    if skip_segmented:
188        while True:
189            if next_image_id == len(images):
190                print(end_msg)
191                return
192
193            save_path = _get_save_path(images[next_image_id], next_image_id)
194            if not os.path.exists(save_path):
195                print("The first image to annotate is image number", next_image_id)
196                image = images[next_image_id]
197                if not have_inputs_as_arrays:
198                    image = imageio.imread(image)
199                image_embedding_path = embedding_paths[next_image_id]
200                break
201
202            next_image_id += 1
203
204    # Initialize the viewer and annotator for this image.
205    state = AnnotatorState()
206    viewer, annotator = _initialize_annotator(
207        viewer, image, image_embedding_path,
208        model_type, halo, tile_shape, predictor, decoder, is_volumetric,
209        precompute_amg_state, checkpoint_path, device, embedding_path,
210    )
211
212    def _save_segmentation(image_path, current_idx, segmentation):
213        save_path = _get_save_path(image_path, next_image_id)
214        imageio.imwrite(save_path, segmentation, compression="zlib")
215
216    # Add functionality for going to the next image.
217    @magicgui(call_button="Next Image [N]")
218    def next_image(*args):
219        nonlocal next_image_id
220
221        segmentation = viewer.layers["committed_objects"].data
222        abort = False
223        if segmentation.sum() == 0:
224            msg = "Nothing is segmented yet. Do you wish to continue to the next image?"
225            abort = widgets._generate_message("info", msg)
226            if abort:
227                return
228
229        # Save the current segmentation.
230        _save_segmentation(images[next_image_id], next_image_id, segmentation)
231
232        # Clear the segmentation already to avoid lagging removal.
233        viewer.layers["committed_objects"].data = np.zeros_like(viewer.layers["committed_objects"].data)
234
235        # Go to the next images, if skipping images that are already segmented check if we have to load it.
236        next_image_id += 1
237        if skip_segmented:
238            save_path = _get_save_path(images[next_image_id], next_image_id)
239            while os.path.exists(save_path):
240                next_image_id += 1
241                if next_image_id == len(images):
242                    break
243                save_path = _get_save_path(images[next_image_id], next_image_id)
244
245        # Load the next image.
246        if next_image_id == len(images):
247            # Inform the user via dialog.
248            abort = widgets._generate_message("info", end_msg)
249            if not abort:
250                viewer.close()
251            return
252
253        print(
254            "Loading next image:", images[next_image_id] if not have_inputs_as_arrays else f"at index {next_image_id}"
255        )
256
257        if have_inputs_as_arrays:
258            image = images[next_image_id]
259        else:
260            image = imageio.imread(images[next_image_id])
261
262        image_embedding_path = embedding_paths[next_image_id]
263
264        # Set the new image in the viewer, state and annotator.
265        viewer.layers["image"].data = image
266
267        if state.amg is not None:
268            state.amg.clear_state()
269
270        state.initialize_predictor(
271            image, model_type=model_type, ndim=3 if is_volumetric else 2,
272            save_path=image_embedding_path,
273            tile_shape=tile_shape, halo=halo,
274            predictor=predictor, decoder=decoder,
275            precompute_amg_state=precompute_amg_state, device=device,
276            skip_load=False,
277        )
278        state.image_shape = _get_input_shape(image, is_volumetric)
279
280        annotator._update_image()
281
282    viewer.window.add_dock_widget(next_image)
283
284    @viewer.bind_key("n", overwrite=True)
285    def _next_image(viewer):
286        next_image(viewer)
287
288    if return_viewer:
289        return viewer
290    napari.run()
291
292
293def image_folder_annotator(
294    input_folder: str,
295    output_folder: str,
296    pattern: str = "*",
297    viewer: Optional["napari.viewer.Viewer"] = None,
298    return_viewer: bool = False,
299    **kwargs
300) -> Optional["napari.viewer.Viewer"]:
301    """Run the 2d annotation tool for a series of images in a folder.
302
303    Args:
304        input_folder: The folder with the images to be annotated.
305        output_folder: The folder where the segmentation results are saved.
306        pattern: The glob patter for loading files from `input_folder`.
307            By default all files will be loaded.
308        viewer: The viewer to which the SegmentAnything functionality should be added.
309            This enables using a pre-initialized viewer.
310        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
311        kwargs: The keyword arguments for `micro_sam.sam_annotator.image_series_annotator`.
312
313    Returns:
314        The napari viewer, only returned if `return_viewer=True`.
315    """
316    image_files = sorted(glob(os.path.join(input_folder, pattern)))
317
318    return image_series_annotator(
319        image_files, output_folder, viewer=viewer, return_viewer=return_viewer, **kwargs
320    )
321
322
323class ImageSeriesAnnotator(widgets._WidgetBase):
324    def __init__(self, viewer: napari.Viewer, parent=None):
325        super().__init__(parent=parent)
326        self._viewer = viewer
327
328        # Create the UI: the general options.
329        self._create_options()
330
331        # Add the settings (collapsible).
332        self.layout().addWidget(self._create_settings())
333
334        # Add the run button to trigger the embedding computation.
335        self.run_button = QtWidgets.QPushButton("Annotate Images")
336        self.run_button.clicked.connect(self.__call__)
337        self.layout().addWidget(self.run_button)
338
339    # model_type: str = util._DEFAULT_MODEL,
340    def _create_options(self):
341        self.folder = None
342        _, layout = self._add_path_param(
343            "folder", self.folder, "directory",
344            title="Input Folder", placeholder="Folder with images ...",
345            tooltip=get_tooltip("image_series_annotator", "folder")
346        )
347        self.layout().addLayout(layout)
348
349        self.output_folder = None
350        _, layout = self._add_path_param(
351            "output_folder", self.output_folder, "directory",
352            title="Output Folder", placeholder="Folder to save the results ...",
353            tooltip=get_tooltip("image_series_annotator", "output_folder")
354        )
355        self.layout().addLayout(layout)
356
357        self.model_type = util._DEFAULT_MODEL
358        model_options = list(util.models().urls.keys())
359        model_options = [model for model in model_options if not model.endswith("decoder")]
360        _, layout = self._add_choice_param(
361            "model_type", self.model_type, model_options, title="Model:",
362            tooltip=get_tooltip("embedding", "model")
363        )
364        self.layout().addLayout(layout)
365
366    def _create_settings(self):
367        setting_values = QtWidgets.QWidget()
368        setting_values.setLayout(QtWidgets.QVBoxLayout())
369
370        self.pattern = "*"
371        _, layout = self._add_string_param(
372            "pattern", self.pattern, tooltip=get_tooltip("image_series_annotator", "pattern")
373        )
374        setting_values.layout().addLayout(layout)
375
376        self.is_volumetric = False
377        setting_values.layout().addWidget(self._add_boolean_param(
378            "is_volumetric", self.is_volumetric, tooltip=get_tooltip("image_series_annotator", "is_volumetric")
379        ))
380
381        self.device = "auto"
382        device_options = ["auto"] + util._available_devices()
383        self.device_dropdown, layout = self._add_choice_param(
384            "device", self.device, device_options, tooltip=get_tooltip("embedding", "device")
385        )
386        setting_values.layout().addLayout(layout)
387
388        self.embeddings_save_path = None
389        _, layout = self._add_path_param(
390            "embeddings_save_path", self.embeddings_save_path, "directory", title="embeddings save path:",
391            tooltip=get_tooltip("embedding", "embeddings_save_path")
392        )
393        setting_values.layout().addLayout(layout)
394
395        self.custom_weights = None  # select_file
396        _, layout = self._add_path_param(
397            "custom_weights", self.custom_weights, "file", title="custom weights path:",
398            tooltip=get_tooltip("embedding", "custom_weights")
399        )
400        setting_values.layout().addLayout(layout)
401
402        self.tile_x, self.tile_y = 0, 0
403        self.tile_x_param, self.tile_y_param, layout = self._add_shape_param(
404            ("tile_x", "tile_y"), (self.tile_x, self.tile_y), min_val=0, max_val=2048, step=16,
405            tooltip=get_tooltip("embedding", "tiling")
406        )
407        setting_values.layout().addLayout(layout)
408
409        self.halo_x, self.halo_y = 0, 0
410        self.halo_x_param, self.halo_y_param, layout = self._add_shape_param(
411            ("halo_x", "halo_y"), (self.halo_x, self.halo_y), min_val=0, max_val=512,
412            tooltip=get_tooltip("embedding", "halo")
413        )
414        setting_values.layout().addLayout(layout)
415
416        settings = widgets._make_collapsible(setting_values, title="Advanced Settings")
417        return settings
418
419    def _validate_inputs(self):
420        missing_data = self.folder is None or len(glob(os.path.join(self.folder, self.pattern))) == 0
421        missing_output = self.output_folder is None
422        if missing_data or missing_output:
423            msg = ""
424            if missing_data:
425                msg += "The input folder is missing or empty. "
426            if missing_output:
427                msg += "The output folder is missing."
428            return widgets._generate_message("error", msg)
429        return False
430
431    def __call__(self, skip_validate=False):
432        if not skip_validate and self._validate_inputs():
433            return
434        tile_shape, halo = widgets._process_tiling_inputs(self.tile_x, self.tile_y, self.halo_x, self.halo_y)
435
436        image_folder_annotator(
437            self.folder, self.output_folder, self.pattern,
438            model_type=self.model_type,
439            embedding_path=self.embeddings_save_path,
440            tile_shape=tile_shape, halo=halo, checkpoint_path=self.custom_weights,
441            device=self.device, is_volumetric=self.is_volumetric,
442            viewer=self._viewer, return_viewer=True,
443        )
444
445
446def main():
447    """@private"""
448    import argparse
449
450    available_models = list(util.get_model_names())
451    available_models = ", ".join(available_models)
452
453    parser = argparse.ArgumentParser(description="Annotate a series of images from a folder.")
454    parser.add_argument(
455        "-i", "--input_folder", required=True,
456        help="The folder containing the image data. The data can be stored in any common format (tif, jpg, png, ...)."
457    )
458    parser.add_argument(
459        "-o", "--output_folder", required=True,
460        help="The folder where the segmentation results will be stored."
461    )
462    parser.add_argument(
463        "-p", "--pattern", default="*",
464        help="The pattern to select the images to annotator from the input folder. E.g. *.tif to annotate all tifs."
465        "By default all files in the folder will be loaded and annotated."
466    )
467    parser.add_argument(
468        "-e", "--embedding_path",
469        help="The filepath for saving/loading the pre-computed image embeddings. "
470        "NOTE: It is recommended to pass this argument and store the embeddings, "
471        "otherwise they will be recomputed every time (which can take a long time)."
472    )
473    parser.add_argument(
474        "-m", "--model_type", default=util._DEFAULT_MODEL,
475        help=f"The segment anything model that will be used, one of {available_models}."
476    )
477    parser.add_argument(
478        "-c", "--checkpoint", default=None,
479        help="Checkpoint from which the SAM model will be loaded loaded."
480    )
481    parser.add_argument(
482        "-d", "--device", default=None,
483        help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)."
484        "By default the most performant available device will be selected."
485    )
486    parser.add_argument(
487        "--is_volumetric", action="store_true", help="Whether to use the 3d annotator for a set of 3d volumes."
488    )
489
490    parser.add_argument(
491        "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction", default=None
492    )
493    parser.add_argument(
494        "--halo", nargs="+", type=int, help="The halo for using tiled prediction", default=None
495    )
496    parser.add_argument("--precompute_amg_state", action="store_true")
497    parser.add_argument("--prefer_decoder", action="store_false")
498    parser.add_argument("--skip_segmented", action="store_false")
499
500    args = parser.parse_args()
501
502    image_folder_annotator(
503        args.input_folder, args.output_folder, args.pattern,
504        embedding_path=args.embedding_path, model_type=args.model_type,
505        tile_shape=args.tile_shape, halo=args.halo, precompute_amg_state=args.precompute_amg_state,
506        checkpoint_path=args.checkpoint, device=args.device, is_volumetric=args.is_volumetric,
507        prefer_decoder=args.prefer_decoder, skip_segmented=args.skip_segmented
508    )
def image_series_annotator( images: Union[List[Union[str, os.PathLike]], List[numpy.ndarray]], output_folder: str, model_type: str = 'vit_l', embedding_path: Optional[str] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, viewer: Optional[napari.viewer.Viewer] = None, return_viewer: bool = False, precompute_amg_state: bool = False, checkpoint_path: Optional[str] = None, is_volumetric: bool = False, device: Union[str, torch.device, NoneType] = None, prefer_decoder: bool = True, skip_segmented: bool = True) -> Optional[napari.viewer.Viewer]:
122def image_series_annotator(
123    images: Union[List[Union[os.PathLike, str]], List[np.ndarray]],
124    output_folder: str,
125    model_type: str = util._DEFAULT_MODEL,
126    embedding_path: Optional[str] = None,
127    tile_shape: Optional[Tuple[int, int]] = None,
128    halo: Optional[Tuple[int, int]] = None,
129    viewer: Optional["napari.viewer.Viewer"] = None,
130    return_viewer: bool = False,
131    precompute_amg_state: bool = False,
132    checkpoint_path: Optional[str] = None,
133    is_volumetric: bool = False,
134    device: Optional[Union[str, torch.device]] = None,
135    prefer_decoder: bool = True,
136    skip_segmented: bool = True,
137) -> Optional["napari.viewer.Viewer"]:
138    """Run the annotation tool for a series of images (supported for both 2d and 3d images).
139
140    Args:
141        images: List of the file paths or list of (set of) slices for the images to be annotated.
142        output_folder: The folder where the segmentation results are saved.
143        model_type: The Segment Anything model to use. For details on the available models check out
144            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
145        embedding_path: Filepath where to save the embeddings.
146        tile_shape: Shape of tiles for tiled embedding prediction.
147            If `None` then the whole image is passed to Segment Anything.
148        halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
149        viewer: The viewer to which the SegmentAnything functionality should be added.
150            This enables using a pre-initialized viewer.
151        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
152        precompute_amg_state: Whether to precompute the state for automatic mask generation.
153            This will take more time when precomputing embeddings, but will then make
154            automatic mask generation much faster.
155        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
156        is_volumetric: Whether to use the 3d annotator.
157        prefer_decoder: Whether to use decoder based instance segmentation if
158            the model used has an additional decoder for instance segmentation.
159        skip_segmented: Whether to skip images that were already segmented.
160
161    Returns:
162        The napari viewer, only returned if `return_viewer=True`.
163    """
164    end_msg = "You have annotated the last image. Do you wish to close napari?"
165    os.makedirs(output_folder, exist_ok=True)
166
167    # Precompute embeddings and amg state (if corresponding options set).
168    predictor, decoder, embedding_paths = _precompute(
169        images, model_type,
170        embedding_path, tile_shape, halo, precompute_amg_state,
171        checkpoint_path=checkpoint_path, device=device,
172        ndim=3 if is_volumetric else 2, prefer_decoder=prefer_decoder,
173    )
174
175    next_image_id = 0
176    have_inputs_as_arrays = isinstance(images[next_image_id], np.ndarray)
177
178    def _get_save_path(image_path, current_idx):
179        if have_inputs_as_arrays:
180            fname = f"seg_{current_idx:05}.tif"
181        else:
182            fname = os.path.basename(image_path)
183            fname = os.path.splitext(fname)[0] + ".tif"
184        return os.path.join(output_folder, fname)
185
186    # Check which image to load next if we skip segmented images.
187    image_embedding_path = None
188    if skip_segmented:
189        while True:
190            if next_image_id == len(images):
191                print(end_msg)
192                return
193
194            save_path = _get_save_path(images[next_image_id], next_image_id)
195            if not os.path.exists(save_path):
196                print("The first image to annotate is image number", next_image_id)
197                image = images[next_image_id]
198                if not have_inputs_as_arrays:
199                    image = imageio.imread(image)
200                image_embedding_path = embedding_paths[next_image_id]
201                break
202
203            next_image_id += 1
204
205    # Initialize the viewer and annotator for this image.
206    state = AnnotatorState()
207    viewer, annotator = _initialize_annotator(
208        viewer, image, image_embedding_path,
209        model_type, halo, tile_shape, predictor, decoder, is_volumetric,
210        precompute_amg_state, checkpoint_path, device, embedding_path,
211    )
212
213    def _save_segmentation(image_path, current_idx, segmentation):
214        save_path = _get_save_path(image_path, next_image_id)
215        imageio.imwrite(save_path, segmentation, compression="zlib")
216
217    # Add functionality for going to the next image.
218    @magicgui(call_button="Next Image [N]")
219    def next_image(*args):
220        nonlocal next_image_id
221
222        segmentation = viewer.layers["committed_objects"].data
223        abort = False
224        if segmentation.sum() == 0:
225            msg = "Nothing is segmented yet. Do you wish to continue to the next image?"
226            abort = widgets._generate_message("info", msg)
227            if abort:
228                return
229
230        # Save the current segmentation.
231        _save_segmentation(images[next_image_id], next_image_id, segmentation)
232
233        # Clear the segmentation already to avoid lagging removal.
234        viewer.layers["committed_objects"].data = np.zeros_like(viewer.layers["committed_objects"].data)
235
236        # Go to the next images, if skipping images that are already segmented check if we have to load it.
237        next_image_id += 1
238        if skip_segmented:
239            save_path = _get_save_path(images[next_image_id], next_image_id)
240            while os.path.exists(save_path):
241                next_image_id += 1
242                if next_image_id == len(images):
243                    break
244                save_path = _get_save_path(images[next_image_id], next_image_id)
245
246        # Load the next image.
247        if next_image_id == len(images):
248            # Inform the user via dialog.
249            abort = widgets._generate_message("info", end_msg)
250            if not abort:
251                viewer.close()
252            return
253
254        print(
255            "Loading next image:", images[next_image_id] if not have_inputs_as_arrays else f"at index {next_image_id}"
256        )
257
258        if have_inputs_as_arrays:
259            image = images[next_image_id]
260        else:
261            image = imageio.imread(images[next_image_id])
262
263        image_embedding_path = embedding_paths[next_image_id]
264
265        # Set the new image in the viewer, state and annotator.
266        viewer.layers["image"].data = image
267
268        if state.amg is not None:
269            state.amg.clear_state()
270
271        state.initialize_predictor(
272            image, model_type=model_type, ndim=3 if is_volumetric else 2,
273            save_path=image_embedding_path,
274            tile_shape=tile_shape, halo=halo,
275            predictor=predictor, decoder=decoder,
276            precompute_amg_state=precompute_amg_state, device=device,
277            skip_load=False,
278        )
279        state.image_shape = _get_input_shape(image, is_volumetric)
280
281        annotator._update_image()
282
283    viewer.window.add_dock_widget(next_image)
284
285    @viewer.bind_key("n", overwrite=True)
286    def _next_image(viewer):
287        next_image(viewer)
288
289    if return_viewer:
290        return viewer
291    napari.run()

Run the annotation tool for a series of images (supported for both 2d and 3d images).

Arguments:
  • images: List of the file paths or list of (set of) slices for the images to be annotated.
  • output_folder: The folder where the segmentation results are saved.
  • model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
  • embedding_path: Filepath where to save the embeddings.
  • tile_shape: Shape of tiles for tiled embedding prediction. If None then the whole image is passed to Segment Anything.
  • halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
  • viewer: The viewer to which the SegmentAnything functionality should be added. This enables using a pre-initialized viewer.
  • return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
  • precompute_amg_state: Whether to precompute the state for automatic mask generation. This will take more time when precomputing embeddings, but will then make automatic mask generation much faster.
  • checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
  • is_volumetric: Whether to use the 3d annotator.
  • prefer_decoder: Whether to use decoder based instance segmentation if the model used has an additional decoder for instance segmentation.
  • skip_segmented: Whether to skip images that were already segmented.
Returns:

The napari viewer, only returned if return_viewer=True.

def image_folder_annotator( input_folder: str, output_folder: str, pattern: str = '*', viewer: Optional[napari.viewer.Viewer] = None, return_viewer: bool = False, **kwargs) -> Optional[napari.viewer.Viewer]:
294def image_folder_annotator(
295    input_folder: str,
296    output_folder: str,
297    pattern: str = "*",
298    viewer: Optional["napari.viewer.Viewer"] = None,
299    return_viewer: bool = False,
300    **kwargs
301) -> Optional["napari.viewer.Viewer"]:
302    """Run the 2d annotation tool for a series of images in a folder.
303
304    Args:
305        input_folder: The folder with the images to be annotated.
306        output_folder: The folder where the segmentation results are saved.
307        pattern: The glob patter for loading files from `input_folder`.
308            By default all files will be loaded.
309        viewer: The viewer to which the SegmentAnything functionality should be added.
310            This enables using a pre-initialized viewer.
311        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
312        kwargs: The keyword arguments for `micro_sam.sam_annotator.image_series_annotator`.
313
314    Returns:
315        The napari viewer, only returned if `return_viewer=True`.
316    """
317    image_files = sorted(glob(os.path.join(input_folder, pattern)))
318
319    return image_series_annotator(
320        image_files, output_folder, viewer=viewer, return_viewer=return_viewer, **kwargs
321    )

Run the 2d annotation tool for a series of images in a folder.

Arguments:
  • input_folder: The folder with the images to be annotated.
  • output_folder: The folder where the segmentation results are saved.
  • pattern: The glob patter for loading files from input_folder. By default all files will be loaded.
  • viewer: The viewer to which the SegmentAnything functionality should be added. This enables using a pre-initialized viewer.
  • return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
  • kwargs: The keyword arguments for micro_sam.sam_annotator.image_series_annotator.
Returns:

The napari viewer, only returned if return_viewer=True.

class ImageSeriesAnnotator(micro_sam.sam_annotator._widgets._WidgetBase):
324class ImageSeriesAnnotator(widgets._WidgetBase):
325    def __init__(self, viewer: napari.Viewer, parent=None):
326        super().__init__(parent=parent)
327        self._viewer = viewer
328
329        # Create the UI: the general options.
330        self._create_options()
331
332        # Add the settings (collapsible).
333        self.layout().addWidget(self._create_settings())
334
335        # Add the run button to trigger the embedding computation.
336        self.run_button = QtWidgets.QPushButton("Annotate Images")
337        self.run_button.clicked.connect(self.__call__)
338        self.layout().addWidget(self.run_button)
339
340    # model_type: str = util._DEFAULT_MODEL,
341    def _create_options(self):
342        self.folder = None
343        _, layout = self._add_path_param(
344            "folder", self.folder, "directory",
345            title="Input Folder", placeholder="Folder with images ...",
346            tooltip=get_tooltip("image_series_annotator", "folder")
347        )
348        self.layout().addLayout(layout)
349
350        self.output_folder = None
351        _, layout = self._add_path_param(
352            "output_folder", self.output_folder, "directory",
353            title="Output Folder", placeholder="Folder to save the results ...",
354            tooltip=get_tooltip("image_series_annotator", "output_folder")
355        )
356        self.layout().addLayout(layout)
357
358        self.model_type = util._DEFAULT_MODEL
359        model_options = list(util.models().urls.keys())
360        model_options = [model for model in model_options if not model.endswith("decoder")]
361        _, layout = self._add_choice_param(
362            "model_type", self.model_type, model_options, title="Model:",
363            tooltip=get_tooltip("embedding", "model")
364        )
365        self.layout().addLayout(layout)
366
367    def _create_settings(self):
368        setting_values = QtWidgets.QWidget()
369        setting_values.setLayout(QtWidgets.QVBoxLayout())
370
371        self.pattern = "*"
372        _, layout = self._add_string_param(
373            "pattern", self.pattern, tooltip=get_tooltip("image_series_annotator", "pattern")
374        )
375        setting_values.layout().addLayout(layout)
376
377        self.is_volumetric = False
378        setting_values.layout().addWidget(self._add_boolean_param(
379            "is_volumetric", self.is_volumetric, tooltip=get_tooltip("image_series_annotator", "is_volumetric")
380        ))
381
382        self.device = "auto"
383        device_options = ["auto"] + util._available_devices()
384        self.device_dropdown, layout = self._add_choice_param(
385            "device", self.device, device_options, tooltip=get_tooltip("embedding", "device")
386        )
387        setting_values.layout().addLayout(layout)
388
389        self.embeddings_save_path = None
390        _, layout = self._add_path_param(
391            "embeddings_save_path", self.embeddings_save_path, "directory", title="embeddings save path:",
392            tooltip=get_tooltip("embedding", "embeddings_save_path")
393        )
394        setting_values.layout().addLayout(layout)
395
396        self.custom_weights = None  # select_file
397        _, layout = self._add_path_param(
398            "custom_weights", self.custom_weights, "file", title="custom weights path:",
399            tooltip=get_tooltip("embedding", "custom_weights")
400        )
401        setting_values.layout().addLayout(layout)
402
403        self.tile_x, self.tile_y = 0, 0
404        self.tile_x_param, self.tile_y_param, layout = self._add_shape_param(
405            ("tile_x", "tile_y"), (self.tile_x, self.tile_y), min_val=0, max_val=2048, step=16,
406            tooltip=get_tooltip("embedding", "tiling")
407        )
408        setting_values.layout().addLayout(layout)
409
410        self.halo_x, self.halo_y = 0, 0
411        self.halo_x_param, self.halo_y_param, layout = self._add_shape_param(
412            ("halo_x", "halo_y"), (self.halo_x, self.halo_y), min_val=0, max_val=512,
413            tooltip=get_tooltip("embedding", "halo")
414        )
415        setting_values.layout().addLayout(layout)
416
417        settings = widgets._make_collapsible(setting_values, title="Advanced Settings")
418        return settings
419
420    def _validate_inputs(self):
421        missing_data = self.folder is None or len(glob(os.path.join(self.folder, self.pattern))) == 0
422        missing_output = self.output_folder is None
423        if missing_data or missing_output:
424            msg = ""
425            if missing_data:
426                msg += "The input folder is missing or empty. "
427            if missing_output:
428                msg += "The output folder is missing."
429            return widgets._generate_message("error", msg)
430        return False
431
432    def __call__(self, skip_validate=False):
433        if not skip_validate and self._validate_inputs():
434            return
435        tile_shape, halo = widgets._process_tiling_inputs(self.tile_x, self.tile_y, self.halo_x, self.halo_y)
436
437        image_folder_annotator(
438            self.folder, self.output_folder, self.pattern,
439            model_type=self.model_type,
440            embedding_path=self.embeddings_save_path,
441            tile_shape=tile_shape, halo=halo, checkpoint_path=self.custom_weights,
442            device=self.device, is_volumetric=self.is_volumetric,
443            viewer=self._viewer, return_viewer=True,
444        )

QWidget(parent: typing.Optional[QWidget] = None, flags: Union[Qt.WindowFlags, Qt.WindowType] = Qt.WindowFlags())

ImageSeriesAnnotator(viewer: napari.viewer.Viewer, parent=None)
325    def __init__(self, viewer: napari.Viewer, parent=None):
326        super().__init__(parent=parent)
327        self._viewer = viewer
328
329        # Create the UI: the general options.
330        self._create_options()
331
332        # Add the settings (collapsible).
333        self.layout().addWidget(self._create_settings())
334
335        # Add the run button to trigger the embedding computation.
336        self.run_button = QtWidgets.QPushButton("Annotate Images")
337        self.run_button.clicked.connect(self.__call__)
338        self.layout().addWidget(self.run_button)
run_button
Inherited Members
PyQt5.QtWidgets.QWidget
RenderFlag
RenderFlags
acceptDrops
accessibleDescription
accessibleName
actionEvent
actions
activateWindow
addAction
addActions
adjustSize
autoFillBackground
backgroundRole
baseSize
changeEvent
childAt
childrenRect
childrenRegion
clearFocus
clearMask
close
closeEvent
contentsMargins
contentsRect
contextMenuEvent
contextMenuPolicy
create
createWindowContainer
cursor
destroy
devType
dragEnterEvent
dragLeaveEvent
dragMoveEvent
dropEvent
effectiveWinId
ensurePolished
enterEvent
event
find
focusInEvent
focusNextChild
focusNextPrevChild
focusOutEvent
focusPolicy
focusPreviousChild
focusProxy
focusWidget
font
fontInfo
fontMetrics
foregroundRole
frameGeometry
frameSize
geometry
getContentsMargins
grab
grabGesture
grabKeyboard
grabMouse
grabShortcut
graphicsEffect
graphicsProxyWidget
hasFocus
hasHeightForWidth
hasMouseTracking
hasTabletTracking
height
heightForWidth
hide
hideEvent
initPainter
inputMethodEvent
inputMethodHints
inputMethodQuery
insertAction
insertActions
isActiveWindow
isAncestorOf
isEnabled
isEnabledTo
isFullScreen
isHidden
isLeftToRight
isMaximized
isMinimized
isModal
isRightToLeft
isVisible
isVisibleTo
isWindow
isWindowModified
keyPressEvent
keyReleaseEvent
keyboardGrabber
layout
layoutDirection
leaveEvent
locale
lower
mapFrom
mapFromGlobal
mapFromParent
mapTo
mapToGlobal
mapToParent
mask
maximumHeight
maximumSize
maximumWidth
metric
minimumHeight
minimumSize
minimumSizeHint
minimumWidth
mouseDoubleClickEvent
mouseGrabber
mouseMoveEvent
mousePressEvent
mouseReleaseEvent
move
moveEvent
nativeEvent
nativeParentWidget
nextInFocusChain
normalGeometry
overrideWindowFlags
overrideWindowState
paintEngine
paintEvent
palette
parentWidget
pos
previousInFocusChain
raise_
rect
releaseKeyboard
releaseMouse
releaseShortcut
removeAction
render
repaint
resize
resizeEvent
restoreGeometry
saveGeometry
screen
scroll
setAcceptDrops
setAccessibleDescription
setAccessibleName
setAttribute
setAutoFillBackground
setBackgroundRole
setBaseSize
setContentsMargins
setContextMenuPolicy
setCursor
setDisabled
setEnabled
setFixedHeight
setFixedSize
setFixedWidth
setFocus
setFocusPolicy
setFocusProxy
setFont
setForegroundRole
setGeometry
setGraphicsEffect
setHidden
setInputMethodHints
setLayout
setLayoutDirection
setLocale
setMask
setMaximumHeight
setMaximumSize
setMaximumWidth
setMinimumHeight
setMinimumSize
setMinimumWidth
setMouseTracking
setPalette
setParent
setShortcutAutoRepeat
setShortcutEnabled
setSizeIncrement
setSizePolicy
setStatusTip
setStyle
setStyleSheet
setTabOrder
setTabletTracking
setToolTip
setToolTipDuration
setUpdatesEnabled
setVisible
setWhatsThis
setWindowFilePath
setWindowFlag
setWindowFlags
setWindowIcon
setWindowIconText
setWindowModality
setWindowModified
setWindowOpacity
setWindowRole
setWindowState
setWindowTitle
sharedPainter
show
showEvent
showFullScreen
showMaximized
showMinimized
showNormal
size
sizeHint
sizeIncrement
sizePolicy
stackUnder
statusTip
style
styleSheet
tabletEvent
testAttribute
toolTip
toolTipDuration
underMouse
ungrabGesture
unsetCursor
unsetLayoutDirection
unsetLocale
update
updateGeometry
updateMicroFocus
updatesEnabled
visibleRegion
whatsThis
wheelEvent
width
winId
window
windowFilePath
windowFlags
windowHandle
windowIcon
windowIconText
windowModality
windowOpacity
windowRole
windowState
windowTitle
windowType
x
y
DrawChildren
DrawWindowBackground
IgnoreMask
windowIconTextChanged
windowIconChanged
windowTitleChanged
customContextMenuRequested
PyQt5.QtCore.QObject
blockSignals
childEvent
children
connectNotify
customEvent
deleteLater
disconnect
disconnectNotify
dumpObjectInfo
dumpObjectTree
dynamicPropertyNames
eventFilter
findChild
findChildren
inherits
installEventFilter
isSignalConnected
isWidgetType
isWindowType
killTimer
metaObject
moveToThread
objectName
parent
property
pyqtConfigure
receivers
removeEventFilter
sender
senderSignalIndex
setObjectName
setProperty
signalsBlocked
startTimer
thread
timerEvent
tr
staticMetaObject
objectNameChanged
destroyed
PyQt5.QtGui.QPaintDevice
PaintDeviceMetric
colorCount
depth
devicePixelRatio
devicePixelRatioF
devicePixelRatioFScale
heightMM
logicalDpiX
logicalDpiY
paintingActive
physicalDpiX
physicalDpiY
widthMM
PdmDepth
PdmDevicePixelRatio
PdmDevicePixelRatioScaled
PdmDpiX
PdmDpiY
PdmHeight
PdmHeightMM
PdmNumColors
PdmPhysicalDpiX
PdmPhysicalDpiY
PdmWidth
PdmWidthMM