micro_sam.sam_annotator.util

  1import os
  2import pickle
  3import warnings
  4import argparse
  5from glob import glob
  6from pathlib import Path
  7from typing import List, Optional, Tuple
  8
  9import h5py
 10import napari
 11import numpy as np
 12from skimage import draw
 13from scipy.ndimage import shift
 14
 15from .. import prompt_based_segmentation, util
 16from .. import _model_settings as model_settings
 17from ..multi_dimensional_segmentation import _validate_projection
 18
 19# Green and Red
 20LABEL_COLOR_CYCLE = ["#00FF00", "#FF0000"]
 21"""@private"""
 22
 23
 24#
 25# Misc helper functions
 26#
 27
 28
 29def toggle_label(prompts):
 30    """@private"""
 31    # get the currently selected label
 32    current_properties = prompts.current_properties
 33    current_label = current_properties["label"][0]
 34    new_label = "negative" if current_label == "positive" else "positive"
 35    current_properties["label"] = np.array([new_label])
 36    prompts.current_properties = current_properties
 37    prompts.refresh()
 38    prompts.refresh_colors()
 39
 40
 41def _initialize_parser(description, with_segmentation_result=True, with_instance_segmentation=True):
 42
 43    available_models = list(util.get_model_names())
 44    available_models = ", ".join(available_models)
 45
 46    parser = argparse.ArgumentParser(description=description)
 47
 48    parser.add_argument(
 49        "-i", "--input", required=True,
 50        help="The filepath to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) "
 51        "or elf.io.open_file (e.g. hdf5, zarr, mrc). For the latter you also need to pass the 'key' parameter."
 52    )
 53    parser.add_argument(
 54        "-k", "--key",
 55        help="The key for opening data with elf.io.open_file. This is the internal path for a hdf5 or zarr container, "
 56        "for a image series it is a wild-card, e.g. '*.png' and for mrc it is 'data'."
 57    )
 58
 59    parser.add_argument(
 60        "-e", "--embedding_path",
 61        help="The filepath for saving/loading the pre-computed image embeddings. "
 62        "It is recommended to pass this argument and store the embeddings if you want to open the annotator "
 63        "multiple times for this image. Otherwise the embeddings will be recomputed every time."
 64    )
 65
 66    if with_segmentation_result:
 67        parser.add_argument(
 68            "-s", "--segmentation_result",
 69            help="Optional filepath to a precomputed segmentation. If passed this will be used to initialize the "
 70            "'committed_objects' layer. This can be useful if you want to correct an existing segmentation or if you "
 71            "have saved intermediate results from the annotator and want to continue with your annotations. "
 72            "Supports the same file formats as 'input'."
 73        )
 74        parser.add_argument(
 75            "-sk", "--segmentation_key",
 76            help="The key for opening the segmentation data. Same rules as for 'key' apply."
 77        )
 78
 79    parser.add_argument(
 80        "-m", "--model_type", default=util._DEFAULT_MODEL,
 81        help=f"The segment anything model that will be used, one of {available_models}."
 82    )
 83    parser.add_argument(
 84        "-c", "--checkpoint", default=None,
 85        help="Checkpoint from which the SAM model will be loaded loaded."
 86    )
 87    parser.add_argument(
 88        "-d", "--device", default=None,
 89        help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)."
 90        "By default the most performant available device will be selected."
 91    )
 92
 93    parser.add_argument(
 94        "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction", default=None
 95    )
 96    parser.add_argument(
 97        "--halo", nargs="+", type=int, help="The halo for using tiled prediction", default=None
 98    )
 99
100    if with_instance_segmentation:
101        parser.add_argument(
102            "--precompute_amg_state", action="store_true",
103            help="Whether to precompute the state for automatic instance segmentation. "
104            "This will lead to a longer start-up time, but the automatic instance segmentation can "
105            "be run directly once the tool has started."
106        )
107        parser.add_argument(
108            "--prefer_decoder", action="store_false",
109            help="Whether to use decoder based instance segmentation if the model "
110            "being used has an additional decoder for that purpose."
111        )
112
113    return parser
114
115
116def clear_annotations(viewer: napari.Viewer, clear_segmentations=True) -> None:
117    """@private"""
118    viewer.layers["point_prompts"].data = []
119    viewer.layers["point_prompts"].refresh()
120    if "prompts" in viewer.layers:
121        # Select all prompts and then remove them.
122        # This is how it worked before napari 0.5.
123        # viewer.layers["prompts"].data = []
124        viewer.layers["prompts"].selected_data = set(range(len(viewer.layers["prompts"].data)))
125        viewer.layers["prompts"].remove_selected()
126        viewer.layers["prompts"].refresh()
127    if not clear_segmentations:
128        return
129    viewer.layers["current_object"].data = np.zeros(viewer.layers["current_object"].data.shape, dtype="uint32")
130    viewer.layers["current_object"].refresh()
131
132
133def clear_annotations_slice(viewer: napari.Viewer, i: int, clear_segmentations=True) -> None:
134    """@private"""
135    point_prompts = viewer.layers["point_prompts"].data
136    point_prompts = point_prompts[point_prompts[:, 0] != i]
137    viewer.layers["point_prompts"].data = point_prompts
138    viewer.layers["point_prompts"].refresh()
139    if "prompts" in viewer.layers:
140        prompts = viewer.layers["prompts"].data
141        prompts = [prompt for prompt in prompts if not (prompt[:, 0] == i).all()]
142        viewer.layers["prompts"].data = prompts
143        viewer.layers["prompts"].refresh()
144    if not clear_segmentations:
145        return
146    viewer.layers["current_object"].data[i] = 0
147    viewer.layers["current_object"].refresh()
148
149
150#
151# Helper functions to extract prompts from napari layers.
152#
153
154
155def point_layer_to_prompts(
156    layer: napari.layers.Points, i=None, track_id=None, with_stop_annotation=True,
157) -> Optional[Tuple[np.ndarray, np.ndarray]]:
158    """Extract point prompts for SAM from a napari point layer.
159
160    Args:
161        layer: The point layer from which to extract the prompts.
162        i: Index for the data (required for 3d or timeseries data).
163        track_id: Id of the current track (required for tracking data).
164        with_stop_annotation: Whether a single negative point will be interpreted
165            as stop annotation or just returned as normal prompt.
166
167    Returns:
168        The point coordinates for the prompts.
169        The labels (positive or negative / 1 or 0) for the prompts.
170    """
171
172    points = layer.data
173    labels = layer.properties["label"]
174    assert len(points) == len(labels)
175
176    if i is None:
177        assert points.shape[1] == 2, f"{points.shape}"
178        this_points, this_labels = points, labels
179    else:
180        assert points.shape[1] == 3, f"{points.shape}"
181        mask = points[:, 0] == i
182        this_points = points[mask][:, 1:]
183        this_labels = labels[mask]
184    assert len(this_points) == len(this_labels)
185
186    if track_id is not None:
187        assert i is not None
188        track_ids = np.array(list(map(int, layer.properties["track_id"])))[mask]
189        track_id_mask = track_ids == track_id
190        this_labels, this_points = this_labels[track_id_mask], this_points[track_id_mask]
191    assert len(this_points) == len(this_labels)
192
193    this_labels = np.array([1 if label == "positive" else 0 for label in this_labels])
194    # a single point with a negative label is interpreted as 'stop' signal
195    # in this case we return None
196    if with_stop_annotation and (len(this_points) == 1 and this_labels[0] == 0):
197        return None
198
199    return this_points, this_labels
200
201
202def shape_layer_to_prompts(
203    layer: napari.layers.Shapes, shape: Tuple[int, int], i=None, track_id=None
204) -> Tuple[List[np.ndarray], List[Optional[np.ndarray]]]:
205    """Extract prompts for SAM from a napari shape layer.
206
207    Extracts the bounding box for 'rectangle' shapes and the bounding box and corresponding mask
208    for 'ellipse' and 'polygon' shapes.
209
210    Args:
211        prompt_layer: The napari shape layer.
212        shape: The image shape.
213        i: Index for the data (required for 3d or timeseries data).
214        track_id: Id of the current track (required for tracking data).
215
216    Returns:
217        The box prompts.
218        The mask prompts.
219    """
220
221    def _to_prompts(shape_data, shape_types):
222        boxes, masks = [], []
223
224        for data, type_ in zip(shape_data, shape_types):
225
226            if type_ == "rectangle":
227                boxes.append(data)
228                masks.append(None)
229
230            elif type_ == "ellipse":
231                boxes.append(data)
232                center = np.mean(data, axis=0)
233                radius_r = ((data[2] - data[1]) / 2)[0]
234                radius_c = ((data[1] - data[0]) / 2)[1]
235                rr, cc = draw.ellipse(center[0], center[1], radius_r, radius_c, shape=shape)
236                mask = np.zeros(shape, dtype=bool)
237                mask[rr, cc] = 1
238                masks.append(mask)
239
240            elif type_ == "polygon":
241                boxes.append(data)
242                rr, cc = draw.polygon(data[:, 0], data[:, 1], shape=shape)
243                mask = np.zeros(shape, dtype=bool)
244                mask[rr, cc] = 1
245                masks.append(mask)
246
247            else:
248                warnings.warn(f"Shape type {type_} is not supported and will be ignored.")
249
250        # map to correct box format
251        boxes = [
252            np.array([box[:, 0].min(), box[:, 1].min(), box[:, 0].max(), box[:, 1].max()]) for box in boxes
253        ]
254        return boxes, masks
255
256    shape_data, shape_types = layer.data, layer.shape_type
257    assert len(shape_data) == len(shape_types)
258    if len(shape_data) == 0:
259        return [], []
260
261    if i is not None:
262        if track_id is None:
263            prompt_selection = [j for j, data in enumerate(shape_data) if (data[:, 0] == i).all()]
264        else:
265            track_ids = np.array(list(map(int, layer.properties["track_id"])))
266            prompt_selection = [
267                j for j, (data, this_track_id) in enumerate(zip(shape_data, track_ids))
268                if ((data[:, 0] == i).all() and this_track_id == track_id)
269            ]
270
271        shape_data = [shape_data[j][:, 1:] for j in prompt_selection]
272        shape_types = [shape_types[j] for j in prompt_selection]
273
274    boxes, masks = _to_prompts(shape_data, shape_types)
275    return boxes, masks
276
277
278def prompt_layer_to_state(prompt_layer: napari.layers.Points, i: int) -> str:
279    """Get the state of the track from a point layer for a given timeframe.
280
281    Only relevant for annotator_tracking.
282
283    Args:
284        prompt_layer: The napari layer.
285        i: Timeframe of the data.
286
287    Returns:
288        The state of this frame (either "division" or "track").
289    """
290    state = prompt_layer.properties["state"]
291
292    points = prompt_layer.data
293    assert points.shape[1] == 3, f"{points.shape}"
294    mask = points[:, 0] == i
295    this_points = points[mask][:, 1:]
296    this_state = state[mask]
297    assert len(this_points) == len(this_state)
298
299    # we set the state to 'division' if at least one point in this frame has a division label
300    if any(st == "division" for st in this_state):
301        return "division"
302    else:
303        return "track"
304
305
306def prompt_layers_to_state(
307    point_layer: napari.layers.Points, box_layer: napari.layers.Shapes, i: int
308) -> str:
309    """Get the state of the track from a point layer and shape layer for a given timeframe.
310
311    Only relevant for annotator_tracking.
312
313    Args:
314        point_layer: The napari point layer.
315        box_layer: The napari box layer.
316        i: Timeframe of the data.
317
318    Returns:
319        The state of this frame (either "division" or "track").
320    """
321    state = point_layer.properties["state"]
322
323    points = point_layer.data
324    assert points.shape[1] == 3, f"{points.shape}"
325    mask = points[:, 0] == i
326    if mask.sum() > 0:
327        this_state = state[mask].tolist()
328    else:
329        this_state = []
330
331    box_states = box_layer.properties["state"]
332    this_box_states = [
333        state for box, state in zip(box_layer.data, box_states)
334        if (box[:, 0] == i).all()
335    ]
336    this_state.extend(this_box_states)
337
338    # we set the state to 'division' if at least one point in this frame has a division label
339    if any(st == "division" for st in this_state):
340        return "division"
341    else:
342        return "track"
343
344
345#
346# Helper functions to run (multi-dimensional) segmentation on napari layers.
347#
348
349
350def segment_slices_with_prompts(
351    predictor, point_prompts, box_prompts, image_embeddings, shape, track_id=None, update_progress=None,
352):
353    """@private"""
354    assert len(shape) == 3
355    image_shape = shape[1:]
356    seg = np.zeros(shape, dtype="uint32")
357
358    z_values = point_prompts.data[:, 0]
359    z_values_boxes = np.concatenate([box[:1, 0] for box in box_prompts.data]) if box_prompts.data else\
360        np.zeros(0, dtype="int")
361
362    if track_id is not None:
363        track_ids_points = np.array(list(map(int, point_prompts.properties["track_id"])))
364        assert len(track_ids_points) == len(z_values)
365        z_values = z_values[track_ids_points == track_id]
366
367        if len(z_values_boxes) > 0:
368            track_ids_boxes = np.array(list(map(int, box_prompts.properties["track_id"])))
369            assert len(track_ids_boxes) == len(z_values_boxes), f"{len(track_ids_boxes)}, {len(z_values_boxes)}"
370            z_values_boxes = z_values_boxes[track_ids_boxes == track_id]
371
372    slices = np.unique(np.concatenate([z_values, z_values_boxes])).astype("int")
373    stop_lower, stop_upper = False, False
374
375    if update_progress is None:
376        def update_progress(*args):
377            pass
378
379    for i in slices:
380        points_i = point_layer_to_prompts(point_prompts, i, track_id)
381
382        # do we end the segmentation at the outer slices?
383        if points_i is None:
384
385            if i == slices[0]:  # The bottom slice is a stop slice.
386                stop_lower = True
387                seg[i] = 0
388            elif i == slices[-1]:  # The top sloce is a stop slice.
389                stop_upper = True
390                seg[i] = 0
391            else:  # We have a stop annotation somewhere in the middle. Ignore this.
392                # Remove this slice from the annotated slices, so that it is segmented via
393                # projection in the next step.
394                slices = np.setdiff1d(slices, i)
395                print(f"You have provided a stop annotation (single red point) in slice {i},")
396                print("but you have annotated slices above or below it. This stop annotation will")
397                print(f"be ignored and the slice {i} will be segmented normally.")
398
399            update_progress(1)
400            continue
401
402        boxes, masks = shape_layer_to_prompts(box_prompts, image_shape, i=i, track_id=track_id)
403        points, labels = points_i
404
405        seg_i = prompt_segmentation(
406            predictor, points, labels, boxes, masks, image_shape, multiple_box_prompts=False,
407            image_embeddings=image_embeddings, i=i
408        )
409        if seg_i is None:
410            print(f"The prompts at slice or frame {i} are invalid and the segmentation was skipped.")
411            print("This will lead to a wrong segmentation across slices or frames.")
412            print(f"Please correct the prompts in {i} and rerun the segmentation.")
413            continue
414
415        seg[i] = seg_i
416        update_progress(1)
417
418    return seg, slices, stop_lower, stop_upper
419
420
421# For advanced batching: match prompts to already segmented objects and continue segmentation.
422def _match_prompts(previous_segmentation, points, boxes, seg_ids):
423    # Create a mapping between ids and prompts.
424    batched_prompts = {}
425    # seg_boundaries = find_boundaries(previous_segmentation, mode="inner")
426    # indices = distance_transform_edt(seg_boundaries, return_distance=False, return_index=True)
427    return batched_prompts
428
429
430def _batched_interactive_segmentation(predictor, points, labels, boxes, image_embeddings, i, previous_segmentation):
431    prev_seg = previous_segmentation if i is None else previous_segmentation[i]
432    seg = np.zeros(prev_seg.shape, dtype="uint32")
433
434    # seg_ids = np.unique(previous_segmentation)
435    # assert seg_ids[0] == 0
436
437    batched_points, batched_labels = [], []
438    negative_points, negative_labels = [], []
439    for j in range(len(points)):
440        if labels[j] == 1:  # positive point
441            batched_points.append(points[j:j+1])
442            batched_labels.append(labels[j:j+1])
443        else:  # negative points
444            negative_points.append(points[j:j+1])
445            negative_labels.append(labels[j:j+1])
446
447    batched_prompts = [(None, point, label) for point, label in zip(batched_points, batched_labels)]
448    batched_prompts.extend([(box, None, None) for box in boxes])
449    batched_prompts = {i: prompt for i, prompt in enumerate(batched_prompts, 1)}
450
451    # For advanced batching: match prompts to already segmented objects and continue segmentation.
452    # (This is left here as a reference for how this can be implemented.
453    #  I have not decided yet if this is actually a good idea or not.)
454    # # If we have no objects: this is the first call for a batched segmentation.
455    # # We treat each positive point or box as a separate obejct.
456    # if len(seg_ids) == 1:
457    #     # Create a list of all prompts.
458    #     batched_prompts = [(None, point, label) for point, label in zip(batched_points, batched_labels)]
459    #     batched_prompts.extend([(box, None, None) for box in boxes])
460    #     batched_prompts = {i: prompt for i, prompt in enumerate(batched_prompts, 1)}
461
462    # # Otherwise we match the prompts to existing objects.
463    # else:
464    #     batched_prompts = _match_prompts(prev_seg, batched_points, boxes, seg_ids)
465
466    for seg_id, prompt in batched_prompts.items():
467        box, point, label = prompt
468        if len(negative_points) > 0:
469            if point is None:
470                point, label = negative_points, negative_labels
471            else:
472                point = np.concatenate([point] + negative_points)
473                label = np.concatenate([label] + negative_labels)
474
475        if (box is not None) and (point is not None):
476            prediction = prompt_based_segmentation.segment_from_box_and_points(
477                predictor, box, point, label, image_embeddings=image_embeddings, i=i
478            ).squeeze()
479        elif (box is not None) and (point is None):
480            prediction = prompt_based_segmentation.segment_from_box(
481                predictor, box, image_embeddings=image_embeddings, i=i
482            ).squeeze()
483        else:
484            prediction = prompt_based_segmentation.segment_from_points(
485                predictor, point, label, image_embeddings=image_embeddings, i=i
486            ).squeeze()
487
488        seg[prediction] = seg_id
489
490    return seg
491
492
493def prompt_segmentation(
494    predictor, points, labels, boxes, masks, shape, multiple_box_prompts,
495    image_embeddings=None, i=None, box_extension=0, batched=None,
496    previous_segmentation=None,
497):
498    """@private"""
499    assert len(points) == len(labels)
500    have_points = len(points) > 0
501    have_boxes = len(boxes) > 0
502
503    # No prompts were given, return None.
504    if not have_points and not have_boxes:
505        return
506
507    # Batched interactive segmentation.
508    elif batched:
509        assert previous_segmentation is not None
510        seg = _batched_interactive_segmentation(
511            predictor, points, labels, boxes, image_embeddings, i, previous_segmentation
512        )
513
514    # Box and point prompts were given.
515    elif have_points and have_boxes:
516        if len(boxes) > 1:
517            print("You have provided point prompts and more than one box prompt.")
518            print("This setting is currently not supported.")
519            print("When providing both points and prompts you can only segment one object at a time.")
520            return
521        mask = masks[0]
522        if mask is None:
523            seg = prompt_based_segmentation.segment_from_box_and_points(
524                predictor, boxes[0], points, labels, image_embeddings=image_embeddings, i=i
525            ).squeeze()
526        else:
527            seg = prompt_based_segmentation.segment_from_mask(
528                predictor, mask, box=boxes[0], points=points, labels=labels, image_embeddings=image_embeddings, i=i
529            ).squeeze()
530
531    # Only point prompts were given.
532    elif have_points and not have_boxes:
533        seg = prompt_based_segmentation.segment_from_points(
534            predictor, points, labels, image_embeddings=image_embeddings, i=i
535        ).squeeze()
536
537    # Only box prompts were given.
538    elif not have_points and have_boxes:
539        seg = np.zeros(shape, dtype="uint32")
540
541        if len(boxes) > 1 and not multiple_box_prompts:
542            print("You have provided more than one box annotation. This is not yet supported in the 3d annotator.")
543            print("You can only segment one object at a time in 3d.")
544            return
545
546        # Batch this?
547        for seg_id, (box, mask) in enumerate(zip(boxes, masks), 1):
548            if mask is None:
549                prediction = prompt_based_segmentation.segment_from_box(
550                    predictor, box, image_embeddings=image_embeddings, i=i
551                ).squeeze()
552            else:
553                prediction = prompt_based_segmentation.segment_from_mask(
554                    predictor, mask, box=box, image_embeddings=image_embeddings, i=i,
555                    box_extension=box_extension,
556                ).squeeze()
557            seg[prediction] = seg_id
558
559    return seg
560
561
562def _compute_movement(seg, t0, t1):
563
564    def compute_center(t):
565        # computation with center of mass
566        center = np.where(seg[t] == 1)
567        center = np.array(np.mean(center[0]), np.mean(center[1]))
568        return center
569
570    center0 = compute_center(t0)
571    center1 = compute_center(t1)
572
573    move = center1 - center0
574    return move.astype("float64")
575
576
577def _shift_object(mask, motion_model):
578    mask_shifted = np.zeros_like(mask)
579    shift(mask, motion_model, output=mask_shifted, order=0, prefilter=False)
580    return mask_shifted
581
582
583def track_from_prompts(
584    point_prompts, box_prompts, seg, predictor, slices, image_embeddings,
585    stop_upper, threshold, projection,
586    motion_smoothing=0.5, box_extension=0, update_progress=None,
587):
588    """@private
589    """
590    use_box, use_mask, use_points, use_single_point = _validate_projection(projection)
591
592    if update_progress is None:
593        def update_progress(*args):
594            pass
595
596    # shift the segmentation based on the motion model and update the motion model
597    def _update_motion_model(seg, t, t0, motion_model):
598        if t in (t0, t0 + 1):  # this is the first or second frame, we don't have a motion yet
599            pass
600        elif t == t0 + 2:  # this the third frame, we initialize the motion model
601            current_move = _compute_movement(seg, t - 1, t - 2)
602            motion_model = current_move
603        else:  # we already have a motion model and update it
604            current_move = _compute_movement(seg, t - 1, t - 2)
605            alpha = motion_smoothing
606            motion_model = alpha * motion_model + (1 - alpha) * current_move
607
608        return motion_model
609
610    has_division = False
611    motion_model = None
612    verbose = False
613
614    t0 = int(slices.min())
615    t = t0 + 1
616    while True:
617
618        # update the motion model
619        motion_model = _update_motion_model(seg, t, t0, motion_model)
620
621        # use the segmentation from prompts if we are in a slice with prompts
622        if t in slices:
623            seg_prev = None
624            seg_t = seg[t]
625            # currently using the box layer doesn't work for keeping track of the track state
626            # track_state = prompt_layers_to_state(point_prompts, box_prompts, t)
627            track_state = prompt_layer_to_state(point_prompts, t)
628
629        # otherwise project the mask (under the motion model) and segment the next slice from the mask
630        else:
631            if verbose:
632                print(f"Tracking object in frame {t} with movement {motion_model}")
633
634            seg_prev = seg[t - 1]
635            # shift the segmentation according to the motion model
636            if motion_model is not None:
637                seg_prev = _shift_object(seg_prev, motion_model)
638
639            seg_t = prompt_based_segmentation.segment_from_mask(
640                predictor, seg_prev, image_embeddings=image_embeddings, i=t,
641                use_mask=use_mask, use_box=use_box, use_points=use_points,
642                box_extension=box_extension, use_single_point=use_single_point,
643            )
644            track_state = "track"
645
646            # are we beyond the last slice with prompt?
647            # if no: we continue tracking because we know we need to connect to a future frame
648            # if yes: we only continue tracking if overlaps are above the threshold
649            if t < slices[-1]:
650                seg_prev = None
651
652            update_progress(1)
653
654        if (threshold is not None) and (seg_prev is not None):
655            iou = util.compute_iou(seg_prev, seg_t)
656            if iou < threshold:
657                msg = f"Segmentation stopped at frame {t} due to IOU {iou} < {threshold}."
658                print(msg)
659                break
660
661        # stop if we have a division
662        if track_state == "division":
663            has_division = True
664            break
665
666        seg[t] = seg_t
667        t += 1
668
669        # stop tracking if we have stop upper set (i.e. single negative point was set to indicate stop track)
670        if t == slices[-1] and stop_upper:
671            break
672
673        # stop if we are at the last slce
674        if t == seg.shape[0]:
675            break
676
677    return seg, has_division
678
679
680def _sync_embedding_widget(widget, model_type, save_path, checkpoint_path, device, tile_shape, halo):
681    widget.model_type = model_type
682    index = widget.model_dropdown.findText(model_type)
683    if index > 0:
684        widget.model_dropdown.setCurrentIndex(index)
685
686    if save_path is not None:
687        widget.embeddings_save_path_param.setText(str(save_path))
688
689    if checkpoint_path is not None:
690        widget.custom_weights_param.setText(str(checkpoint_path))
691
692    if device is not None:
693        widget.device = device
694        index = widget.device_dropdown.findText(device)
695        widget.device_dropdown.setCurrentIndex(index)
696
697    if tile_shape is not None:
698        widget.tile_x_param.setValue(tile_shape[0])
699        widget.tile_y_param.setValue(tile_shape[1])
700
701    if halo is not None:
702        widget.halo_x_param.setValue(halo[0])
703        widget.halo_y_param.setValue(halo[1])
704
705
706# Read parameters from checkpoint path if it is given instead.
707def _sync_autosegment_widget(widget, model_type, checkpoint_path, update_decoder=None):
708    if update_decoder is not None:
709        widget._reset_segmentation_mode(update_decoder)
710
711    if widget.with_decoder:
712        settings = model_settings.AIS_SETTINGS.get(model_type, {})
713        params = ("center_distance_thresh", "boundary_distance_thresh")
714        for param in params:
715            if param in settings:
716                getattr(widget, f"{param}_param").setValue(settings[param])
717    else:
718        settings = model_settings.AMG_SETTINGS.get(model_type, {})
719        params = ("pred_iou_thresh", "stability_score_thresh", "min_object_size")
720        for param in params:
721            if param in settings:
722                getattr(widget, f"{param}_param").setValue(settings[param])
723
724
725# Read parameters from checkpoint path if it is given instead.
726def _sync_ndsegment_widget(widget, model_type, checkpoint_path):
727    settings = model_settings.ND_SEGMENT_SETTINGS.get(model_type, {})
728
729    if "projection_mode" in settings:
730        projection_mode = settings["projection_mode"]
731        widget.projection = projection_mode
732        index = widget.projection_dropdown.findText(projection_mode)
733        if index > 0:
734            widget.projection_dropdown.setCurrentIndex(index)
735
736    params = ("iou_threshold", "box_extension")
737    for param in params:
738        if param in settings:
739            getattr(widget, f"{param}_param").setValue(settings[param])
740
741
742def _load_amg_state(embedding_path):
743    if embedding_path is None or not os.path.exists(embedding_path):
744        return {"cache_folder": None}
745
746    cache_folder = os.path.join(embedding_path, "amg_state")
747    os.makedirs(cache_folder, exist_ok=True)
748    amg_state = {"cache_folder": cache_folder}
749
750    state_paths = glob(os.path.join(cache_folder, "*.pkl"))
751    for path in state_paths:
752        with open(path, "rb") as f:
753            state = pickle.load(f)
754        i = int(Path(path).stem.split("-")[-1])
755        amg_state[i] = state
756    return amg_state
757
758
759def _load_is_state(embedding_path):
760    if embedding_path is None or not os.path.exists(embedding_path):
761        return {"cache_path": None}
762
763    cache_path = os.path.join(embedding_path, "is_state.h5")
764    is_state = {"cache_path": cache_path}
765
766    with h5py.File(cache_path, "a") as f:
767        for name, g in f.items():
768            i = int(name.split("-")[-1])
769            state = {
770                "foreground": g["foreground"][:],
771                "boundary_distances": g["boundary_distances"][:],
772                "center_distances": g["center_distances"][:],
773            }
774            is_state[i] = state
775
776    return is_state
def point_layer_to_prompts( layer: napari.layers.points.points.Points, i=None, track_id=None, with_stop_annotation=True) -> Optional[Tuple[numpy.ndarray, numpy.ndarray]]:
156def point_layer_to_prompts(
157    layer: napari.layers.Points, i=None, track_id=None, with_stop_annotation=True,
158) -> Optional[Tuple[np.ndarray, np.ndarray]]:
159    """Extract point prompts for SAM from a napari point layer.
160
161    Args:
162        layer: The point layer from which to extract the prompts.
163        i: Index for the data (required for 3d or timeseries data).
164        track_id: Id of the current track (required for tracking data).
165        with_stop_annotation: Whether a single negative point will be interpreted
166            as stop annotation or just returned as normal prompt.
167
168    Returns:
169        The point coordinates for the prompts.
170        The labels (positive or negative / 1 or 0) for the prompts.
171    """
172
173    points = layer.data
174    labels = layer.properties["label"]
175    assert len(points) == len(labels)
176
177    if i is None:
178        assert points.shape[1] == 2, f"{points.shape}"
179        this_points, this_labels = points, labels
180    else:
181        assert points.shape[1] == 3, f"{points.shape}"
182        mask = points[:, 0] == i
183        this_points = points[mask][:, 1:]
184        this_labels = labels[mask]
185    assert len(this_points) == len(this_labels)
186
187    if track_id is not None:
188        assert i is not None
189        track_ids = np.array(list(map(int, layer.properties["track_id"])))[mask]
190        track_id_mask = track_ids == track_id
191        this_labels, this_points = this_labels[track_id_mask], this_points[track_id_mask]
192    assert len(this_points) == len(this_labels)
193
194    this_labels = np.array([1 if label == "positive" else 0 for label in this_labels])
195    # a single point with a negative label is interpreted as 'stop' signal
196    # in this case we return None
197    if with_stop_annotation and (len(this_points) == 1 and this_labels[0] == 0):
198        return None
199
200    return this_points, this_labels

Extract point prompts for SAM from a napari point layer.

Arguments:
  • layer: The point layer from which to extract the prompts.
  • i: Index for the data (required for 3d or timeseries data).
  • track_id: Id of the current track (required for tracking data).
  • with_stop_annotation: Whether a single negative point will be interpreted as stop annotation or just returned as normal prompt.
Returns:

The point coordinates for the prompts. The labels (positive or negative / 1 or 0) for the prompts.

def shape_layer_to_prompts( layer: napari.layers.shapes.shapes.Shapes, shape: Tuple[int, int], i=None, track_id=None) -> Tuple[List[numpy.ndarray], List[Optional[numpy.ndarray]]]:
203def shape_layer_to_prompts(
204    layer: napari.layers.Shapes, shape: Tuple[int, int], i=None, track_id=None
205) -> Tuple[List[np.ndarray], List[Optional[np.ndarray]]]:
206    """Extract prompts for SAM from a napari shape layer.
207
208    Extracts the bounding box for 'rectangle' shapes and the bounding box and corresponding mask
209    for 'ellipse' and 'polygon' shapes.
210
211    Args:
212        prompt_layer: The napari shape layer.
213        shape: The image shape.
214        i: Index for the data (required for 3d or timeseries data).
215        track_id: Id of the current track (required for tracking data).
216
217    Returns:
218        The box prompts.
219        The mask prompts.
220    """
221
222    def _to_prompts(shape_data, shape_types):
223        boxes, masks = [], []
224
225        for data, type_ in zip(shape_data, shape_types):
226
227            if type_ == "rectangle":
228                boxes.append(data)
229                masks.append(None)
230
231            elif type_ == "ellipse":
232                boxes.append(data)
233                center = np.mean(data, axis=0)
234                radius_r = ((data[2] - data[1]) / 2)[0]
235                radius_c = ((data[1] - data[0]) / 2)[1]
236                rr, cc = draw.ellipse(center[0], center[1], radius_r, radius_c, shape=shape)
237                mask = np.zeros(shape, dtype=bool)
238                mask[rr, cc] = 1
239                masks.append(mask)
240
241            elif type_ == "polygon":
242                boxes.append(data)
243                rr, cc = draw.polygon(data[:, 0], data[:, 1], shape=shape)
244                mask = np.zeros(shape, dtype=bool)
245                mask[rr, cc] = 1
246                masks.append(mask)
247
248            else:
249                warnings.warn(f"Shape type {type_} is not supported and will be ignored.")
250
251        # map to correct box format
252        boxes = [
253            np.array([box[:, 0].min(), box[:, 1].min(), box[:, 0].max(), box[:, 1].max()]) for box in boxes
254        ]
255        return boxes, masks
256
257    shape_data, shape_types = layer.data, layer.shape_type
258    assert len(shape_data) == len(shape_types)
259    if len(shape_data) == 0:
260        return [], []
261
262    if i is not None:
263        if track_id is None:
264            prompt_selection = [j for j, data in enumerate(shape_data) if (data[:, 0] == i).all()]
265        else:
266            track_ids = np.array(list(map(int, layer.properties["track_id"])))
267            prompt_selection = [
268                j for j, (data, this_track_id) in enumerate(zip(shape_data, track_ids))
269                if ((data[:, 0] == i).all() and this_track_id == track_id)
270            ]
271
272        shape_data = [shape_data[j][:, 1:] for j in prompt_selection]
273        shape_types = [shape_types[j] for j in prompt_selection]
274
275    boxes, masks = _to_prompts(shape_data, shape_types)
276    return boxes, masks

Extract prompts for SAM from a napari shape layer.

Extracts the bounding box for 'rectangle' shapes and the bounding box and corresponding mask for 'ellipse' and 'polygon' shapes.

Arguments:
  • prompt_layer: The napari shape layer.
  • shape: The image shape.
  • i: Index for the data (required for 3d or timeseries data).
  • track_id: Id of the current track (required for tracking data).
Returns:

The box prompts. The mask prompts.

def prompt_layer_to_state(prompt_layer: napari.layers.points.points.Points, i: int) -> str:
279def prompt_layer_to_state(prompt_layer: napari.layers.Points, i: int) -> str:
280    """Get the state of the track from a point layer for a given timeframe.
281
282    Only relevant for annotator_tracking.
283
284    Args:
285        prompt_layer: The napari layer.
286        i: Timeframe of the data.
287
288    Returns:
289        The state of this frame (either "division" or "track").
290    """
291    state = prompt_layer.properties["state"]
292
293    points = prompt_layer.data
294    assert points.shape[1] == 3, f"{points.shape}"
295    mask = points[:, 0] == i
296    this_points = points[mask][:, 1:]
297    this_state = state[mask]
298    assert len(this_points) == len(this_state)
299
300    # we set the state to 'division' if at least one point in this frame has a division label
301    if any(st == "division" for st in this_state):
302        return "division"
303    else:
304        return "track"

Get the state of the track from a point layer for a given timeframe.

Only relevant for annotator_tracking.

Arguments:
  • prompt_layer: The napari layer.
  • i: Timeframe of the data.
Returns:

The state of this frame (either "division" or "track").

def prompt_layers_to_state( point_layer: napari.layers.points.points.Points, box_layer: napari.layers.shapes.shapes.Shapes, i: int) -> str:
307def prompt_layers_to_state(
308    point_layer: napari.layers.Points, box_layer: napari.layers.Shapes, i: int
309) -> str:
310    """Get the state of the track from a point layer and shape layer for a given timeframe.
311
312    Only relevant for annotator_tracking.
313
314    Args:
315        point_layer: The napari point layer.
316        box_layer: The napari box layer.
317        i: Timeframe of the data.
318
319    Returns:
320        The state of this frame (either "division" or "track").
321    """
322    state = point_layer.properties["state"]
323
324    points = point_layer.data
325    assert points.shape[1] == 3, f"{points.shape}"
326    mask = points[:, 0] == i
327    if mask.sum() > 0:
328        this_state = state[mask].tolist()
329    else:
330        this_state = []
331
332    box_states = box_layer.properties["state"]
333    this_box_states = [
334        state for box, state in zip(box_layer.data, box_states)
335        if (box[:, 0] == i).all()
336    ]
337    this_state.extend(this_box_states)
338
339    # we set the state to 'division' if at least one point in this frame has a division label
340    if any(st == "division" for st in this_state):
341        return "division"
342    else:
343        return "track"

Get the state of the track from a point layer and shape layer for a given timeframe.

Only relevant for annotator_tracking.

Arguments:
  • point_layer: The napari point layer.
  • box_layer: The napari box layer.
  • i: Timeframe of the data.
Returns:

The state of this frame (either "division" or "track").