micro_sam.sam_annotator.util

  1import os
  2import pickle
  3import warnings
  4import argparse
  5from glob import glob
  6from pathlib import Path
  7from typing import List, Optional, Tuple
  8
  9import h5py
 10import napari
 11import numpy as np
 12from skimage import draw
 13from scipy.ndimage import shift
 14
 15from .. import prompt_based_segmentation, util
 16from .. import _model_settings as model_settings
 17from ..multi_dimensional_segmentation import _validate_projection
 18
 19# Green and Red
 20LABEL_COLOR_CYCLE = ["#00FF00", "#FF0000"]
 21"""@private"""
 22
 23
 24#
 25# Misc helper functions
 26#
 27
 28
 29def toggle_label(prompts):
 30    """@private"""
 31    # get the currently selected label
 32    current_properties = prompts.current_properties
 33    current_label = current_properties["label"][0]
 34    new_label = "negative" if current_label == "positive" else "positive"
 35    current_properties["label"] = np.array([new_label])
 36    prompts.current_properties = current_properties
 37    prompts.refresh()
 38    prompts.refresh_colors()
 39
 40
 41def _initialize_parser(description, with_segmentation_result=True, with_instance_segmentation=True):
 42
 43    available_models = list(util.get_model_names())
 44    available_models = ", ".join(available_models)
 45
 46    parser = argparse.ArgumentParser(description=description)
 47
 48    parser.add_argument(
 49        "-i", "--input", required=True,
 50        help="The filepath to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) "
 51        "or elf.io.open_file (e.g. hdf5, zarr, mrc). For the latter you also need to pass the 'key' parameter."
 52    )
 53    parser.add_argument(
 54        "-k", "--key",
 55        help="The key for opening data with elf.io.open_file. This is the internal path for a hdf5 or zarr container, "
 56        "for a image series it is a wild-card, e.g. '*.png' and for mrc it is 'data'."
 57    )
 58    parser.add_argument(
 59        "-e", "--embedding_path",
 60        help="The filepath for saving/loading the pre-computed image embeddings. "
 61        "It is recommended to pass this argument and store the embeddings if you want to open the annotator "
 62        "multiple times for this image. Otherwise the embeddings will be recomputed every time."
 63    )
 64
 65    if with_segmentation_result:
 66        parser.add_argument(
 67            "-s", "--segmentation_result",
 68            help="Optional filepath to a precomputed segmentation. If passed this will be used to initialize the "
 69            "'committed_objects' layer. This can be useful if you want to correct an existing segmentation or if you "
 70            "have saved intermediate results from the annotator and want to continue with your annotations. "
 71            "Supports the same file formats as 'input'."
 72        )
 73        parser.add_argument(
 74            "-sk", "--segmentation_key",
 75            help="The key for opening the segmentation data. Same rules as for 'key' apply."
 76        )
 77
 78    parser.add_argument(
 79        "-m", "--model_type", default=util._DEFAULT_MODEL,
 80        help=f"The segment anything model that will be used, one of {available_models}."
 81    )
 82    parser.add_argument(
 83        "-c", "--checkpoint", default=None,
 84        help="Checkpoint from which the SAM model will be loaded loaded."
 85    )
 86    parser.add_argument(
 87        "-d", "--device", default=None,
 88        help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)."
 89        "By default the most performant available device will be selected."
 90    )
 91    parser.add_argument(
 92        "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction", default=None
 93    )
 94    parser.add_argument(
 95        "--halo", nargs="+", type=int, help="The halo for using tiled prediction", default=None
 96    )
 97
 98    if with_instance_segmentation:
 99        parser.add_argument(
100            "--precompute_amg_state", action="store_true",
101            help="Whether to precompute the state for automatic instance segmentation. "
102            "This will lead to a longer start-up time, but the automatic instance segmentation can "
103            "be run directly once the tool has started."
104        )
105        parser.add_argument(
106            "--prefer_decoder", action="store_false",
107            help="Whether to use decoder based instance segmentation if the model "
108            "being used has an additional decoder for that purpose."
109        )
110
111    return parser
112
113
114def clear_annotations(viewer: napari.Viewer, clear_segmentations=True) -> None:
115    """@private"""
116    viewer.layers["point_prompts"].data = []
117    viewer.layers["point_prompts"].refresh()
118    if "prompts" in viewer.layers:
119        # Select all prompts and then remove them.
120        # This is how it worked before napari 0.5.
121        # viewer.layers["prompts"].data = []
122        viewer.layers["prompts"].selected_data = set(range(len(viewer.layers["prompts"].data)))
123        viewer.layers["prompts"].remove_selected()
124        viewer.layers["prompts"].refresh()
125    if not clear_segmentations:
126        return
127    viewer.layers["current_object"].data = np.zeros(viewer.layers["current_object"].data.shape, dtype="uint32")
128    viewer.layers["current_object"].refresh()
129
130
131def clear_annotations_slice(viewer: napari.Viewer, i: int, clear_segmentations=True) -> None:
132    """@private"""
133    point_prompts = viewer.layers["point_prompts"].data
134    point_prompts = point_prompts[point_prompts[:, 0] != i]
135    viewer.layers["point_prompts"].data = point_prompts
136    viewer.layers["point_prompts"].refresh()
137    if "prompts" in viewer.layers:
138        prompts = viewer.layers["prompts"].data
139        prompts = [prompt for prompt in prompts if not (prompt[:, 0] == i).all()]
140        viewer.layers["prompts"].data = prompts
141        viewer.layers["prompts"].refresh()
142    if not clear_segmentations:
143        return
144    viewer.layers["current_object"].data[i] = 0
145    viewer.layers["current_object"].refresh()
146
147
148#
149# Helper functions to extract prompts from napari layers.
150#
151
152
153def point_layer_to_prompts(
154    layer: napari.layers.Points, i=None, track_id=None, with_stop_annotation=True,
155) -> Optional[Tuple[np.ndarray, np.ndarray]]:
156    """Extract point prompts for SAM from a napari point layer.
157
158    Args:
159        layer: The point layer from which to extract the prompts.
160        i: Index for the data (required for 3d or timeseries data).
161        track_id: Id of the current track (required for tracking data).
162        with_stop_annotation: Whether a single negative point will be interpreted
163            as stop annotation or just returned as normal prompt.
164
165    Returns:
166        The point coordinates for the prompts.
167        The labels (positive or negative / 1 or 0) for the prompts.
168    """
169
170    points = layer.data
171    labels = layer.properties["label"]
172    assert len(points) == len(labels)
173
174    if i is None:
175        assert points.shape[1] == 2, f"{points.shape}"
176        this_points, this_labels = points, labels
177    else:
178        assert points.shape[1] == 3, f"{points.shape}"
179        mask = np.round(points[:, 0]) == i
180        this_points = points[mask][:, 1:]
181        this_labels = labels[mask]
182    assert len(this_points) == len(this_labels)
183
184    if track_id is not None:
185        assert i is not None
186        track_ids = np.array(list(map(int, layer.properties["track_id"])))[mask]
187        track_id_mask = track_ids == track_id
188        this_labels, this_points = this_labels[track_id_mask], this_points[track_id_mask]
189    assert len(this_points) == len(this_labels)
190
191    this_labels = np.array([1 if label == "positive" else 0 for label in this_labels])
192    # a single point with a negative label is interpreted as 'stop' signal
193    # in this case we return None
194    if with_stop_annotation and (len(this_points) == 1 and this_labels[0] == 0):
195        return None
196
197    return this_points, this_labels
198
199
200def shape_layer_to_prompts(
201    layer: napari.layers.Shapes, shape: Tuple[int, int], i=None, track_id=None
202) -> Tuple[List[np.ndarray], List[Optional[np.ndarray]]]:
203    """Extract prompts for SAM from a napari shape layer.
204
205    Extracts the bounding box for 'rectangle' shapes and the bounding box and corresponding mask
206    for 'ellipse' and 'polygon' shapes.
207
208    Args:
209        prompt_layer: The napari shape layer.
210        shape: The image shape.
211        i: Index for the data (required for 3d or timeseries data).
212        track_id: Id of the current track (required for tracking data).
213
214    Returns:
215        The box prompts.
216        The mask prompts.
217    """
218
219    def _to_prompts(shape_data, shape_types):
220        boxes, masks = [], []
221
222        for data, type_ in zip(shape_data, shape_types):
223
224            if type_ == "rectangle":
225                boxes.append(data)
226                masks.append(None)
227
228            elif type_ == "ellipse":
229                boxes.append(data)
230                center = np.mean(data, axis=0)
231                radius_r = ((data[2] - data[1]) / 2)[0]
232                radius_c = ((data[1] - data[0]) / 2)[1]
233                rr, cc = draw.ellipse(center[0], center[1], radius_r, radius_c, shape=shape)
234                mask = np.zeros(shape, dtype=bool)
235                mask[rr, cc] = 1
236                masks.append(mask)
237
238            elif type_ == "polygon":
239                boxes.append(data)
240                rr, cc = draw.polygon(data[:, 0], data[:, 1], shape=shape)
241                mask = np.zeros(shape, dtype=bool)
242                mask[rr, cc] = 1
243                masks.append(mask)
244
245            else:
246                warnings.warn(f"Shape type {type_} is not supported and will be ignored.")
247
248        # map to correct box format
249        boxes = [
250            np.array([box[:, 0].min(), box[:, 1].min(), box[:, 0].max(), box[:, 1].max()]) for box in boxes
251        ]
252        return boxes, masks
253
254    shape_data, shape_types = layer.data, layer.shape_type
255    assert len(shape_data) == len(shape_types)
256    if len(shape_data) == 0:
257        return [], []
258
259    if i is not None:
260        if track_id is None:
261            prompt_selection = [j for j, data in enumerate(shape_data) if (data[:, 0] == i).all()]
262        else:
263            track_ids = np.array(list(map(int, layer.properties["track_id"])))
264            prompt_selection = [
265                j for j, (data, this_track_id) in enumerate(zip(shape_data, track_ids))
266                if ((data[:, 0] == i).all() and this_track_id == track_id)
267            ]
268
269        shape_data = [shape_data[j][:, 1:] for j in prompt_selection]
270        shape_types = [shape_types[j] for j in prompt_selection]
271
272    boxes, masks = _to_prompts(shape_data, shape_types)
273    return boxes, masks
274
275
276def prompt_layer_to_state(prompt_layer: napari.layers.Points, i: int) -> str:
277    """Get the state of the track from a point layer for a given timeframe.
278
279    Only relevant for annotator_tracking.
280
281    Args:
282        prompt_layer: The napari layer.
283        i: Timeframe of the data.
284
285    Returns:
286        The state of this frame (either "division" or "track").
287    """
288    state = prompt_layer.properties["state"]
289
290    points = prompt_layer.data
291    assert points.shape[1] == 3, f"{points.shape}"
292    mask = points[:, 0] == i
293    this_points = points[mask][:, 1:]
294    this_state = state[mask]
295    assert len(this_points) == len(this_state)
296
297    # we set the state to 'division' if at least one point in this frame has a division label
298    if any(st == "division" for st in this_state):
299        return "division"
300    else:
301        return "track"
302
303
304def prompt_layers_to_state(point_layer: napari.layers.Points, box_layer: napari.layers.Shapes, i: int) -> str:
305    """Get the state of the track from a point layer and shape layer for a given timeframe.
306
307    Only relevant for annotator_tracking.
308
309    Args:
310        point_layer: The napari point layer.
311        box_layer: The napari box layer.
312        i: Timeframe of the data.
313
314    Returns:
315        The state of this frame (either "division" or "track").
316    """
317    state = point_layer.properties["state"]
318
319    points = point_layer.data
320    assert points.shape[1] == 3, f"{points.shape}"
321    mask = points[:, 0] == i
322    if mask.sum() > 0:
323        this_state = state[mask].tolist()
324    else:
325        this_state = []
326
327    box_states = box_layer.properties["state"]
328    this_box_states = [
329        state for box, state in zip(box_layer.data, box_states)
330        if (box[:, 0] == i).all()
331    ]
332    this_state.extend(this_box_states)
333
334    # we set the state to 'division' if at least one point in this frame has a division label
335    if any(st == "division" for st in this_state):
336        return "division"
337    else:
338        return "track"
339
340
341#
342# Helper functions to run (multi-dimensional) segmentation on napari layers.
343#
344
345
346def segment_slices_with_prompts(
347    predictor, point_prompts, box_prompts, image_embeddings, shape, track_id=None, update_progress=None,
348):
349    """@private"""
350    assert len(shape) == 3
351    image_shape = shape[1:]
352    seg = np.zeros(shape, dtype="uint32")
353
354    z_values = np.round(point_prompts.data[:, 0])
355    z_values_boxes = np.concatenate([box[:1, 0] for box in box_prompts.data]) if box_prompts.data else\
356        np.zeros(0, dtype="int")
357
358    if track_id is not None:
359        track_ids_points = np.array(list(map(int, point_prompts.properties["track_id"])))
360        assert len(track_ids_points) == len(z_values)
361        z_values = z_values[track_ids_points == track_id]
362
363        if len(z_values_boxes) > 0:
364            track_ids_boxes = np.array(list(map(int, box_prompts.properties["track_id"])))
365            assert len(track_ids_boxes) == len(z_values_boxes), f"{len(track_ids_boxes)}, {len(z_values_boxes)}"
366            z_values_boxes = z_values_boxes[track_ids_boxes == track_id]
367
368    slices = np.unique(np.concatenate([z_values, z_values_boxes])).astype("int")
369    stop_lower, stop_upper = False, False
370
371    if update_progress is None:
372        def update_progress(*args):
373            pass
374
375    for i in slices:
376        points_i = point_layer_to_prompts(point_prompts, i, track_id)
377
378        # do we end the segmentation at the outer slices?
379        if points_i is None:
380
381            if i == slices[0]:  # The bottom slice is a stop slice.
382                stop_lower = True
383                seg[i] = 0
384            elif i == slices[-1]:  # The top sloce is a stop slice.
385                stop_upper = True
386                seg[i] = 0
387            else:  # We have a stop annotation somewhere in the middle. Ignore this.
388                # Remove this slice from the annotated slices, so that it is segmented via
389                # projection in the next step.
390                slices = np.setdiff1d(slices, i)
391                print(f"You have provided a stop annotation (single red point) in slice {i},")
392                print("but you have annotated slices above or below it. This stop annotation will")
393                print(f"be ignored and the slice {i} will be segmented normally.")
394
395            update_progress(1)
396            continue
397
398        boxes, masks = shape_layer_to_prompts(box_prompts, image_shape, i=i, track_id=track_id)
399        points, labels = points_i
400
401        seg_i = prompt_segmentation(
402            predictor, points, labels, boxes, masks, image_shape, multiple_box_prompts=False,
403            image_embeddings=image_embeddings, i=i
404        )
405        if seg_i is None:
406            print(f"The prompts at slice or frame {i} are invalid and the segmentation was skipped.")
407            print("This will lead to a wrong segmentation across slices or frames.")
408            print(f"Please correct the prompts in {i} and rerun the segmentation.")
409            continue
410
411        seg[i] = seg_i
412        update_progress(1)
413
414    return seg, slices, stop_lower, stop_upper
415
416
417# For advanced batching: match prompts to already segmented objects and continue segmentation.
418def _match_prompts(previous_segmentation, points, boxes, seg_ids):
419    # Create a mapping between ids and prompts.
420    batched_prompts = {}
421    # seg_boundaries = find_boundaries(previous_segmentation, mode="inner")
422    # indices = distance_transform_edt(seg_boundaries, return_distance=False, return_index=True)
423    return batched_prompts
424
425
426def _batched_interactive_segmentation(predictor, points, labels, boxes, image_embeddings, i, previous_segmentation):
427    prev_seg = previous_segmentation if i is None else previous_segmentation[i]
428    seg = np.zeros(prev_seg.shape, dtype="uint32")
429
430    # seg_ids = np.unique(previous_segmentation)
431    # assert seg_ids[0] == 0
432
433    batched_points, batched_labels = [], []
434    negative_points, negative_labels = [], []
435    for j in range(len(points)):
436        if labels[j] == 1:  # positive point
437            batched_points.append(points[j:j+1])
438            batched_labels.append(labels[j:j+1])
439        else:  # negative points
440            negative_points.append(points[j:j+1])
441            negative_labels.append(labels[j:j+1])
442
443    batched_prompts = [(None, point, label) for point, label in zip(batched_points, batched_labels)]
444    batched_prompts.extend([(box, None, None) for box in boxes])
445    batched_prompts = {i: prompt for i, prompt in enumerate(batched_prompts, 1)}
446
447    # For advanced batching: match prompts to already segmented objects and continue segmentation.
448    # (This is left here as a reference for how this can be implemented.
449    #  I have not decided yet if this is actually a good idea or not.)
450    # # If we have no objects: this is the first call for a batched segmentation.
451    # # We treat each positive point or box as a separate obejct.
452    # if len(seg_ids) == 1:
453    #     # Create a list of all prompts.
454    #     batched_prompts = [(None, point, label) for point, label in zip(batched_points, batched_labels)]
455    #     batched_prompts.extend([(box, None, None) for box in boxes])
456    #     batched_prompts = {i: prompt for i, prompt in enumerate(batched_prompts, 1)}
457
458    # # Otherwise we match the prompts to existing objects.
459    # else:
460    #     batched_prompts = _match_prompts(prev_seg, batched_points, boxes, seg_ids)
461
462    for seg_id, prompt in batched_prompts.items():
463        box, point, label = prompt
464        if len(negative_points) > 0:
465            if point is None:
466                point, label = negative_points, negative_labels
467            else:
468                point = np.concatenate([point] + negative_points)
469                label = np.concatenate([label] + negative_labels)
470
471        if (box is not None) and (point is not None):
472            prediction = prompt_based_segmentation.segment_from_box_and_points(
473                predictor, box, point, label, image_embeddings=image_embeddings, i=i
474            ).squeeze()
475        elif (box is not None) and (point is None):
476            prediction = prompt_based_segmentation.segment_from_box(
477                predictor, box, image_embeddings=image_embeddings, i=i
478            ).squeeze()
479        else:
480            prediction = prompt_based_segmentation.segment_from_points(
481                predictor, point, label, image_embeddings=image_embeddings, i=i
482            ).squeeze()
483
484        seg[prediction] = seg_id
485
486    return seg
487
488
489def prompt_segmentation(
490    predictor, points, labels, boxes, masks, shape, multiple_box_prompts,
491    image_embeddings=None, i=None, box_extension=0, batched=None, previous_segmentation=None,
492):
493    """@private"""
494    assert len(points) == len(labels)
495    have_points = len(points) > 0
496    have_boxes = len(boxes) > 0
497
498    # No prompts were given, return None.
499    if not have_points and not have_boxes:
500        return
501
502    # Batched interactive segmentation.
503    elif batched:
504        assert previous_segmentation is not None
505        seg = _batched_interactive_segmentation(
506            predictor, points, labels, boxes, image_embeddings, i, previous_segmentation
507        )
508
509    # Box and point prompts were given.
510    elif have_points and have_boxes:
511        if len(boxes) > 1:
512            print("You have provided point prompts and more than one box prompt.")
513            print("This setting is currently not supported.")
514            print("When providing both points and prompts you can only segment one object at a time.")
515            return
516        mask = masks[0]
517        if mask is None:
518            seg = prompt_based_segmentation.segment_from_box_and_points(
519                predictor, boxes[0], points, labels, image_embeddings=image_embeddings, i=i
520            ).squeeze()
521        else:
522            seg = prompt_based_segmentation.segment_from_mask(
523                predictor, mask, box=boxes[0], points=points, labels=labels, image_embeddings=image_embeddings, i=i
524            ).squeeze()
525
526    # Only point prompts were given.
527    elif have_points and not have_boxes:
528        seg = prompt_based_segmentation.segment_from_points(
529            predictor, points, labels, image_embeddings=image_embeddings, i=i
530        ).squeeze()
531
532    # Only box prompts were given.
533    elif not have_points and have_boxes:
534        seg = np.zeros(shape, dtype="uint32")
535
536        if len(boxes) > 1 and not multiple_box_prompts:
537            print("You have provided more than one box annotation. This is not yet supported in the 3d annotator.")
538            print("You can only segment one object at a time in 3d.")
539            return
540
541        # Batch this?
542        for seg_id, (box, mask) in enumerate(zip(boxes, masks), 1):
543            if mask is None:
544                prediction = prompt_based_segmentation.segment_from_box(
545                    predictor, box, image_embeddings=image_embeddings, i=i
546                ).squeeze()
547            else:
548                prediction = prompt_based_segmentation.segment_from_mask(
549                    predictor, mask, box=box, image_embeddings=image_embeddings, i=i,
550                    box_extension=box_extension,
551                ).squeeze()
552            seg[prediction] = seg_id
553
554    return seg
555
556
557def _compute_movement(seg, t0, t1):
558
559    def compute_center(t):
560        # computation with center of mass
561        center = np.where(seg[t] == 1)
562        center = np.array(np.mean(center[0]), np.mean(center[1]))
563        return center
564
565    center0 = compute_center(t0)
566    center1 = compute_center(t1)
567
568    move = center1 - center0
569    return move.astype("float64")
570
571
572def _shift_object(mask, motion_model):
573    mask_shifted = np.zeros_like(mask)
574    shift(mask, motion_model, output=mask_shifted, order=0, prefilter=False)
575    return mask_shifted
576
577
578def track_from_prompts(
579    point_prompts, box_prompts, seg, predictor, slices, image_embeddings,
580    stop_upper, threshold, projection, motion_smoothing=0.5, box_extension=0, update_progress=None,
581):
582    """@private
583    """
584    use_box, use_mask, use_points, use_single_point = _validate_projection(projection)
585
586    if update_progress is None:
587        def update_progress(*args):
588            pass
589
590    # shift the segmentation based on the motion model and update the motion model
591    def _update_motion_model(seg, t, t0, motion_model):
592        if t in (t0, t0 + 1):  # this is the first or second frame, we don't have a motion yet
593            pass
594        elif t == t0 + 2:  # this the third frame, we initialize the motion model
595            current_move = _compute_movement(seg, t - 1, t - 2)
596            motion_model = current_move
597        else:  # we already have a motion model and update it
598            current_move = _compute_movement(seg, t - 1, t - 2)
599            alpha = motion_smoothing
600            motion_model = alpha * motion_model + (1 - alpha) * current_move
601
602        return motion_model
603
604    has_division = False
605    motion_model = None
606    verbose = False
607
608    t0 = int(slices.min())
609    t = t0 + 1
610    while True:
611
612        # update the motion model
613        motion_model = _update_motion_model(seg, t, t0, motion_model)
614
615        # use the segmentation from prompts if we are in a slice with prompts
616        if t in slices:
617            seg_prev = None
618            seg_t = seg[t]
619            # currently using the box layer doesn't work for keeping track of the track state
620            # track_state = prompt_layers_to_state(point_prompts, box_prompts, t)
621            track_state = prompt_layer_to_state(point_prompts, t)
622
623        # otherwise project the mask (under the motion model) and segment the next slice from the mask
624        else:
625            if verbose:
626                print(f"Tracking object in frame {t} with movement {motion_model}")
627
628            seg_prev = seg[t - 1]
629            # shift the segmentation according to the motion model
630            if motion_model is not None:
631                seg_prev = _shift_object(seg_prev, motion_model)
632
633            seg_t = prompt_based_segmentation.segment_from_mask(
634                predictor, seg_prev, image_embeddings=image_embeddings, i=t,
635                use_mask=use_mask, use_box=use_box, use_points=use_points,
636                box_extension=box_extension, use_single_point=use_single_point,
637            )
638            track_state = "track"
639
640            # are we beyond the last slice with prompt?
641            # if no: we continue tracking because we know we need to connect to a future frame
642            # if yes: we only continue tracking if overlaps are above the threshold
643            if t < slices[-1]:
644                seg_prev = None
645
646            update_progress(1)
647
648        if (threshold is not None) and (seg_prev is not None):
649            iou = util.compute_iou(seg_prev, seg_t)
650            if iou < threshold:
651                msg = f"Segmentation stopped at frame {t} due to IOU {iou} < {threshold}."
652                print(msg)
653                break
654
655        # stop if we have a division
656        if track_state == "division":
657            has_division = True
658            break
659
660        seg[t] = seg_t
661        t += 1
662
663        # stop tracking if we have stop upper set (i.e. single negative point was set to indicate stop track)
664        if t == slices[-1] and stop_upper:
665            break
666
667        # stop if we are at the last slce
668        if t == seg.shape[0]:
669            break
670
671    return seg, has_division
672
673
674def _sync_embedding_widget(widget, model_type, save_path, checkpoint_path, device, tile_shape, halo):
675    widget.model_type = model_type
676    index = widget.model_dropdown.findText(model_type)
677    if index > 0:
678        widget.model_dropdown.setCurrentIndex(index)
679
680    if save_path is not None:
681        widget.embeddings_save_path_param.setText(str(save_path))
682
683    if checkpoint_path is not None:
684        widget.custom_weights_param.setText(str(checkpoint_path))
685
686    if device is not None:
687        widget.device = device
688        index = widget.device_dropdown.findText(device)
689        widget.device_dropdown.setCurrentIndex(index)
690
691    if tile_shape is not None:
692        widget.tile_x_param.setValue(tile_shape[0])
693        widget.tile_y_param.setValue(tile_shape[1])
694
695    if halo is not None:
696        widget.halo_x_param.setValue(halo[0])
697        widget.halo_y_param.setValue(halo[1])
698
699
700# Read parameters from checkpoint path if it is given instead.
701def _sync_autosegment_widget(widget, model_type, checkpoint_path, update_decoder=None):
702    if update_decoder is not None:
703        widget._reset_segmentation_mode(update_decoder)
704
705    if widget.with_decoder:
706        settings = model_settings.AIS_SETTINGS.get(model_type, {})
707        params = ("center_distance_thresh", "boundary_distance_thresh")
708        for param in params:
709            if param in settings:
710                getattr(widget, f"{param}_param").setValue(settings[param])
711    else:
712        settings = model_settings.AMG_SETTINGS.get(model_type, {})
713        params = ("pred_iou_thresh", "stability_score_thresh", "min_object_size")
714        for param in params:
715            if param in settings:
716                getattr(widget, f"{param}_param").setValue(settings[param])
717
718
719# Read parameters from checkpoint path if it is given instead.
720def _sync_ndsegment_widget(widget, model_type, checkpoint_path):
721    settings = model_settings.ND_SEGMENT_SETTINGS.get(model_type, {})
722
723    if "projection_mode" in settings:
724        projection_mode = settings["projection_mode"]
725        widget.projection = projection_mode
726        index = widget.projection_dropdown.findText(projection_mode)
727        if index > 0:
728            widget.projection_dropdown.setCurrentIndex(index)
729
730    params = ("iou_threshold", "box_extension")
731    for param in params:
732        if param in settings:
733            getattr(widget, f"{param}_param").setValue(settings[param])
734
735
736def _load_amg_state(embedding_path):
737    if embedding_path is None or not os.path.exists(embedding_path):
738        return {"cache_folder": None}
739
740    cache_folder = os.path.join(embedding_path, "amg_state")
741    os.makedirs(cache_folder, exist_ok=True)
742    amg_state = {"cache_folder": cache_folder}
743
744    state_paths = glob(os.path.join(cache_folder, "*.pkl"))
745    for path in state_paths:
746        with open(path, "rb") as f:
747            state = pickle.load(f)
748        i = int(Path(path).stem.split("-")[-1])
749        amg_state[i] = state
750    return amg_state
751
752
753def _load_is_state(embedding_path):
754    if embedding_path is None or not os.path.exists(embedding_path):
755        return {"cache_path": None}
756
757    cache_path = os.path.join(embedding_path, "is_state.h5")
758    is_state = {"cache_path": cache_path}
759
760    with h5py.File(cache_path, "a") as f:
761        for name, g in f.items():
762            i = int(name.split("-")[-1])
763            state = {
764                "foreground": g["foreground"][:],
765                "boundary_distances": g["boundary_distances"][:],
766                "center_distances": g["center_distances"][:],
767            }
768            is_state[i] = state
769
770    return is_state
def point_layer_to_prompts( layer: napari.layers.points.points.Points, i=None, track_id=None, with_stop_annotation=True) -> Optional[Tuple[numpy.ndarray, numpy.ndarray]]:
154def point_layer_to_prompts(
155    layer: napari.layers.Points, i=None, track_id=None, with_stop_annotation=True,
156) -> Optional[Tuple[np.ndarray, np.ndarray]]:
157    """Extract point prompts for SAM from a napari point layer.
158
159    Args:
160        layer: The point layer from which to extract the prompts.
161        i: Index for the data (required for 3d or timeseries data).
162        track_id: Id of the current track (required for tracking data).
163        with_stop_annotation: Whether a single negative point will be interpreted
164            as stop annotation or just returned as normal prompt.
165
166    Returns:
167        The point coordinates for the prompts.
168        The labels (positive or negative / 1 or 0) for the prompts.
169    """
170
171    points = layer.data
172    labels = layer.properties["label"]
173    assert len(points) == len(labels)
174
175    if i is None:
176        assert points.shape[1] == 2, f"{points.shape}"
177        this_points, this_labels = points, labels
178    else:
179        assert points.shape[1] == 3, f"{points.shape}"
180        mask = np.round(points[:, 0]) == i
181        this_points = points[mask][:, 1:]
182        this_labels = labels[mask]
183    assert len(this_points) == len(this_labels)
184
185    if track_id is not None:
186        assert i is not None
187        track_ids = np.array(list(map(int, layer.properties["track_id"])))[mask]
188        track_id_mask = track_ids == track_id
189        this_labels, this_points = this_labels[track_id_mask], this_points[track_id_mask]
190    assert len(this_points) == len(this_labels)
191
192    this_labels = np.array([1 if label == "positive" else 0 for label in this_labels])
193    # a single point with a negative label is interpreted as 'stop' signal
194    # in this case we return None
195    if with_stop_annotation and (len(this_points) == 1 and this_labels[0] == 0):
196        return None
197
198    return this_points, this_labels

Extract point prompts for SAM from a napari point layer.

Arguments:
  • layer: The point layer from which to extract the prompts.
  • i: Index for the data (required for 3d or timeseries data).
  • track_id: Id of the current track (required for tracking data).
  • with_stop_annotation: Whether a single negative point will be interpreted as stop annotation or just returned as normal prompt.
Returns:

The point coordinates for the prompts. The labels (positive or negative / 1 or 0) for the prompts.

def shape_layer_to_prompts( layer: napari.layers.shapes.shapes.Shapes, shape: Tuple[int, int], i=None, track_id=None) -> Tuple[List[numpy.ndarray], List[Optional[numpy.ndarray]]]:
201def shape_layer_to_prompts(
202    layer: napari.layers.Shapes, shape: Tuple[int, int], i=None, track_id=None
203) -> Tuple[List[np.ndarray], List[Optional[np.ndarray]]]:
204    """Extract prompts for SAM from a napari shape layer.
205
206    Extracts the bounding box for 'rectangle' shapes and the bounding box and corresponding mask
207    for 'ellipse' and 'polygon' shapes.
208
209    Args:
210        prompt_layer: The napari shape layer.
211        shape: The image shape.
212        i: Index for the data (required for 3d or timeseries data).
213        track_id: Id of the current track (required for tracking data).
214
215    Returns:
216        The box prompts.
217        The mask prompts.
218    """
219
220    def _to_prompts(shape_data, shape_types):
221        boxes, masks = [], []
222
223        for data, type_ in zip(shape_data, shape_types):
224
225            if type_ == "rectangle":
226                boxes.append(data)
227                masks.append(None)
228
229            elif type_ == "ellipse":
230                boxes.append(data)
231                center = np.mean(data, axis=0)
232                radius_r = ((data[2] - data[1]) / 2)[0]
233                radius_c = ((data[1] - data[0]) / 2)[1]
234                rr, cc = draw.ellipse(center[0], center[1], radius_r, radius_c, shape=shape)
235                mask = np.zeros(shape, dtype=bool)
236                mask[rr, cc] = 1
237                masks.append(mask)
238
239            elif type_ == "polygon":
240                boxes.append(data)
241                rr, cc = draw.polygon(data[:, 0], data[:, 1], shape=shape)
242                mask = np.zeros(shape, dtype=bool)
243                mask[rr, cc] = 1
244                masks.append(mask)
245
246            else:
247                warnings.warn(f"Shape type {type_} is not supported and will be ignored.")
248
249        # map to correct box format
250        boxes = [
251            np.array([box[:, 0].min(), box[:, 1].min(), box[:, 0].max(), box[:, 1].max()]) for box in boxes
252        ]
253        return boxes, masks
254
255    shape_data, shape_types = layer.data, layer.shape_type
256    assert len(shape_data) == len(shape_types)
257    if len(shape_data) == 0:
258        return [], []
259
260    if i is not None:
261        if track_id is None:
262            prompt_selection = [j for j, data in enumerate(shape_data) if (data[:, 0] == i).all()]
263        else:
264            track_ids = np.array(list(map(int, layer.properties["track_id"])))
265            prompt_selection = [
266                j for j, (data, this_track_id) in enumerate(zip(shape_data, track_ids))
267                if ((data[:, 0] == i).all() and this_track_id == track_id)
268            ]
269
270        shape_data = [shape_data[j][:, 1:] for j in prompt_selection]
271        shape_types = [shape_types[j] for j in prompt_selection]
272
273    boxes, masks = _to_prompts(shape_data, shape_types)
274    return boxes, masks

Extract prompts for SAM from a napari shape layer.

Extracts the bounding box for 'rectangle' shapes and the bounding box and corresponding mask for 'ellipse' and 'polygon' shapes.

Arguments:
  • prompt_layer: The napari shape layer.
  • shape: The image shape.
  • i: Index for the data (required for 3d or timeseries data).
  • track_id: Id of the current track (required for tracking data).
Returns:

The box prompts. The mask prompts.

def prompt_layer_to_state(prompt_layer: napari.layers.points.points.Points, i: int) -> str:
277def prompt_layer_to_state(prompt_layer: napari.layers.Points, i: int) -> str:
278    """Get the state of the track from a point layer for a given timeframe.
279
280    Only relevant for annotator_tracking.
281
282    Args:
283        prompt_layer: The napari layer.
284        i: Timeframe of the data.
285
286    Returns:
287        The state of this frame (either "division" or "track").
288    """
289    state = prompt_layer.properties["state"]
290
291    points = prompt_layer.data
292    assert points.shape[1] == 3, f"{points.shape}"
293    mask = points[:, 0] == i
294    this_points = points[mask][:, 1:]
295    this_state = state[mask]
296    assert len(this_points) == len(this_state)
297
298    # we set the state to 'division' if at least one point in this frame has a division label
299    if any(st == "division" for st in this_state):
300        return "division"
301    else:
302        return "track"

Get the state of the track from a point layer for a given timeframe.

Only relevant for annotator_tracking.

Arguments:
  • prompt_layer: The napari layer.
  • i: Timeframe of the data.
Returns:

The state of this frame (either "division" or "track").

def prompt_layers_to_state( point_layer: napari.layers.points.points.Points, box_layer: napari.layers.shapes.shapes.Shapes, i: int) -> str:
305def prompt_layers_to_state(point_layer: napari.layers.Points, box_layer: napari.layers.Shapes, i: int) -> str:
306    """Get the state of the track from a point layer and shape layer for a given timeframe.
307
308    Only relevant for annotator_tracking.
309
310    Args:
311        point_layer: The napari point layer.
312        box_layer: The napari box layer.
313        i: Timeframe of the data.
314
315    Returns:
316        The state of this frame (either "division" or "track").
317    """
318    state = point_layer.properties["state"]
319
320    points = point_layer.data
321    assert points.shape[1] == 3, f"{points.shape}"
322    mask = points[:, 0] == i
323    if mask.sum() > 0:
324        this_state = state[mask].tolist()
325    else:
326        this_state = []
327
328    box_states = box_layer.properties["state"]
329    this_box_states = [
330        state for box, state in zip(box_layer.data, box_states)
331        if (box[:, 0] == i).all()
332    ]
333    this_state.extend(this_box_states)
334
335    # we set the state to 'division' if at least one point in this frame has a division label
336    if any(st == "division" for st in this_state):
337        return "division"
338    else:
339        return "track"

Get the state of the track from a point layer and shape layer for a given timeframe.

Only relevant for annotator_tracking.

Arguments:
  • point_layer: The napari point layer.
  • box_layer: The napari box layer.
  • i: Timeframe of the data.
Returns:

The state of this frame (either "division" or "track").