micro_sam.sam_annotator.util

  1import os
  2import pickle
  3import warnings
  4import argparse
  5from glob import glob
  6from pathlib import Path
  7from typing import List, Optional, Tuple
  8
  9import h5py
 10import napari
 11import numpy as np
 12from skimage import draw
 13from scipy.ndimage import shift
 14
 15from .. import prompt_based_segmentation, util
 16from .. import _model_settings as model_settings
 17from ..multi_dimensional_segmentation import _validate_projection
 18
 19# Green and Red
 20LABEL_COLOR_CYCLE = ["#00FF00", "#FF0000"]
 21"""@private"""
 22
 23
 24#
 25# Misc helper functions
 26#
 27
 28
 29def toggle_label(prompts):
 30    """@private"""
 31    # get the currently selected label
 32    current_properties = prompts.current_properties
 33    current_label = current_properties["label"][0]
 34    new_label = "negative" if current_label == "positive" else "positive"
 35    current_properties["label"] = np.array([new_label])
 36    prompts.current_properties = current_properties
 37    prompts.refresh()
 38    prompts.refresh_colors()
 39
 40
 41def _initialize_parser(description, with_segmentation_result=True, with_instance_segmentation=True):
 42
 43    available_models = list(util.get_model_names())
 44    available_models = ", ".join(available_models)
 45
 46    parser = argparse.ArgumentParser(description=description)
 47
 48    parser.add_argument(
 49        "-i", "--input", required=True,
 50        help="The filepath to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) "
 51        "or elf.io.open_file (e.g. hdf5, zarr, mrc). For the latter you also need to pass the 'key' parameter."
 52    )
 53    parser.add_argument(
 54        "-k", "--key",
 55        help="The key for opening data with elf.io.open_file. This is the internal path for a hdf5 or zarr container, "
 56        "for a image series it is a wild-card, e.g. '*.png' and for mrc it is 'data'."
 57    )
 58    parser.add_argument(
 59        "-e", "--embedding_path",
 60        help="The filepath for saving/loading the pre-computed image embeddings. "
 61        "It is recommended to pass this argument and store the embeddings if you want to open the annotator "
 62        "multiple times for this image. Otherwise the embeddings will be recomputed every time."
 63    )
 64
 65    if with_segmentation_result:
 66        parser.add_argument(
 67            "-s", "--segmentation_result",
 68            help="Optional filepath to a precomputed segmentation. If passed this will be used to initialize the "
 69            "'committed_objects' layer. This can be useful if you want to correct an existing segmentation or if you "
 70            "have saved intermediate results from the annotator and want to continue with your annotations. "
 71            "Supports the same file formats as 'input'."
 72        )
 73        parser.add_argument(
 74            "-sk", "--segmentation_key",
 75            help="The key for opening the segmentation data. Same rules as for 'key' apply."
 76        )
 77
 78    parser.add_argument(
 79        "-m", "--model_type", default=util._DEFAULT_MODEL,
 80        help=f"The segment anything model that will be used, one of {available_models}."
 81    )
 82    parser.add_argument(
 83        "-c", "--checkpoint", default=None,
 84        help="Checkpoint from which the SAM model will be loaded."
 85    )
 86    parser.add_argument(
 87        "--decoder_path", default=None,
 88        help="Optional checkpoint path to decoder-only weights to enable decoder-based instance segmentation."
 89    )
 90    parser.add_argument(
 91        "-d", "--device", default=None,
 92        help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)."
 93        "By default the most performant available device will be selected."
 94    )
 95    parser.add_argument(
 96        "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction", default=None
 97    )
 98    parser.add_argument(
 99        "--halo", nargs="+", type=int, help="The halo for using tiled prediction", default=None
100    )
101
102    if with_instance_segmentation:
103        parser.add_argument(
104            "--precompute_amg_state", action="store_true",
105            help="Whether to precompute the state for automatic instance segmentation. "
106            "This will lead to a longer start-up time, but the automatic instance segmentation can "
107            "be run directly once the tool has started."
108        )
109        parser.add_argument(
110            "--prefer_decoder", action="store_false",
111            help="Whether to use decoder based instance segmentation if the model "
112            "being used has an additional decoder for that purpose."
113        )
114
115    return parser
116
117
118def clear_annotations(viewer: napari.Viewer, clear_segmentations=True) -> None:
119    """@private"""
120    viewer.layers["point_prompts"].data = []
121    viewer.layers["point_prompts"].refresh()
122    if "prompts" in viewer.layers:
123        # Select all prompts and then remove them.
124        # This is how it worked before napari 0.5.
125        # viewer.layers["prompts"].data = []
126        viewer.layers["prompts"].selected_data = set(range(len(viewer.layers["prompts"].data)))
127        viewer.layers["prompts"].remove_selected()
128        viewer.layers["prompts"].refresh()
129    if not clear_segmentations:
130        return
131    viewer.layers["current_object"].data = np.zeros(viewer.layers["current_object"].data.shape, dtype="uint32")
132    viewer.layers["current_object"].refresh()
133
134
135def clear_annotations_slice(viewer: napari.Viewer, i: int, clear_segmentations=True) -> None:
136    """@private"""
137    point_prompts = viewer.layers["point_prompts"].data
138    point_prompts = point_prompts[point_prompts[:, 0] != i]
139    viewer.layers["point_prompts"].data = point_prompts
140    viewer.layers["point_prompts"].refresh()
141    if "prompts" in viewer.layers:
142        prompts = viewer.layers["prompts"].data
143        prompts = [prompt for prompt in prompts if not (prompt[:, 0] == i).all()]
144        viewer.layers["prompts"].data = prompts
145        viewer.layers["prompts"].refresh()
146    if not clear_segmentations:
147        return
148    viewer.layers["current_object"].data[i] = 0
149    viewer.layers["current_object"].refresh()
150
151
152#
153# Helper functions to extract prompts from napari layers.
154#
155
156
157def point_layer_to_prompts(
158    layer: napari.layers.Points, i=None, track_id=None, with_stop_annotation=True,
159) -> Optional[Tuple[np.ndarray, np.ndarray]]:
160    """Extract point prompts for SAM from a napari point layer.
161
162    Args:
163        layer: The point layer from which to extract the prompts.
164        i: Index for the data (required for 3d or timeseries data).
165        track_id: Id of the current track (required for tracking data).
166        with_stop_annotation: Whether a single negative point will be interpreted
167            as stop annotation or just returned as normal prompt.
168
169    Returns:
170        The point coordinates for the prompts.
171        The labels (positive or negative / 1 or 0) for the prompts.
172    """
173
174    points = layer.data
175    labels = layer.properties["label"]
176    assert len(points) == len(labels)
177
178    if i is None:
179        assert points.shape[1] == 2, f"{points.shape}"
180        this_points, this_labels = points, labels
181    else:
182        assert points.shape[1] == 3, f"{points.shape}"
183        mask = np.round(points[:, 0]) == i
184        this_points = points[mask][:, 1:]
185        this_labels = labels[mask]
186    assert len(this_points) == len(this_labels)
187
188    if track_id is not None:
189        assert i is not None
190        track_ids = np.array(list(map(int, layer.properties["track_id"])))[mask]
191        track_id_mask = track_ids == track_id
192        this_labels, this_points = this_labels[track_id_mask], this_points[track_id_mask]
193    assert len(this_points) == len(this_labels)
194
195    this_labels = np.array([1 if label == "positive" else 0 for label in this_labels])
196    # a single point with a negative label is interpreted as 'stop' signal
197    # in this case we return None
198    if with_stop_annotation and (len(this_points) == 1 and this_labels[0] == 0):
199        return None
200
201    return this_points, this_labels
202
203
204def shape_layer_to_prompts(
205    layer: napari.layers.Shapes, shape: Tuple[int, int], i=None, track_id=None
206) -> Tuple[List[np.ndarray], List[Optional[np.ndarray]]]:
207    """Extract prompts for SAM from a napari shape layer.
208
209    Extracts the bounding box for 'rectangle' shapes and the bounding box and corresponding mask
210    for 'ellipse' and 'polygon' shapes.
211
212    Args:
213        prompt_layer: The napari shape layer.
214        shape: The image shape.
215        i: Index for the data (required for 3d or timeseries data).
216        track_id: Id of the current track (required for tracking data).
217
218    Returns:
219        The box prompts.
220        The mask prompts.
221    """
222
223    def _to_prompts(shape_data, shape_types):
224        boxes, masks = [], []
225
226        for data, type_ in zip(shape_data, shape_types):
227
228            if type_ == "rectangle":
229                boxes.append(data)
230                masks.append(None)
231
232            elif type_ == "ellipse":
233                boxes.append(data)
234                center = np.mean(data, axis=0)
235                radius_r = ((data[2] - data[1]) / 2)[0]
236                radius_c = ((data[1] - data[0]) / 2)[1]
237                rr, cc = draw.ellipse(center[0], center[1], radius_r, radius_c, shape=shape)
238                mask = np.zeros(shape, dtype=bool)
239                mask[rr, cc] = 1
240                masks.append(mask)
241
242            elif type_ == "polygon":
243                boxes.append(data)
244                rr, cc = draw.polygon(data[:, 0], data[:, 1], shape=shape)
245                mask = np.zeros(shape, dtype=bool)
246                mask[rr, cc] = 1
247                masks.append(mask)
248
249            else:
250                warnings.warn(f"Shape type {type_} is not supported and will be ignored.")
251
252        # map to correct box format
253        boxes = [
254            np.array([box[:, 0].min(), box[:, 1].min(), box[:, 0].max(), box[:, 1].max()]) for box in boxes
255        ]
256        return boxes, masks
257
258    shape_data, shape_types = layer.data, layer.shape_type
259    assert len(shape_data) == len(shape_types)
260    if len(shape_data) == 0:
261        return [], []
262
263    if i is not None:
264        if track_id is None:
265            prompt_selection = [j for j, data in enumerate(shape_data) if (data[:, 0] == i).all()]
266        else:
267            track_ids = np.array(list(map(int, layer.properties["track_id"])))
268            prompt_selection = [
269                j for j, (data, this_track_id) in enumerate(zip(shape_data, track_ids))
270                if ((data[:, 0] == i).all() and this_track_id == track_id)
271            ]
272
273        shape_data = [shape_data[j][:, 1:] for j in prompt_selection]
274        shape_types = [shape_types[j] for j in prompt_selection]
275
276    boxes, masks = _to_prompts(shape_data, shape_types)
277    return boxes, masks
278
279
280def prompt_layer_to_state(prompt_layer: napari.layers.Points, i: int) -> str:
281    """Get the state of the track from a point layer for a given timeframe.
282
283    Only relevant for annotator_tracking.
284
285    Args:
286        prompt_layer: The napari layer.
287        i: Timeframe of the data.
288
289    Returns:
290        The state of this frame (either "division" or "track").
291    """
292    state = prompt_layer.properties["state"]
293
294    points = prompt_layer.data
295    assert points.shape[1] == 3, f"{points.shape}"
296    mask = points[:, 0] == i
297    this_points = points[mask][:, 1:]
298    this_state = state[mask]
299    assert len(this_points) == len(this_state)
300
301    # we set the state to 'division' if at least one point in this frame has a division label
302    if any(st == "division" for st in this_state):
303        return "division"
304    else:
305        return "track"
306
307
308def prompt_layers_to_state(point_layer: napari.layers.Points, box_layer: napari.layers.Shapes, i: int) -> str:
309    """Get the state of the track from a point layer and shape layer for a given timeframe.
310
311    Only relevant for annotator_tracking.
312
313    Args:
314        point_layer: The napari point layer.
315        box_layer: The napari box layer.
316        i: Timeframe of the data.
317
318    Returns:
319        The state of this frame (either "division" or "track").
320    """
321    state = point_layer.properties["state"]
322
323    points = point_layer.data
324    assert points.shape[1] == 3, f"{points.shape}"
325    mask = points[:, 0] == i
326    if mask.sum() > 0:
327        this_state = state[mask].tolist()
328    else:
329        this_state = []
330
331    box_states = box_layer.properties["state"]
332    this_box_states = [
333        state for box, state in zip(box_layer.data, box_states)
334        if (box[:, 0] == i).all()
335    ]
336    this_state.extend(this_box_states)
337
338    # we set the state to 'division' if at least one point in this frame has a division label
339    if any(st == "division" for st in this_state):
340        return "division"
341    else:
342        return "track"
343
344
345#
346# Helper functions to run (multi-dimensional) segmentation on napari layers.
347#
348
349
350def segment_slices_with_prompts(
351    predictor, point_prompts, box_prompts, image_embeddings, shape, track_id=None, update_progress=None,
352):
353    """@private"""
354    assert len(shape) == 3
355    image_shape = shape[1:]
356    seg = np.zeros(shape, dtype="uint32")
357
358    z_values = np.round(point_prompts.data[:, 0])
359    z_values_boxes = np.concatenate([box[:1, 0] for box in box_prompts.data]) if box_prompts.data else\
360        np.zeros(0, dtype="int")
361
362    if track_id is not None:
363        track_ids_points = np.array(list(map(int, point_prompts.properties["track_id"])))
364        assert len(track_ids_points) == len(z_values)
365        z_values = z_values[track_ids_points == track_id]
366
367        if len(z_values_boxes) > 0:
368            track_ids_boxes = np.array(list(map(int, box_prompts.properties["track_id"])))
369            assert len(track_ids_boxes) == len(z_values_boxes), f"{len(track_ids_boxes)}, {len(z_values_boxes)}"
370            z_values_boxes = z_values_boxes[track_ids_boxes == track_id]
371
372    slices = np.unique(np.concatenate([z_values, z_values_boxes])).astype("int")
373    stop_lower, stop_upper = False, False
374
375    if update_progress is None:
376        def update_progress(*args):
377            pass
378
379    for i in slices:
380        points_i = point_layer_to_prompts(point_prompts, i, track_id)
381
382        # do we end the segmentation at the outer slices?
383        if points_i is None:
384
385            if i == slices[0]:  # The bottom slice is a stop slice.
386                stop_lower = True
387                seg[i] = 0
388            elif i == slices[-1]:  # The top sloce is a stop slice.
389                stop_upper = True
390                seg[i] = 0
391            else:  # We have a stop annotation somewhere in the middle. Ignore this.
392                # Remove this slice from the annotated slices, so that it is segmented via
393                # projection in the next step.
394                slices = np.setdiff1d(slices, i)
395                print(f"You have provided a stop annotation (single red point) in slice {i},")
396                print("but you have annotated slices above or below it. This stop annotation will")
397                print(f"be ignored and the slice {i} will be segmented normally.")
398
399            update_progress(1)
400            continue
401
402        boxes, masks = shape_layer_to_prompts(box_prompts, image_shape, i=i, track_id=track_id)
403        points, labels = points_i
404
405        seg_i = prompt_segmentation(
406            predictor, points, labels, boxes, masks, image_shape, multiple_box_prompts=False,
407            image_embeddings=image_embeddings, i=i
408        )
409        if seg_i is None:
410            print(f"The prompts at slice or frame {i} are invalid and the segmentation was skipped.")
411            print("This will lead to a wrong segmentation across slices or frames.")
412            print(f"Please correct the prompts in {i} and rerun the segmentation.")
413            continue
414
415        seg[i] = seg_i
416        update_progress(1)
417
418    return seg, slices, stop_lower, stop_upper
419
420
421# For advanced batching: match prompts to already segmented objects and continue segmentation.
422def _match_prompts(previous_segmentation, points, boxes, seg_ids):
423    # Create a mapping between ids and prompts.
424    batched_prompts = {}
425    # seg_boundaries = find_boundaries(previous_segmentation, mode="inner")
426    # indices = distance_transform_edt(seg_boundaries, return_distance=False, return_index=True)
427    return batched_prompts
428
429
430def _batched_interactive_segmentation(predictor, points, labels, boxes, image_embeddings, i, previous_segmentation):
431    prev_seg = previous_segmentation if i is None else previous_segmentation[i]
432    seg = np.zeros(prev_seg.shape, dtype="uint32")
433
434    # seg_ids = np.unique(previous_segmentation)
435    # assert seg_ids[0] == 0
436
437    batched_points, batched_labels = [], []
438    negative_points, negative_labels = [], []
439    for j in range(len(points)):
440        if labels[j] == 1:  # positive point
441            batched_points.append(points[j:j+1])
442            batched_labels.append(labels[j:j+1])
443        else:  # negative points
444            negative_points.append(points[j:j+1])
445            negative_labels.append(labels[j:j+1])
446
447    batched_prompts = [(None, point, label) for point, label in zip(batched_points, batched_labels)]
448    batched_prompts.extend([(box, None, None) for box in boxes])
449    batched_prompts = {i: prompt for i, prompt in enumerate(batched_prompts, 1)}
450
451    # For advanced batching: match prompts to already segmented objects and continue segmentation.
452    # (This is left here as a reference for how this can be implemented.
453    #  I have not decided yet if this is actually a good idea or not.)
454    # # If we have no objects: this is the first call for a batched segmentation.
455    # # We treat each positive point or box as a separate object.
456    # if len(seg_ids) == 1:
457    #     # Create a list of all prompts.
458    #     batched_prompts = [(None, point, label) for point, label in zip(batched_points, batched_labels)]
459    #     batched_prompts.extend([(box, None, None) for box in boxes])
460    #     batched_prompts = {i: prompt for i, prompt in enumerate(batched_prompts, 1)}
461
462    # # Otherwise we match the prompts to existing objects.
463    # else:
464    #     batched_prompts = _match_prompts(prev_seg, batched_points, boxes, seg_ids)
465
466    for seg_id, prompt in batched_prompts.items():
467        box, point, label = prompt
468        if len(negative_points) > 0:
469            if point is None:
470                point, label = negative_points, negative_labels
471            else:
472                point = np.concatenate([point] + negative_points)
473                label = np.concatenate([label] + negative_labels)
474
475        if (box is not None) and (point is not None):
476            prediction = prompt_based_segmentation.segment_from_box_and_points(
477                predictor, box, point, label, image_embeddings=image_embeddings, i=i
478            ).squeeze()
479        elif (box is not None) and (point is None):
480            prediction = prompt_based_segmentation.segment_from_box(
481                predictor, box, image_embeddings=image_embeddings, i=i
482            ).squeeze()
483        else:
484            prediction = prompt_based_segmentation.segment_from_points(
485                predictor, point, label, image_embeddings=image_embeddings, i=i
486            ).squeeze()
487
488        seg[prediction] = seg_id
489
490    return seg
491
492
493def prompt_segmentation(
494    predictor, points, labels, boxes, masks, shape, multiple_box_prompts,
495    image_embeddings=None, i=None, box_extension=0, batched=None, previous_segmentation=None,
496):
497    """@private"""
498    assert len(points) == len(labels)
499    have_points = len(points) > 0
500    have_boxes = len(boxes) > 0
501
502    # No prompts were given, return None.
503    if not have_points and not have_boxes:
504        return
505
506    # Batched interactive segmentation.
507    elif batched:
508        assert previous_segmentation is not None
509        seg = _batched_interactive_segmentation(
510            predictor, points, labels, boxes, image_embeddings, i, previous_segmentation
511        )
512
513    # Box and point prompts were given.
514    elif have_points and have_boxes:
515        if len(boxes) > 1:
516            print("You have provided point prompts and more than one box prompt.")
517            print("This setting is currently not supported.")
518            print("When providing both points and prompts you can only segment one object at a time.")
519            return
520        mask = masks[0]
521        if mask is None:
522            seg = prompt_based_segmentation.segment_from_box_and_points(
523                predictor, boxes[0], points, labels, image_embeddings=image_embeddings, i=i
524            ).squeeze()
525        else:
526            seg = prompt_based_segmentation.segment_from_mask(
527                predictor, mask, box=boxes[0], points=points, labels=labels, image_embeddings=image_embeddings, i=i
528            ).squeeze()
529
530    # Only point prompts were given.
531    elif have_points and not have_boxes:
532        seg = prompt_based_segmentation.segment_from_points(
533            predictor, points, labels, image_embeddings=image_embeddings, i=i
534        ).squeeze()
535
536    # Only box prompts were given.
537    elif not have_points and have_boxes:
538        seg = np.zeros(shape, dtype="uint32")
539
540        if len(boxes) > 1 and not multiple_box_prompts:
541            print("You have provided more than one box annotation. This is not yet supported in the 3d annotator.")
542            print("You can only segment one object at a time in 3d.")
543            return
544
545        # Batch this?
546        for seg_id, (box, mask) in enumerate(zip(boxes, masks), 1):
547            if mask is None:
548                prediction = prompt_based_segmentation.segment_from_box(
549                    predictor, box, image_embeddings=image_embeddings, i=i
550                ).squeeze()
551            else:
552                prediction = prompt_based_segmentation.segment_from_mask(
553                    predictor, mask, box=box, image_embeddings=image_embeddings, i=i,
554                    box_extension=box_extension,
555                ).squeeze()
556            seg[prediction] = seg_id
557
558    return seg
559
560
561def _compute_movement(seg, t0, t1):
562
563    def compute_center(t):
564        # computation with center of mass
565        center = np.where(seg[t] == 1)
566        center = np.array([np.mean(center[0]), np.mean(center[1])])
567        return center
568
569    center0 = compute_center(t0)
570    center1 = compute_center(t1)
571
572    move = center0 - center1
573    return move.astype("float64")
574
575
576def _shift_object(mask, motion_model):
577    mask_shifted = np.zeros_like(mask)
578    shift(mask, motion_model, output=mask_shifted, order=0, prefilter=False)
579    return mask_shifted
580
581
582def track_from_prompts(
583    point_prompts, box_prompts, seg, predictor, slices, image_embeddings,
584    stop_upper, threshold, projection, motion_smoothing=0.5, box_extension=0, update_progress=None,
585):
586    """@private
587    """
588    use_box, use_mask, use_points, use_single_point = _validate_projection(projection)
589
590    if update_progress is None:
591        def update_progress(*args):
592            pass
593
594    # shift the segmentation based on the motion model and update the motion model
595    def _update_motion_model(seg, t, t0, motion_model):
596        if t in (t0, t0 + 1):  # this is the first or second frame, we don't have a motion yet
597            pass
598        elif t == t0 + 2:  # this the third frame, we initialize the motion model
599            current_move = _compute_movement(seg, t - 1, t - 2)
600            motion_model = current_move
601        else:  # we already have a motion model and update it
602            current_move = _compute_movement(seg, t - 1, t - 2)
603            alpha = motion_smoothing
604            motion_model = alpha * motion_model + (1 - alpha) * current_move
605
606        return motion_model
607
608    has_division = False
609    motion_model = None
610    verbose = False
611
612    t0 = int(slices.min())
613    t = t0 + 1
614    while True:
615
616        # update the motion model
617        motion_model = _update_motion_model(seg, t, t0, motion_model)
618
619        # use the segmentation from prompts if we are in a slice with prompts
620        if t in slices:
621            seg_prev = None
622            seg_t = seg[t]
623            # currently using the box layer doesn't work for keeping track of the track state
624            # track_state = prompt_layers_to_state(point_prompts, box_prompts, t)
625            track_state = prompt_layer_to_state(point_prompts, t)
626
627        # otherwise project the mask (under the motion model) and segment the next slice from the mask
628        else:
629            if verbose:
630                print(f"Tracking object in frame {t} with movement {motion_model}")
631
632            seg_prev = seg[t - 1]
633            # shift the segmentation according to the motion model
634            if motion_model is not None:
635                seg_prev = _shift_object(seg_prev, motion_model)
636
637            seg_t = prompt_based_segmentation.segment_from_mask(
638                predictor, seg_prev, image_embeddings=image_embeddings, i=t,
639                use_mask=use_mask, use_box=use_box, use_points=use_points,
640                box_extension=box_extension, use_single_point=use_single_point,
641            )
642            track_state = "track"
643
644            # are we beyond the last slice with prompt?
645            # if no: we continue tracking because we know we need to connect to a future frame
646            # if yes: we only continue tracking if overlaps are above the threshold
647            if t < slices[-1]:
648                seg_prev = None
649
650            update_progress(1)
651
652        if (threshold is not None) and (seg_prev is not None):
653            iou = util.compute_iou(seg_prev, seg_t)
654            if iou < threshold:
655                msg = f"Segmentation stopped at frame {t} due to IOU {iou} < {threshold}."
656                print(msg)
657                break
658
659        # stop if we have a division
660        if track_state == "division":
661            has_division = True
662            break
663
664        seg[t] = seg_t
665        t += 1
666
667        # stop tracking if we have stop upper set (i.e. single negative point was set to indicate stop track)
668        if t == slices[-1] and stop_upper:
669            break
670
671        # stop if we are at the last slce
672        if t == seg.shape[0]:
673            break
674
675    return seg, has_division
676
677
678def _sync_embedding_widget(widget, model_type, save_path, checkpoint_path, device, tile_shape, halo):
679
680    # Update the index for model family, eg. 'Natural Images (SAM)', 'Light Microscopy', etc.
681    supported_dropdown_maps = {
682        "lm": "Light Microscopy",
683        "em_organelles": "Electron Microscopy",
684        "medical_imaging": "Medical Imaging",
685        "histopathology": "Histopathology",
686    }
687
688    model_family = "Natural Images (SAM)"  # If no suffix patterns match, stick to 'Natural Images (SAM)' family.
689    for k, v in supported_dropdown_maps.items():
690        if model_type.endswith(k):
691            model_family = v
692            break
693
694    index = widget.model_family_dropdown.findText(model_family)
695    if index > 0:
696        widget.model_family_dropdown.setCurrentIndex(index)
697
698    # Update the index for model size, eg. 'base', 'tiny', etc.
699    size_map = {"t": "tiny", "b": "base", "l": "large", "h": "huge"}
700    model_size = size_map[model_type[4]]
701
702    index = widget.model_size_dropdown.findText(model_size)
703    if index > 0:
704        widget.model_size_dropdown.setCurrentIndex(index)
705
706    if save_path is not None and isinstance(save_path, str):
707        widget.embeddings_save_path_param.setText(str(save_path))
708
709    if checkpoint_path is not None:
710        widget.custom_weights_param.setText(str(checkpoint_path))
711
712    if device is not None:
713        widget.device = device
714        index = widget.device_dropdown.findText(device)
715        widget.device_dropdown.setCurrentIndex(index)
716
717    if tile_shape is not None:
718        widget.tile_x_param.setValue(tile_shape[0])
719        widget.tile_y_param.setValue(tile_shape[1])
720
721    if halo is not None:
722        widget.halo_x_param.setValue(halo[0])
723        widget.halo_y_param.setValue(halo[1])
724
725
726# Read parameters from checkpoint path if it is given instead.
727def _sync_autosegment_widget(widget, model_type, checkpoint_path, update_decoder=None):
728    if update_decoder is not None:
729        widget._reset_segmentation_mode(update_decoder)
730
731    if widget.with_decoder:
732        settings = model_settings.AIS_SETTINGS.get(model_type, {})
733        params = ("center_distance_thresh", "boundary_distance_thresh")
734        for param in params:
735            if param in settings:
736                getattr(widget, f"{param}_param").setValue(settings[param])
737    else:
738        settings = model_settings.AMG_SETTINGS.get(model_type, {})
739        params = ("pred_iou_thresh", "stability_score_thresh", "min_object_size")
740        for param in params:
741            if param in settings:
742                getattr(widget, f"{param}_param").setValue(settings[param])
743
744
745# Read parameters from checkpoint path if it is given instead.
746def _sync_ndsegment_widget(widget, model_type, checkpoint_path):
747    settings = model_settings.ND_SEGMENT_SETTINGS.get(model_type, {})
748
749    if "projection_mode" in settings:
750        projection_mode = settings["projection_mode"]
751        widget.projection = projection_mode
752        index = widget.projection_dropdown.findText(projection_mode)
753        if index > 0:
754            widget.projection_dropdown.setCurrentIndex(index)
755
756    params = ("iou_threshold", "box_extension")
757    for param in params:
758        if param in settings:
759            getattr(widget, f"{param}_param").setValue(settings[param])
760
761
762def _load_amg_state(embedding_path):
763    if embedding_path is None or not os.path.exists(embedding_path):
764        return {"cache_folder": None}
765
766    cache_folder = os.path.join(embedding_path, "amg_state")
767    os.makedirs(cache_folder, exist_ok=True)
768    amg_state = {"cache_folder": cache_folder}
769
770    state_paths = glob(os.path.join(cache_folder, "*.pkl"))
771    for path in state_paths:
772        with open(path, "rb") as f:
773            state = pickle.load(f)
774        i = int(Path(path).stem.split("-")[-1])
775        amg_state[i] = state
776    return amg_state
777
778
779def _load_is_state(embedding_path):
780    if embedding_path is None or not os.path.exists(embedding_path):
781        return {"cache_path": None}
782
783    cache_path = os.path.join(embedding_path, "is_state.h5")
784    is_state = {"cache_path": cache_path}
785
786    with h5py.File(cache_path, "a") as f:
787        for name, g in f.items():
788            i = int(name.split("-")[-1])
789            state = {
790                "foreground": g["foreground"][:],
791                "boundary_distances": g["boundary_distances"][:],
792                "center_distances": g["center_distances"][:],
793            }
794            is_state[i] = state
795
796    return is_state
def point_layer_to_prompts( layer: napari.layers.points.points.Points, i=None, track_id=None, with_stop_annotation=True) -> Optional[Tuple[numpy.ndarray, numpy.ndarray]]:
158def point_layer_to_prompts(
159    layer: napari.layers.Points, i=None, track_id=None, with_stop_annotation=True,
160) -> Optional[Tuple[np.ndarray, np.ndarray]]:
161    """Extract point prompts for SAM from a napari point layer.
162
163    Args:
164        layer: The point layer from which to extract the prompts.
165        i: Index for the data (required for 3d or timeseries data).
166        track_id: Id of the current track (required for tracking data).
167        with_stop_annotation: Whether a single negative point will be interpreted
168            as stop annotation or just returned as normal prompt.
169
170    Returns:
171        The point coordinates for the prompts.
172        The labels (positive or negative / 1 or 0) for the prompts.
173    """
174
175    points = layer.data
176    labels = layer.properties["label"]
177    assert len(points) == len(labels)
178
179    if i is None:
180        assert points.shape[1] == 2, f"{points.shape}"
181        this_points, this_labels = points, labels
182    else:
183        assert points.shape[1] == 3, f"{points.shape}"
184        mask = np.round(points[:, 0]) == i
185        this_points = points[mask][:, 1:]
186        this_labels = labels[mask]
187    assert len(this_points) == len(this_labels)
188
189    if track_id is not None:
190        assert i is not None
191        track_ids = np.array(list(map(int, layer.properties["track_id"])))[mask]
192        track_id_mask = track_ids == track_id
193        this_labels, this_points = this_labels[track_id_mask], this_points[track_id_mask]
194    assert len(this_points) == len(this_labels)
195
196    this_labels = np.array([1 if label == "positive" else 0 for label in this_labels])
197    # a single point with a negative label is interpreted as 'stop' signal
198    # in this case we return None
199    if with_stop_annotation and (len(this_points) == 1 and this_labels[0] == 0):
200        return None
201
202    return this_points, this_labels

Extract point prompts for SAM from a napari point layer.

Arguments:
  • layer: The point layer from which to extract the prompts.
  • i: Index for the data (required for 3d or timeseries data).
  • track_id: Id of the current track (required for tracking data).
  • with_stop_annotation: Whether a single negative point will be interpreted as stop annotation or just returned as normal prompt.
Returns:

The point coordinates for the prompts. The labels (positive or negative / 1 or 0) for the prompts.

def shape_layer_to_prompts( layer: napari.layers.shapes.shapes.Shapes, shape: Tuple[int, int], i=None, track_id=None) -> Tuple[List[numpy.ndarray], List[Optional[numpy.ndarray]]]:
205def shape_layer_to_prompts(
206    layer: napari.layers.Shapes, shape: Tuple[int, int], i=None, track_id=None
207) -> Tuple[List[np.ndarray], List[Optional[np.ndarray]]]:
208    """Extract prompts for SAM from a napari shape layer.
209
210    Extracts the bounding box for 'rectangle' shapes and the bounding box and corresponding mask
211    for 'ellipse' and 'polygon' shapes.
212
213    Args:
214        prompt_layer: The napari shape layer.
215        shape: The image shape.
216        i: Index for the data (required for 3d or timeseries data).
217        track_id: Id of the current track (required for tracking data).
218
219    Returns:
220        The box prompts.
221        The mask prompts.
222    """
223
224    def _to_prompts(shape_data, shape_types):
225        boxes, masks = [], []
226
227        for data, type_ in zip(shape_data, shape_types):
228
229            if type_ == "rectangle":
230                boxes.append(data)
231                masks.append(None)
232
233            elif type_ == "ellipse":
234                boxes.append(data)
235                center = np.mean(data, axis=0)
236                radius_r = ((data[2] - data[1]) / 2)[0]
237                radius_c = ((data[1] - data[0]) / 2)[1]
238                rr, cc = draw.ellipse(center[0], center[1], radius_r, radius_c, shape=shape)
239                mask = np.zeros(shape, dtype=bool)
240                mask[rr, cc] = 1
241                masks.append(mask)
242
243            elif type_ == "polygon":
244                boxes.append(data)
245                rr, cc = draw.polygon(data[:, 0], data[:, 1], shape=shape)
246                mask = np.zeros(shape, dtype=bool)
247                mask[rr, cc] = 1
248                masks.append(mask)
249
250            else:
251                warnings.warn(f"Shape type {type_} is not supported and will be ignored.")
252
253        # map to correct box format
254        boxes = [
255            np.array([box[:, 0].min(), box[:, 1].min(), box[:, 0].max(), box[:, 1].max()]) for box in boxes
256        ]
257        return boxes, masks
258
259    shape_data, shape_types = layer.data, layer.shape_type
260    assert len(shape_data) == len(shape_types)
261    if len(shape_data) == 0:
262        return [], []
263
264    if i is not None:
265        if track_id is None:
266            prompt_selection = [j for j, data in enumerate(shape_data) if (data[:, 0] == i).all()]
267        else:
268            track_ids = np.array(list(map(int, layer.properties["track_id"])))
269            prompt_selection = [
270                j for j, (data, this_track_id) in enumerate(zip(shape_data, track_ids))
271                if ((data[:, 0] == i).all() and this_track_id == track_id)
272            ]
273
274        shape_data = [shape_data[j][:, 1:] for j in prompt_selection]
275        shape_types = [shape_types[j] for j in prompt_selection]
276
277    boxes, masks = _to_prompts(shape_data, shape_types)
278    return boxes, masks

Extract prompts for SAM from a napari shape layer.

Extracts the bounding box for 'rectangle' shapes and the bounding box and corresponding mask for 'ellipse' and 'polygon' shapes.

Arguments:
  • prompt_layer: The napari shape layer.
  • shape: The image shape.
  • i: Index for the data (required for 3d or timeseries data).
  • track_id: Id of the current track (required for tracking data).
Returns:

The box prompts. The mask prompts.

def prompt_layer_to_state(prompt_layer: napari.layers.points.points.Points, i: int) -> str:
281def prompt_layer_to_state(prompt_layer: napari.layers.Points, i: int) -> str:
282    """Get the state of the track from a point layer for a given timeframe.
283
284    Only relevant for annotator_tracking.
285
286    Args:
287        prompt_layer: The napari layer.
288        i: Timeframe of the data.
289
290    Returns:
291        The state of this frame (either "division" or "track").
292    """
293    state = prompt_layer.properties["state"]
294
295    points = prompt_layer.data
296    assert points.shape[1] == 3, f"{points.shape}"
297    mask = points[:, 0] == i
298    this_points = points[mask][:, 1:]
299    this_state = state[mask]
300    assert len(this_points) == len(this_state)
301
302    # we set the state to 'division' if at least one point in this frame has a division label
303    if any(st == "division" for st in this_state):
304        return "division"
305    else:
306        return "track"

Get the state of the track from a point layer for a given timeframe.

Only relevant for annotator_tracking.

Arguments:
  • prompt_layer: The napari layer.
  • i: Timeframe of the data.
Returns:

The state of this frame (either "division" or "track").

def prompt_layers_to_state( point_layer: napari.layers.points.points.Points, box_layer: napari.layers.shapes.shapes.Shapes, i: int) -> str:
309def prompt_layers_to_state(point_layer: napari.layers.Points, box_layer: napari.layers.Shapes, i: int) -> str:
310    """Get the state of the track from a point layer and shape layer for a given timeframe.
311
312    Only relevant for annotator_tracking.
313
314    Args:
315        point_layer: The napari point layer.
316        box_layer: The napari box layer.
317        i: Timeframe of the data.
318
319    Returns:
320        The state of this frame (either "division" or "track").
321    """
322    state = point_layer.properties["state"]
323
324    points = point_layer.data
325    assert points.shape[1] == 3, f"{points.shape}"
326    mask = points[:, 0] == i
327    if mask.sum() > 0:
328        this_state = state[mask].tolist()
329    else:
330        this_state = []
331
332    box_states = box_layer.properties["state"]
333    this_box_states = [
334        state for box, state in zip(box_layer.data, box_states)
335        if (box[:, 0] == i).all()
336    ]
337    this_state.extend(this_box_states)
338
339    # we set the state to 'division' if at least one point in this frame has a division label
340    if any(st == "division" for st in this_state):
341        return "division"
342    else:
343        return "track"

Get the state of the track from a point layer and shape layer for a given timeframe.

Only relevant for annotator_tracking.

Arguments:
  • point_layer: The napari point layer.
  • box_layer: The napari box layer.
  • i: Timeframe of the data.
Returns:

The state of this frame (either "division" or "track").