micro_sam.multi_dimensional_segmentation

Multi-dimensional segmentation with segment anything.

  1"""Multi-dimensional segmentation with segment anything.
  2"""
  3
  4import os
  5import multiprocessing as mp
  6from concurrent import futures
  7from typing import Dict, List, Optional, Union, Tuple
  8
  9import networkx as nx
 10import nifty
 11import numpy as np
 12import torch
 13from scipy.ndimage import binary_closing
 14from skimage.measure import label, regionprops
 15from skimage.segmentation import relabel_sequential
 16
 17import elf.segmentation as seg_utils
 18import elf.tracking.tracking_utils as track_utils
 19from elf.tracking.motile_tracking import recolor_segmentation
 20
 21from segment_anything.predictor import SamPredictor
 22
 23try:
 24    from napari.utils import progress as tqdm
 25except ImportError:
 26    from tqdm import tqdm
 27
 28try:
 29    from trackastra.model import Trackastra
 30    from trackastra.tracking import graph_to_ctc, graph_to_napari_tracks
 31except ImportError:
 32    Trackastra = None
 33
 34
 35from . import util
 36from .prompt_based_segmentation import segment_from_mask
 37from .instance_segmentation import AMGBase, mask_data_to_segmentation
 38
 39
 40PROJECTION_MODES = ("box", "mask", "points", "points_and_mask", "single_point")
 41
 42
 43def _validate_projection(projection):
 44    use_single_point = False
 45    if isinstance(projection, str):
 46        if projection == "mask":
 47            use_box, use_mask, use_points = True, True, False
 48        elif projection == "points":
 49            use_box, use_mask, use_points = False, False, True
 50        elif projection == "box":
 51            use_box, use_mask, use_points = True, False, False
 52        elif projection == "points_and_mask":
 53            use_box, use_mask, use_points = False, True, True
 54        elif projection == "single_point":
 55            use_box, use_mask, use_points = False, False, True
 56            use_single_point = True
 57        else:
 58            raise ValueError(
 59                "Choose projection method from 'mask' / 'points' / 'box' / 'points_and_mask' / 'single_point'. "
 60                f"You have passed the invalid option {projection}."
 61            )
 62    elif isinstance(projection, dict):
 63        assert len(projection.keys()) == 3, "There should be three parameters assigned for the projection method."
 64        use_box, use_mask, use_points = projection["use_box"], projection["use_mask"], projection["use_points"]
 65    else:
 66        raise ValueError(f"{projection} is not a supported projection method.")
 67    return use_box, use_mask, use_points, use_single_point
 68
 69
 70# Advanced stopping criterions.
 71# In practice these did not make a big difference, so we do not use this at the moment.
 72# We still leave it here for reference.
 73def _advanced_stopping_criteria(
 74    z, seg_z, seg_prev, z_start, z_increment, segmentation, criterion_choice, score, increment
 75):
 76    def _compute_mean_iou_for_n_slices(z, increment, seg_z, n_slices):
 77        iou_list = [
 78            util.compute_iou(segmentation[z - increment * _slice], seg_z) for _slice in range(1, n_slices+1)
 79        ]
 80        return np.mean(iou_list)
 81
 82    if criterion_choice == 1:
 83        # 1. current metric: iou of current segmentation and the previous slice
 84        iou = util.compute_iou(seg_prev, seg_z)
 85        criterion = iou
 86
 87    elif criterion_choice == 2:
 88        # 2. combining SAM iou + iou: curr. slice & first segmented slice + iou: curr. slice vs prev. slice
 89        iou = util.compute_iou(seg_prev, seg_z)
 90        ff_iou = util.compute_iou(segmentation[z_start], seg_z)
 91        criterion = 0.5 * iou + 0.3 * score + 0.2 * ff_iou
 92
 93    elif criterion_choice == 3:
 94        # 3. iou of current segmented slice w.r.t the previous n slices
 95        criterion = _compute_mean_iou_for_n_slices(z, increment, seg_z, min(5, abs(z - z_start)))
 96
 97    return criterion
 98
 99
100def segment_mask_in_volume(
101    segmentation: np.ndarray,
102    predictor: SamPredictor,
103    image_embeddings: util.ImageEmbeddings,
104    segmented_slices: np.ndarray,
105    stop_lower: bool,
106    stop_upper: bool,
107    iou_threshold: float,
108    projection: Union[str, dict],
109    update_progress: Optional[callable] = None,
110    box_extension: float = 0.0,
111    verbose: bool = False,
112) -> Tuple[np.ndarray, Tuple[int, int]]:
113    """Segment an object mask in in volumetric data.
114
115    Args:
116        segmentation: The initial segmentation for the object.
117        predictor: The Segment Anything predictor.
118        image_embeddings: The precomputed image embeddings for the volume.
119        segmented_slices: List of slices for which this object has already been segmented.
120        stop_lower: Whether to stop at the lowest segmented slice.
121        stop_upper: Wheter to stop at the topmost segmented slice.
122        iou_threshold: The IOU threshold for continuing segmentation across 3d.
123        projection: The projection method to use. One of 'box', 'mask', 'points', 'points_and_mask' or 'single point'.
124            Pass a dictionary to choose the excact combination of projection modes.
125        update_progress: Callback to update an external progress bar.
126        box_extension: Extension factor for increasing the box size after projection.
127            By default, does not increase the projected box size.
128        verbose: Whether to print details about the segmentation steps. By default, set to 'True'.
129
130    Returns:
131        Array with the volumetric segmentation.
132        Tuple with the first and last segmented slice.
133    """
134    use_box, use_mask, use_points, use_single_point = _validate_projection(projection)
135
136    if update_progress is None:
137        def update_progress(*args):
138            pass
139
140    def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None, verbose=False):
141        z = z_start + increment
142        while True:
143            if verbose:
144                print(f"Segment {z_start} to {z_stop}: segmenting slice {z}")
145            seg_prev = segmentation[z - increment]
146            seg_z, score, _ = segment_from_mask(
147                predictor, seg_prev, image_embeddings=image_embeddings, i=z, use_mask=use_mask,
148                use_box=use_box, use_points=use_points, box_extension=box_extension, return_all=True,
149                use_single_point=use_single_point,
150            )
151            if threshold is not None:
152                iou = util.compute_iou(seg_prev, seg_z)
153                if iou < threshold:
154                    if verbose:
155                        msg = f"Segmentation stopped at slice {z} due to IOU {iou} < {threshold}."
156                        print(msg)
157                    break
158
159            segmentation[z] = seg_z
160            z += increment
161            if stopping_criterion(z, z_stop):
162                if verbose:
163                    print(f"Segment {z_start} to {z_stop}: stop at slice {z}")
164                break
165            update_progress(1)
166
167        return z - increment
168
169    z0, z1 = int(segmented_slices.min()), int(segmented_slices.max())
170
171    # segment below the min slice
172    if z0 > 0 and not stop_lower:
173        z_min = segment_range(z0, 0, -1, np.less, iou_threshold, verbose=verbose)
174    else:
175        z_min = z0
176
177    # segment above the max slice
178    if z1 < segmentation.shape[0] - 1 and not stop_upper:
179        z_max = segment_range(z1, segmentation.shape[0] - 1, 1, np.greater, iou_threshold, verbose=verbose)
180    else:
181        z_max = z1
182
183    # segment in between min and max slice
184    if z0 != z1:
185        for z_start, z_stop in zip(segmented_slices[:-1], segmented_slices[1:]):
186            slice_diff = z_stop - z_start
187            z_mid = int((z_start + z_stop) // 2)
188
189            if slice_diff == 1:  # the slices are adjacent -> we don't need to do anything
190                pass
191
192            elif z_start == z0 and stop_lower:  # the lower slice is stop: we just segment from upper
193                segment_range(z_stop, z_start, -1, np.less_equal, verbose=verbose)
194
195            elif z_stop == z1 and stop_upper:  # the upper slice is stop: we just segment from lower
196                segment_range(z_start, z_stop, 1, np.greater_equal, verbose=verbose)
197
198            elif slice_diff == 2:  # there is only one slice in between -> use combined mask
199                z = z_start + 1
200                seg_prompt = np.logical_or(segmentation[z_start] == 1, segmentation[z_stop] == 1)
201                segmentation[z] = segment_from_mask(
202                    predictor, seg_prompt, image_embeddings=image_embeddings, i=z,
203                    use_mask=use_mask, use_box=use_box, use_points=use_points,
204                    box_extension=box_extension
205                )
206                update_progress(1)
207
208            else:  # there is a range of more than 2 slices in between -> segment ranges
209                # segment from bottom
210                segment_range(
211                    z_start, z_mid, 1, np.greater_equal if slice_diff % 2 == 0 else np.greater, verbose=verbose
212                )
213                # segment from top
214                segment_range(z_stop, z_mid, -1, np.less_equal, verbose=verbose)
215                # if the difference between start and stop is even,
216                # then we have a slice in the middle that is the same distance from top bottom
217                # in this case the slice is not segmented in the ranges above, and we segment it
218                # using the combined mask from the adjacent top and bottom slice as prompt
219                if slice_diff % 2 == 0:
220                    seg_prompt = np.logical_or(segmentation[z_mid - 1] == 1, segmentation[z_mid + 1] == 1)
221                    segmentation[z_mid] = segment_from_mask(
222                        predictor, seg_prompt, image_embeddings=image_embeddings, i=z_mid,
223                        use_mask=use_mask, use_box=use_box, use_points=use_points,
224                        box_extension=box_extension
225                    )
226                    update_progress(1)
227
228    return segmentation, (z_min, z_max)
229
230
231def _preprocess_closing(slice_segmentation, gap_closing, pbar_update):
232    binarized = slice_segmentation > 0
233    # Use a structuring element that only closes elements in z, to avoid merging objects in-plane.
234    structuring_element = np.zeros((3, 1, 1))
235    structuring_element[:, 0, 0] = 1
236    closed_segmentation = binary_closing(binarized, iterations=gap_closing, structure=structuring_element)
237
238    new_segmentation = np.zeros_like(slice_segmentation)
239    n_slices = new_segmentation.shape[0]
240
241    def process_slice(z, offset):
242        seg_z = slice_segmentation[z]
243
244        # Closing does not work for the first and last gap slices
245        if z < gap_closing or z >= (n_slices - gap_closing):
246            seg_z, _, _ = relabel_sequential(seg_z, offset=offset)
247            offset = int(seg_z.max()) + 1
248            return seg_z, offset
249
250        # Apply connected components to the closed segmentation.
251        closed_z = label(closed_segmentation[z])
252
253        # Map objects in the closed and initial segmentation.
254        # We take objects from the closed segmentation unless they
255        # have overlap with more than one object from the initial segmentation.
256        # This indicates wrong merging of closeby objects that we want to prevent.
257        matches = nifty.ground_truth.overlap(closed_z, seg_z)
258        matches = {
259            seg_id: matches.overlapArrays(seg_id, sorted=False)[0] for seg_id in range(1, int(closed_z.max() + 1))
260        }
261        matches = {k: v[v != 0] for k, v in matches.items()}
262
263        ids_initial, ids_closed = [], []
264        for seg_id, matched in matches.items():
265            if len(matched) > 1:
266                ids_initial.extend(matched.tolist())
267            else:
268                ids_closed.append(seg_id)
269
270        seg_new = np.zeros_like(seg_z)
271        closed_mask = np.isin(closed_z, ids_closed)
272        seg_new[closed_mask] = closed_z[closed_mask]
273
274        if ids_initial:
275            initial_mask = np.isin(seg_z, ids_initial)
276            seg_new[initial_mask] = relabel_sequential(seg_z[initial_mask], offset=seg_new.max() + 1)[0]
277
278        seg_new, _, _ = relabel_sequential(seg_new, offset=offset)
279        max_z = seg_new.max()
280        if max_z > 0:
281            offset = int(max_z) + 1
282
283        return seg_new, offset
284
285    # Further optimization: parallelize
286    offset = 1
287    for z in range(n_slices):
288        new_segmentation[z], offset = process_slice(z, offset)
289        pbar_update(1)
290
291    return new_segmentation
292
293
294def _filter_z_extent(segmentation, min_z_extent):
295    props = regionprops(segmentation)
296    filter_ids = []
297    for prop in props:
298        box = prop.bbox
299        z_extent = box[3] - box[0]
300        if z_extent < min_z_extent:
301            filter_ids.append(prop.label)
302    if filter_ids:
303        segmentation[np.isin(segmentation, filter_ids)] = 0
304    return segmentation
305
306
307def merge_instance_segmentation_3d(
308    slice_segmentation: np.ndarray,
309    beta: float = 0.5,
310    with_background: bool = True,
311    gap_closing: Optional[int] = None,
312    min_z_extent: Optional[int] = None,
313    verbose: bool = True,
314    pbar_init: Optional[callable] = None,
315    pbar_update: Optional[callable] = None,
316) -> np.ndarray:
317    """Merge stacked 2d instance segmentations into a consistent 3d segmentation.
318
319    Solves a multicut problem based on the overlap of objects to merge across z.
320
321    Args:
322        slice_segmentation: The stacked segmentation across the slices.
323            We assume that the segmentation is labeled consecutive across z.
324        beta: The bias term for the multicut. Higher values lead to a larger
325            degree of over-segmentation and vice versa. by default, set to '0.5'.
326        with_background: Whether this is a segmentation problem with background.
327            In that case all edges connecting to the background are set to be repulsive.
328            By default, set to 'True'.
329        gap_closing: If given, gaps in the segmentation are closed with a binary closing
330            operation. The value is used to determine the number of iterations for the closing.
331        min_z_extent: Require a minimal extent in z for the segmented objects.
332            This can help to prevent segmentation artifacts.
333        verbose: Verbosity flag. By default, set to 'True'.
334        pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
335            Can be used together with pbar_update to handle napari progress bar in other thread.
336            To enables using this function within a threadworker.
337        pbar_update: Callback to update an external progress bar.
338
339    Returns:
340        The merged segmentation.
341    """
342    _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
343
344    if gap_closing is not None and gap_closing > 0:
345        pbar_init(slice_segmentation.shape[0] + 1, "Merge segmentation")
346        slice_segmentation = _preprocess_closing(slice_segmentation, gap_closing, pbar_update)
347    else:
348        pbar_init(1, "Merge segmentation")
349
350    # Extract the overlap between slices.
351    edges = track_utils.compute_edges_from_overlap(slice_segmentation, verbose=False)
352
353    uv_ids = np.array([[edge["source"], edge["target"]] for edge in edges])
354    overlaps = np.array([edge["score"] for edge in edges])
355
356    n_nodes = int(slice_segmentation.max() + 1)
357    graph = nifty.graph.undirectedGraph(n_nodes)
358    graph.insertEdges(uv_ids)
359
360    costs = seg_utils.multicut.compute_edge_costs(overlaps)
361    # Set background weights to be maximally repulsive.
362    if with_background:
363        bg_edges = (uv_ids == 0).any(axis=1)
364        costs[bg_edges] = -8.0
365
366    node_labels = seg_utils.multicut.multicut_decomposition(graph, 1.0 - costs, beta=beta)
367
368    segmentation = nifty.tools.take(node_labels, slice_segmentation)
369    if min_z_extent is not None and min_z_extent > 0:
370        segmentation = _filter_z_extent(segmentation, min_z_extent)
371
372    pbar_update(1)
373    pbar_close()
374
375    return segmentation
376
377
378def _segment_slices(
379    data, predictor, segmentor, embedding_path, verbose, tile_shape, halo, with_background=True, batch_size=1, **kwargs
380):
381    assert data.ndim == 3
382
383    min_object_size = kwargs.pop("min_object_size", 0)
384    image_embeddings = util.precompute_image_embeddings(
385        predictor=predictor,
386        input_=data,
387        save_path=embedding_path,
388        ndim=3,
389        tile_shape=tile_shape,
390        halo=halo,
391        verbose=verbose,
392        batch_size=batch_size,
393    )
394
395    offset = 0
396    segmentation = np.zeros(data.shape, dtype="uint32")
397
398    for i in tqdm(range(segmentation.shape[0]), desc="Segment slices", disable=not verbose):
399        segmentor.initialize(data[i], image_embeddings=image_embeddings, verbose=False, i=i)
400        seg = segmentor.generate(**kwargs)
401
402        if isinstance(seg, list) and len(seg) == 0:
403            continue
404        else:
405            if isinstance(seg, list):
406                seg = mask_data_to_segmentation(
407                    seg, with_background=with_background, min_object_size=min_object_size
408                )
409
410            # Set offset for instance per slice.
411            max_z = int(seg.max())
412            if max_z == 0:
413                continue
414            seg[seg != 0] += offset
415            offset = max_z + offset
416
417        segmentation[i] = seg
418
419    return segmentation, image_embeddings
420
421
422def automatic_3d_segmentation(
423    volume: np.ndarray,
424    predictor: SamPredictor,
425    segmentor: AMGBase,
426    embedding_path: Optional[Union[str, os.PathLike]] = None,
427    with_background: bool = True,
428    gap_closing: Optional[int] = None,
429    min_z_extent: Optional[int] = None,
430    tile_shape: Optional[Tuple[int, int]] = None,
431    halo: Optional[Tuple[int, int]] = None,
432    verbose: bool = True,
433    return_embeddings: bool = False,
434    batch_size: int = 1,
435    **kwargs,
436) -> np.ndarray:
437    """Automatically segment objects in a volume.
438
439    First segments slices individually in 2d and then merges them across 3d
440    based on overlap of objects between slices.
441
442    Args:
443        volume: The input volume.
444        predictor: The Segment Anything predictor.
445        segmentor: The instance segmentation class.
446        embedding_path: The path to save pre-computed embeddings.
447        with_background: Whether the segmentation has background. By default, set to 'True'.
448        gap_closing: If given, gaps in the segmentation are closed with a binary closing
449            operation. The value is used to determine the number of iterations for the closing.
450        min_z_extent: Require a minimal extent in z for the segmented objects.
451            This can help to prevent segmentation artifacts.
452        tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
453        halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
454        verbose: Verbosity flag. By default, set to 'True'.
455        return_embeddings: Whether to return the precomputed image embeddings. By default, set to 'False'.
456        batch_size: The batch size to compute image embeddings over planes. By default, set to '1'.
457        kwargs: Keyword arguments for the 'generate' method of the 'segmentor'.
458
459    Returns:
460        The segmentation.
461    """
462    segmentation, image_embeddings = _segment_slices(
463        data=volume,
464        predictor=predictor,
465        segmentor=segmentor,
466        embedding_path=embedding_path,
467        verbose=verbose,
468        tile_shape=tile_shape,
469        halo=halo,
470        with_background=with_background,
471        batch_size=batch_size,
472        **kwargs
473    )
474    segmentation = merge_instance_segmentation_3d(
475        segmentation,
476        beta=0.5,
477        with_background=with_background,
478        gap_closing=gap_closing,
479        min_z_extent=min_z_extent,
480        verbose=verbose,
481    )
482    if return_embeddings:
483        return segmentation, image_embeddings
484    else:
485        return segmentation
486
487
488def _filter_tracks(tracking_result, min_track_length):
489    props = regionprops(tracking_result)
490    discard_ids = []
491    for prop in props:
492        label_id = prop.label
493        z_start, z_stop = prop.bbox[0], prop.bbox[3]
494        if z_stop - z_start < min_track_length:
495            discard_ids.append(label_id)
496    tracking_result[np.isin(tracking_result, discard_ids)] = 0
497    tracking_result, _, _ = relabel_sequential(tracking_result)
498    return tracking_result
499
500
501def _extract_tracks_and_lineages(segmentations, track_data, parent_graph):
502    # The track data has the following layout: n_tracks x 4
503    # With the following columns:
504    # track_id - id of the track (= result from trackastra)
505    # timepoint
506    # y coordinate
507    # x coordinate
508
509    # Use the last three columns to index the segmentation and get the segmentation id.
510    index = np.round(track_data[:, 1:], 0).astype("int32")
511    index = tuple(index[:, i] for i in range(index.shape[1]))
512    segmentation_ids = segmentations[index]
513
514    # Find the mapping of nodes (= segmented objects) to track-ids.
515    track_ids = track_data[:, 0].astype("int32")
516    assert len(segmentation_ids) == len(track_ids)
517    node_to_track = {k: v for k, v in zip(segmentation_ids, track_ids)}
518
519    # Find the lineages as connected components in the parent graph.
520    # First, we build a proper graph.
521    lineage_graph = nx.Graph()
522    for k, v in parent_graph.items():
523        lineage_graph.add_edge(k, v)
524
525    # Then, find the connected components, and compute the lineage representation expected by micro-sam from it:
526    # E.g. if we have three lineages, the first consisting of three tracks and the second and third of one track each:
527    # [
528    #   {1: [2, 3]},  lineage with a dividing cell
529    #   {4: []}, lineage with just one cell
530    #   {5: []}, lineage with just one cell
531    # ]
532
533    # First, we fill the lineages which have one or more divisions, i.e. trees with more than one node.
534    lineages = []
535    for component in nx.connected_components(lineage_graph):
536        root = next(iter(component))
537        lineage_dict = {}
538
539        def dfs(node, parent):
540            # Avoid revisiting the parent node
541            children = [n for n in lineage_graph[node] if n != parent]
542            lineage_dict[node] = children
543            for child in children:
544                dfs(child, node)
545
546        dfs(root, None)
547        lineages.append(lineage_dict)
548
549    # Then add single node lineages, which are not reflected in the original graph.
550    all_tracks = set(track_ids.tolist())
551    lineage_tracks = []
552    for lineage in lineages:
553        for k, v in lineage.items():
554            lineage_tracks.append(k)
555            lineage_tracks.extend(v)
556    singleton_tracks = list(all_tracks - set(lineage_tracks))
557    lineages.extend([{track: []} for track in singleton_tracks])
558
559    # Make sure node_to_track contains everything.
560    all_seg_ids = np.unique(segmentations)
561    missing_seg_ids = np.setdiff1d(all_seg_ids, list(node_to_track.keys()))
562    node_to_track.update({seg_id: 0 for seg_id in missing_seg_ids})
563    return node_to_track, lineages
564
565
566def _filter_lineages(lineages, tracking_result):
567    track_ids = set(np.unique(tracking_result)) - {0}
568    filtered_lineages = []
569    for lineage in lineages:
570        filtered_lineage = {k: v for k, v in lineage.items() if k in track_ids}
571        if filtered_lineage:
572            filtered_lineages.append(filtered_lineage)
573    return filtered_lineages
574
575
576def _tracking_impl(timeseries, segmentation, mode, min_time_extent, output_folder=None):
577    device = "cuda" if torch.cuda.is_available() else "cpu"
578    model = Trackastra.from_pretrained("general_2d", device=device)
579    lineage_graph = model.track(timeseries, segmentation, mode=mode)
580    track_data, parent_graph, _ = graph_to_napari_tracks(lineage_graph)
581    node_to_track, lineages = _extract_tracks_and_lineages(segmentation, track_data, parent_graph)
582    tracking_result = recolor_segmentation(segmentation, node_to_track)
583
584    if output_folder is not None:  # Store tracking results in CTC format.
585        graph_to_ctc(lineage_graph, segmentation, outdir=output_folder)
586
587    # TODO
588    # We should check if trackastra supports this already.
589    # Filter out short tracks / lineages.
590    if min_time_extent is not None and min_time_extent > 0:
591        raise NotImplementedError
592
593    # Filter out pruned lineages.
594    # Mmay either be missing due to track filtering or non-consectutive track numbering in trackastra.
595    lineages = _filter_lineages(lineages, tracking_result)
596
597    return tracking_result, lineages
598
599
600def track_across_frames(
601    timeseries: np.ndarray,
602    segmentation: np.ndarray,
603    gap_closing: Optional[int] = None,
604    min_time_extent: Optional[int] = None,
605    verbose: bool = True,
606    pbar_init: Optional[callable] = None,
607    pbar_update: Optional[callable] = None,
608    output_folder: Optional[Union[os.PathLike, str]] = None,
609) -> Tuple[np.ndarray, List[Dict]]:
610    """Track segmented objects over time.
611
612    This function uses Trackastra: https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf
613    for tracking. Please cite it if you use the automated tracking functionality.
614
615    Args:
616        timeseries: The input timeseries of images.
617        segmentation: The segmentation. Expect segmentation results per frame
618            that are relabeled so that segmentation ids don't overlap.
619        gap_closing: If given, gaps in the segmentation are closed with a binary closing
620            operation. The value is used to determine the number of iterations for the closing.
621        min_time_extent: Require a minimal extent in time for the tracked objects.
622        verbose: Verbosity flag. By default, set to 'True'.
623        pbar_init: Function to initialize the progress bar.
624        pbar_update: Function to update the progress bar.
625        output_folder: The folder where the tracking results are stored in CTC format.
626
627    Returns:
628        The tracking result. Each object is colored by its track id.
629        The lineages, which correspond to the cell divisions. Lineages are represented by a list of dicts,
630            with each dict encoding a lineage, where keys correspond to parent track ids.
631            Each key either maps to a list with two child track ids (cell division) or to an empty list (no division).
632    """
633    _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init=pbar_init, pbar_update=pbar_update)
634
635    if gap_closing is not None and gap_closing > 0:
636        segmentation = _preprocess_closing(segmentation, gap_closing, pbar_update)
637
638    segmentation, lineage = _tracking_impl(
639        timeseries=np.asarray(timeseries),
640        segmentation=segmentation,
641        mode="greedy",
642        min_time_extent=min_time_extent,
643        output_folder=output_folder,
644    )
645    return segmentation, lineage
646
647
648def automatic_tracking_implementation(
649    timeseries: np.ndarray,
650    predictor: SamPredictor,
651    segmentor: AMGBase,
652    embedding_path: Optional[Union[str, os.PathLike]] = None,
653    gap_closing: Optional[int] = None,
654    min_time_extent: Optional[int] = None,
655    tile_shape: Optional[Tuple[int, int]] = None,
656    halo: Optional[Tuple[int, int]] = None,
657    verbose: bool = True,
658    return_embeddings: bool = False,
659    batch_size: int = 1,
660    output_folder: Optional[Union[os.PathLike, str]] = None,
661    **kwargs,
662) -> Tuple[np.ndarray, List[Dict]]:
663    """Automatically track objects in a timesries based on per-frame automatic segmentation.
664
665    This function uses Trackastra: https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf
666    for tracking. Please cite it if you use the automated tracking functionality.
667
668    Args:
669        timeseries: The input timeseries of images.
670        predictor: The SAM model.
671        segmentor: The instance segmentation class.
672        embedding_path: The path to save pre-computed embeddings.
673        gap_closing: If given, gaps in the segmentation are closed with a binary closing
674            operation. The value is used to determine the number of iterations for the closing.
675        min_time_extent: Require a minimal extent in time for the tracked objects.
676        tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
677        halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
678        verbose: Verbosity flag. By default, set to 'True'.
679        return_embeddings: Whether to return the precomputed image embeddings. By default, set to 'False'.
680        batch_size: The batch size to compute image embeddings over planes. By default, set to '1'.
681        output_folder: The folder where the tracking results are stored in CTC format.
682        kwargs: Keyword arguments for the 'generate' method of the 'segmentor'.
683
684    Returns:
685        The tracking result. Each object is colored by its track id.
686        The lineages, which correspond to the cell divisions. Lineages are represented by a list of dicts,
687            with each dict encoding a lineage, where keys correspond to parent track ids.
688            Each key either maps to a list with two child track ids (cell division) or to an empty list (no division).
689    """
690    if Trackastra is None:
691        raise RuntimeError(
692            "Automatic tracking requires trackastra. You can install it via 'pip install trackastra'."
693        )
694
695    segmentation, image_embeddings = _segment_slices(
696        timeseries, predictor, segmentor, embedding_path, verbose,
697        tile_shape=tile_shape, halo=halo, batch_size=batch_size,
698        **kwargs,
699    )
700
701    segmentation, lineage = track_across_frames(
702        timeseries=timeseries,
703        segmentation=segmentation,
704        gap_closing=gap_closing,
705        min_time_extent=min_time_extent,
706        verbose=verbose,
707        output_folder=output_folder,
708    )
709
710    if return_embeddings:
711        return segmentation, lineage, image_embeddings
712    else:
713        return segmentation, lineage
714
715
716def get_napari_track_data(
717    segmentation: np.ndarray, lineages: List[Dict], n_threads: Optional[int] = None
718) -> Tuple[np.ndarray, Dict[int, List]]:
719    """Derive the inputs for the napari tracking layer from a tracking result.
720
721    Args:
722        segmentation: The segmentation, after relabeling with track ids.
723        lineages: The lineage information.
724        n_threads: Number of threads for extracting the track data from the segmentation.
725
726    Returns:
727        The array with the track data expected by napari.
728        The parent dictionary for napari.
729    """
730    if n_threads is None:
731        n_threads = mp.cpu_count()
732
733    def compute_props(t):
734        props = regionprops(segmentation[t])
735        # Create the track data representation for napari, which expects:
736        # track_id, timepoint, y, x
737        track_data = np.array([[prop.label, t] + list(prop.centroid) for prop in props])
738        return track_data
739
740    with futures.ThreadPoolExecutor(n_threads) as tp:
741        track_data = list(tp.map(compute_props, range(segmentation.shape[0])))
742    track_data = [data for data in track_data if data.size > 0]
743    track_data = np.concatenate(track_data)
744
745    # The graph representation of napari uses the children as keys and the parents as values,
746    # whereas our representation uses parents as keys and children as values.
747    # Hence, we need to translate the representation.
748    parent_graph = {
749        child: [parent] for lineage in lineages for parent, children in lineage.items() for child in children
750    }
751
752    return track_data, parent_graph
PROJECTION_MODES = ('box', 'mask', 'points', 'points_and_mask', 'single_point')
def segment_mask_in_volume( segmentation: numpy.ndarray, predictor: segment_anything.predictor.SamPredictor, image_embeddings: Dict[str, Any], segmented_slices: numpy.ndarray, stop_lower: bool, stop_upper: bool, iou_threshold: float, projection: Union[str, dict], update_progress: Optional[<built-in function callable>] = None, box_extension: float = 0.0, verbose: bool = False) -> Tuple[numpy.ndarray, Tuple[int, int]]:
101def segment_mask_in_volume(
102    segmentation: np.ndarray,
103    predictor: SamPredictor,
104    image_embeddings: util.ImageEmbeddings,
105    segmented_slices: np.ndarray,
106    stop_lower: bool,
107    stop_upper: bool,
108    iou_threshold: float,
109    projection: Union[str, dict],
110    update_progress: Optional[callable] = None,
111    box_extension: float = 0.0,
112    verbose: bool = False,
113) -> Tuple[np.ndarray, Tuple[int, int]]:
114    """Segment an object mask in in volumetric data.
115
116    Args:
117        segmentation: The initial segmentation for the object.
118        predictor: The Segment Anything predictor.
119        image_embeddings: The precomputed image embeddings for the volume.
120        segmented_slices: List of slices for which this object has already been segmented.
121        stop_lower: Whether to stop at the lowest segmented slice.
122        stop_upper: Wheter to stop at the topmost segmented slice.
123        iou_threshold: The IOU threshold for continuing segmentation across 3d.
124        projection: The projection method to use. One of 'box', 'mask', 'points', 'points_and_mask' or 'single point'.
125            Pass a dictionary to choose the excact combination of projection modes.
126        update_progress: Callback to update an external progress bar.
127        box_extension: Extension factor for increasing the box size after projection.
128            By default, does not increase the projected box size.
129        verbose: Whether to print details about the segmentation steps. By default, set to 'True'.
130
131    Returns:
132        Array with the volumetric segmentation.
133        Tuple with the first and last segmented slice.
134    """
135    use_box, use_mask, use_points, use_single_point = _validate_projection(projection)
136
137    if update_progress is None:
138        def update_progress(*args):
139            pass
140
141    def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None, verbose=False):
142        z = z_start + increment
143        while True:
144            if verbose:
145                print(f"Segment {z_start} to {z_stop}: segmenting slice {z}")
146            seg_prev = segmentation[z - increment]
147            seg_z, score, _ = segment_from_mask(
148                predictor, seg_prev, image_embeddings=image_embeddings, i=z, use_mask=use_mask,
149                use_box=use_box, use_points=use_points, box_extension=box_extension, return_all=True,
150                use_single_point=use_single_point,
151            )
152            if threshold is not None:
153                iou = util.compute_iou(seg_prev, seg_z)
154                if iou < threshold:
155                    if verbose:
156                        msg = f"Segmentation stopped at slice {z} due to IOU {iou} < {threshold}."
157                        print(msg)
158                    break
159
160            segmentation[z] = seg_z
161            z += increment
162            if stopping_criterion(z, z_stop):
163                if verbose:
164                    print(f"Segment {z_start} to {z_stop}: stop at slice {z}")
165                break
166            update_progress(1)
167
168        return z - increment
169
170    z0, z1 = int(segmented_slices.min()), int(segmented_slices.max())
171
172    # segment below the min slice
173    if z0 > 0 and not stop_lower:
174        z_min = segment_range(z0, 0, -1, np.less, iou_threshold, verbose=verbose)
175    else:
176        z_min = z0
177
178    # segment above the max slice
179    if z1 < segmentation.shape[0] - 1 and not stop_upper:
180        z_max = segment_range(z1, segmentation.shape[0] - 1, 1, np.greater, iou_threshold, verbose=verbose)
181    else:
182        z_max = z1
183
184    # segment in between min and max slice
185    if z0 != z1:
186        for z_start, z_stop in zip(segmented_slices[:-1], segmented_slices[1:]):
187            slice_diff = z_stop - z_start
188            z_mid = int((z_start + z_stop) // 2)
189
190            if slice_diff == 1:  # the slices are adjacent -> we don't need to do anything
191                pass
192
193            elif z_start == z0 and stop_lower:  # the lower slice is stop: we just segment from upper
194                segment_range(z_stop, z_start, -1, np.less_equal, verbose=verbose)
195
196            elif z_stop == z1 and stop_upper:  # the upper slice is stop: we just segment from lower
197                segment_range(z_start, z_stop, 1, np.greater_equal, verbose=verbose)
198
199            elif slice_diff == 2:  # there is only one slice in between -> use combined mask
200                z = z_start + 1
201                seg_prompt = np.logical_or(segmentation[z_start] == 1, segmentation[z_stop] == 1)
202                segmentation[z] = segment_from_mask(
203                    predictor, seg_prompt, image_embeddings=image_embeddings, i=z,
204                    use_mask=use_mask, use_box=use_box, use_points=use_points,
205                    box_extension=box_extension
206                )
207                update_progress(1)
208
209            else:  # there is a range of more than 2 slices in between -> segment ranges
210                # segment from bottom
211                segment_range(
212                    z_start, z_mid, 1, np.greater_equal if slice_diff % 2 == 0 else np.greater, verbose=verbose
213                )
214                # segment from top
215                segment_range(z_stop, z_mid, -1, np.less_equal, verbose=verbose)
216                # if the difference between start and stop is even,
217                # then we have a slice in the middle that is the same distance from top bottom
218                # in this case the slice is not segmented in the ranges above, and we segment it
219                # using the combined mask from the adjacent top and bottom slice as prompt
220                if slice_diff % 2 == 0:
221                    seg_prompt = np.logical_or(segmentation[z_mid - 1] == 1, segmentation[z_mid + 1] == 1)
222                    segmentation[z_mid] = segment_from_mask(
223                        predictor, seg_prompt, image_embeddings=image_embeddings, i=z_mid,
224                        use_mask=use_mask, use_box=use_box, use_points=use_points,
225                        box_extension=box_extension
226                    )
227                    update_progress(1)
228
229    return segmentation, (z_min, z_max)

Segment an object mask in in volumetric data.

Arguments:
  • segmentation: The initial segmentation for the object.
  • predictor: The Segment Anything predictor.
  • image_embeddings: The precomputed image embeddings for the volume.
  • segmented_slices: List of slices for which this object has already been segmented.
  • stop_lower: Whether to stop at the lowest segmented slice.
  • stop_upper: Wheter to stop at the topmost segmented slice.
  • iou_threshold: The IOU threshold for continuing segmentation across 3d.
  • projection: The projection method to use. One of 'box', 'mask', 'points', 'points_and_mask' or 'single point'. Pass a dictionary to choose the excact combination of projection modes.
  • update_progress: Callback to update an external progress bar.
  • box_extension: Extension factor for increasing the box size after projection. By default, does not increase the projected box size.
  • verbose: Whether to print details about the segmentation steps. By default, set to 'True'.
Returns:

Array with the volumetric segmentation. Tuple with the first and last segmented slice.

def merge_instance_segmentation_3d( slice_segmentation: numpy.ndarray, beta: float = 0.5, with_background: bool = True, gap_closing: Optional[int] = None, min_z_extent: Optional[int] = None, verbose: bool = True, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None) -> numpy.ndarray:
308def merge_instance_segmentation_3d(
309    slice_segmentation: np.ndarray,
310    beta: float = 0.5,
311    with_background: bool = True,
312    gap_closing: Optional[int] = None,
313    min_z_extent: Optional[int] = None,
314    verbose: bool = True,
315    pbar_init: Optional[callable] = None,
316    pbar_update: Optional[callable] = None,
317) -> np.ndarray:
318    """Merge stacked 2d instance segmentations into a consistent 3d segmentation.
319
320    Solves a multicut problem based on the overlap of objects to merge across z.
321
322    Args:
323        slice_segmentation: The stacked segmentation across the slices.
324            We assume that the segmentation is labeled consecutive across z.
325        beta: The bias term for the multicut. Higher values lead to a larger
326            degree of over-segmentation and vice versa. by default, set to '0.5'.
327        with_background: Whether this is a segmentation problem with background.
328            In that case all edges connecting to the background are set to be repulsive.
329            By default, set to 'True'.
330        gap_closing: If given, gaps in the segmentation are closed with a binary closing
331            operation. The value is used to determine the number of iterations for the closing.
332        min_z_extent: Require a minimal extent in z for the segmented objects.
333            This can help to prevent segmentation artifacts.
334        verbose: Verbosity flag. By default, set to 'True'.
335        pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
336            Can be used together with pbar_update to handle napari progress bar in other thread.
337            To enables using this function within a threadworker.
338        pbar_update: Callback to update an external progress bar.
339
340    Returns:
341        The merged segmentation.
342    """
343    _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
344
345    if gap_closing is not None and gap_closing > 0:
346        pbar_init(slice_segmentation.shape[0] + 1, "Merge segmentation")
347        slice_segmentation = _preprocess_closing(slice_segmentation, gap_closing, pbar_update)
348    else:
349        pbar_init(1, "Merge segmentation")
350
351    # Extract the overlap between slices.
352    edges = track_utils.compute_edges_from_overlap(slice_segmentation, verbose=False)
353
354    uv_ids = np.array([[edge["source"], edge["target"]] for edge in edges])
355    overlaps = np.array([edge["score"] for edge in edges])
356
357    n_nodes = int(slice_segmentation.max() + 1)
358    graph = nifty.graph.undirectedGraph(n_nodes)
359    graph.insertEdges(uv_ids)
360
361    costs = seg_utils.multicut.compute_edge_costs(overlaps)
362    # Set background weights to be maximally repulsive.
363    if with_background:
364        bg_edges = (uv_ids == 0).any(axis=1)
365        costs[bg_edges] = -8.0
366
367    node_labels = seg_utils.multicut.multicut_decomposition(graph, 1.0 - costs, beta=beta)
368
369    segmentation = nifty.tools.take(node_labels, slice_segmentation)
370    if min_z_extent is not None and min_z_extent > 0:
371        segmentation = _filter_z_extent(segmentation, min_z_extent)
372
373    pbar_update(1)
374    pbar_close()
375
376    return segmentation

Merge stacked 2d instance segmentations into a consistent 3d segmentation.

Solves a multicut problem based on the overlap of objects to merge across z.

Arguments:
  • slice_segmentation: The stacked segmentation across the slices. We assume that the segmentation is labeled consecutive across z.
  • beta: The bias term for the multicut. Higher values lead to a larger degree of over-segmentation and vice versa. by default, set to '0.5'.
  • with_background: Whether this is a segmentation problem with background. In that case all edges connecting to the background are set to be repulsive. By default, set to 'True'.
  • gap_closing: If given, gaps in the segmentation are closed with a binary closing operation. The value is used to determine the number of iterations for the closing.
  • min_z_extent: Require a minimal extent in z for the segmented objects. This can help to prevent segmentation artifacts.
  • verbose: Verbosity flag. By default, set to 'True'.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
Returns:

The merged segmentation.

def automatic_3d_segmentation( volume: numpy.ndarray, predictor: segment_anything.predictor.SamPredictor, segmentor: micro_sam.instance_segmentation.AMGBase, embedding_path: Union[str, os.PathLike, NoneType] = None, with_background: bool = True, gap_closing: Optional[int] = None, min_z_extent: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, return_embeddings: bool = False, batch_size: int = 1, **kwargs) -> numpy.ndarray:
423def automatic_3d_segmentation(
424    volume: np.ndarray,
425    predictor: SamPredictor,
426    segmentor: AMGBase,
427    embedding_path: Optional[Union[str, os.PathLike]] = None,
428    with_background: bool = True,
429    gap_closing: Optional[int] = None,
430    min_z_extent: Optional[int] = None,
431    tile_shape: Optional[Tuple[int, int]] = None,
432    halo: Optional[Tuple[int, int]] = None,
433    verbose: bool = True,
434    return_embeddings: bool = False,
435    batch_size: int = 1,
436    **kwargs,
437) -> np.ndarray:
438    """Automatically segment objects in a volume.
439
440    First segments slices individually in 2d and then merges them across 3d
441    based on overlap of objects between slices.
442
443    Args:
444        volume: The input volume.
445        predictor: The Segment Anything predictor.
446        segmentor: The instance segmentation class.
447        embedding_path: The path to save pre-computed embeddings.
448        with_background: Whether the segmentation has background. By default, set to 'True'.
449        gap_closing: If given, gaps in the segmentation are closed with a binary closing
450            operation. The value is used to determine the number of iterations for the closing.
451        min_z_extent: Require a minimal extent in z for the segmented objects.
452            This can help to prevent segmentation artifacts.
453        tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
454        halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
455        verbose: Verbosity flag. By default, set to 'True'.
456        return_embeddings: Whether to return the precomputed image embeddings. By default, set to 'False'.
457        batch_size: The batch size to compute image embeddings over planes. By default, set to '1'.
458        kwargs: Keyword arguments for the 'generate' method of the 'segmentor'.
459
460    Returns:
461        The segmentation.
462    """
463    segmentation, image_embeddings = _segment_slices(
464        data=volume,
465        predictor=predictor,
466        segmentor=segmentor,
467        embedding_path=embedding_path,
468        verbose=verbose,
469        tile_shape=tile_shape,
470        halo=halo,
471        with_background=with_background,
472        batch_size=batch_size,
473        **kwargs
474    )
475    segmentation = merge_instance_segmentation_3d(
476        segmentation,
477        beta=0.5,
478        with_background=with_background,
479        gap_closing=gap_closing,
480        min_z_extent=min_z_extent,
481        verbose=verbose,
482    )
483    if return_embeddings:
484        return segmentation, image_embeddings
485    else:
486        return segmentation

Automatically segment objects in a volume.

First segments slices individually in 2d and then merges them across 3d based on overlap of objects between slices.

Arguments:
  • volume: The input volume.
  • predictor: The Segment Anything predictor.
  • segmentor: The instance segmentation class.
  • embedding_path: The path to save pre-computed embeddings.
  • with_background: Whether the segmentation has background. By default, set to 'True'.
  • gap_closing: If given, gaps in the segmentation are closed with a binary closing operation. The value is used to determine the number of iterations for the closing.
  • min_z_extent: Require a minimal extent in z for the segmented objects. This can help to prevent segmentation artifacts.
  • tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
  • halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
  • verbose: Verbosity flag. By default, set to 'True'.
  • return_embeddings: Whether to return the precomputed image embeddings. By default, set to 'False'.
  • batch_size: The batch size to compute image embeddings over planes. By default, set to '1'.
  • kwargs: Keyword arguments for the 'generate' method of the 'segmentor'.
Returns:

The segmentation.

def track_across_frames( timeseries: numpy.ndarray, segmentation: numpy.ndarray, gap_closing: Optional[int] = None, min_time_extent: Optional[int] = None, verbose: bool = True, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None, output_folder: Union[str, os.PathLike, NoneType] = None) -> Tuple[numpy.ndarray, List[Dict]]:
601def track_across_frames(
602    timeseries: np.ndarray,
603    segmentation: np.ndarray,
604    gap_closing: Optional[int] = None,
605    min_time_extent: Optional[int] = None,
606    verbose: bool = True,
607    pbar_init: Optional[callable] = None,
608    pbar_update: Optional[callable] = None,
609    output_folder: Optional[Union[os.PathLike, str]] = None,
610) -> Tuple[np.ndarray, List[Dict]]:
611    """Track segmented objects over time.
612
613    This function uses Trackastra: https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf
614    for tracking. Please cite it if you use the automated tracking functionality.
615
616    Args:
617        timeseries: The input timeseries of images.
618        segmentation: The segmentation. Expect segmentation results per frame
619            that are relabeled so that segmentation ids don't overlap.
620        gap_closing: If given, gaps in the segmentation are closed with a binary closing
621            operation. The value is used to determine the number of iterations for the closing.
622        min_time_extent: Require a minimal extent in time for the tracked objects.
623        verbose: Verbosity flag. By default, set to 'True'.
624        pbar_init: Function to initialize the progress bar.
625        pbar_update: Function to update the progress bar.
626        output_folder: The folder where the tracking results are stored in CTC format.
627
628    Returns:
629        The tracking result. Each object is colored by its track id.
630        The lineages, which correspond to the cell divisions. Lineages are represented by a list of dicts,
631            with each dict encoding a lineage, where keys correspond to parent track ids.
632            Each key either maps to a list with two child track ids (cell division) or to an empty list (no division).
633    """
634    _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init=pbar_init, pbar_update=pbar_update)
635
636    if gap_closing is not None and gap_closing > 0:
637        segmentation = _preprocess_closing(segmentation, gap_closing, pbar_update)
638
639    segmentation, lineage = _tracking_impl(
640        timeseries=np.asarray(timeseries),
641        segmentation=segmentation,
642        mode="greedy",
643        min_time_extent=min_time_extent,
644        output_folder=output_folder,
645    )
646    return segmentation, lineage

Track segmented objects over time.

This function uses Trackastra: https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf for tracking. Please cite it if you use the automated tracking functionality.

Arguments:
  • timeseries: The input timeseries of images.
  • segmentation: The segmentation. Expect segmentation results per frame that are relabeled so that segmentation ids don't overlap.
  • gap_closing: If given, gaps in the segmentation are closed with a binary closing operation. The value is used to determine the number of iterations for the closing.
  • min_time_extent: Require a minimal extent in time for the tracked objects.
  • verbose: Verbosity flag. By default, set to 'True'.
  • pbar_init: Function to initialize the progress bar.
  • pbar_update: Function to update the progress bar.
  • output_folder: The folder where the tracking results are stored in CTC format.
Returns:

The tracking result. Each object is colored by its track id. The lineages, which correspond to the cell divisions. Lineages are represented by a list of dicts, with each dict encoding a lineage, where keys correspond to parent track ids. Each key either maps to a list with two child track ids (cell division) or to an empty list (no division).

def automatic_tracking_implementation( timeseries: numpy.ndarray, predictor: segment_anything.predictor.SamPredictor, segmentor: micro_sam.instance_segmentation.AMGBase, embedding_path: Union[str, os.PathLike, NoneType] = None, gap_closing: Optional[int] = None, min_time_extent: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, return_embeddings: bool = False, batch_size: int = 1, output_folder: Union[str, os.PathLike, NoneType] = None, **kwargs) -> Tuple[numpy.ndarray, List[Dict]]:
649def automatic_tracking_implementation(
650    timeseries: np.ndarray,
651    predictor: SamPredictor,
652    segmentor: AMGBase,
653    embedding_path: Optional[Union[str, os.PathLike]] = None,
654    gap_closing: Optional[int] = None,
655    min_time_extent: Optional[int] = None,
656    tile_shape: Optional[Tuple[int, int]] = None,
657    halo: Optional[Tuple[int, int]] = None,
658    verbose: bool = True,
659    return_embeddings: bool = False,
660    batch_size: int = 1,
661    output_folder: Optional[Union[os.PathLike, str]] = None,
662    **kwargs,
663) -> Tuple[np.ndarray, List[Dict]]:
664    """Automatically track objects in a timesries based on per-frame automatic segmentation.
665
666    This function uses Trackastra: https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf
667    for tracking. Please cite it if you use the automated tracking functionality.
668
669    Args:
670        timeseries: The input timeseries of images.
671        predictor: The SAM model.
672        segmentor: The instance segmentation class.
673        embedding_path: The path to save pre-computed embeddings.
674        gap_closing: If given, gaps in the segmentation are closed with a binary closing
675            operation. The value is used to determine the number of iterations for the closing.
676        min_time_extent: Require a minimal extent in time for the tracked objects.
677        tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
678        halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
679        verbose: Verbosity flag. By default, set to 'True'.
680        return_embeddings: Whether to return the precomputed image embeddings. By default, set to 'False'.
681        batch_size: The batch size to compute image embeddings over planes. By default, set to '1'.
682        output_folder: The folder where the tracking results are stored in CTC format.
683        kwargs: Keyword arguments for the 'generate' method of the 'segmentor'.
684
685    Returns:
686        The tracking result. Each object is colored by its track id.
687        The lineages, which correspond to the cell divisions. Lineages are represented by a list of dicts,
688            with each dict encoding a lineage, where keys correspond to parent track ids.
689            Each key either maps to a list with two child track ids (cell division) or to an empty list (no division).
690    """
691    if Trackastra is None:
692        raise RuntimeError(
693            "Automatic tracking requires trackastra. You can install it via 'pip install trackastra'."
694        )
695
696    segmentation, image_embeddings = _segment_slices(
697        timeseries, predictor, segmentor, embedding_path, verbose,
698        tile_shape=tile_shape, halo=halo, batch_size=batch_size,
699        **kwargs,
700    )
701
702    segmentation, lineage = track_across_frames(
703        timeseries=timeseries,
704        segmentation=segmentation,
705        gap_closing=gap_closing,
706        min_time_extent=min_time_extent,
707        verbose=verbose,
708        output_folder=output_folder,
709    )
710
711    if return_embeddings:
712        return segmentation, lineage, image_embeddings
713    else:
714        return segmentation, lineage

Automatically track objects in a timesries based on per-frame automatic segmentation.

This function uses Trackastra: https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf for tracking. Please cite it if you use the automated tracking functionality.

Arguments:
  • timeseries: The input timeseries of images.
  • predictor: The SAM model.
  • segmentor: The instance segmentation class.
  • embedding_path: The path to save pre-computed embeddings.
  • gap_closing: If given, gaps in the segmentation are closed with a binary closing operation. The value is used to determine the number of iterations for the closing.
  • min_time_extent: Require a minimal extent in time for the tracked objects.
  • tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
  • halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
  • verbose: Verbosity flag. By default, set to 'True'.
  • return_embeddings: Whether to return the precomputed image embeddings. By default, set to 'False'.
  • batch_size: The batch size to compute image embeddings over planes. By default, set to '1'.
  • output_folder: The folder where the tracking results are stored in CTC format.
  • kwargs: Keyword arguments for the 'generate' method of the 'segmentor'.
Returns:

The tracking result. Each object is colored by its track id. The lineages, which correspond to the cell divisions. Lineages are represented by a list of dicts, with each dict encoding a lineage, where keys correspond to parent track ids. Each key either maps to a list with two child track ids (cell division) or to an empty list (no division).

def get_napari_track_data( segmentation: numpy.ndarray, lineages: List[Dict], n_threads: Optional[int] = None) -> Tuple[numpy.ndarray, Dict[int, List]]:
717def get_napari_track_data(
718    segmentation: np.ndarray, lineages: List[Dict], n_threads: Optional[int] = None
719) -> Tuple[np.ndarray, Dict[int, List]]:
720    """Derive the inputs for the napari tracking layer from a tracking result.
721
722    Args:
723        segmentation: The segmentation, after relabeling with track ids.
724        lineages: The lineage information.
725        n_threads: Number of threads for extracting the track data from the segmentation.
726
727    Returns:
728        The array with the track data expected by napari.
729        The parent dictionary for napari.
730    """
731    if n_threads is None:
732        n_threads = mp.cpu_count()
733
734    def compute_props(t):
735        props = regionprops(segmentation[t])
736        # Create the track data representation for napari, which expects:
737        # track_id, timepoint, y, x
738        track_data = np.array([[prop.label, t] + list(prop.centroid) for prop in props])
739        return track_data
740
741    with futures.ThreadPoolExecutor(n_threads) as tp:
742        track_data = list(tp.map(compute_props, range(segmentation.shape[0])))
743    track_data = [data for data in track_data if data.size > 0]
744    track_data = np.concatenate(track_data)
745
746    # The graph representation of napari uses the children as keys and the parents as values,
747    # whereas our representation uses parents as keys and children as values.
748    # Hence, we need to translate the representation.
749    parent_graph = {
750        child: [parent] for lineage in lineages for parent, children in lineage.items() for child in children
751    }
752
753    return track_data, parent_graph

Derive the inputs for the napari tracking layer from a tracking result.

Arguments:
  • segmentation: The segmentation, after relabeling with track ids.
  • lineages: The lineage information.
  • n_threads: Number of threads for extracting the track data from the segmentation.
Returns:

The array with the track data expected by napari. The parent dictionary for napari.