micro_sam.multi_dimensional_segmentation

Multi-dimensional segmentation with segment anything.

  1"""Multi-dimensional segmentation with segment anything.
  2"""
  3
  4import os
  5import multiprocessing as mp
  6import warnings
  7from concurrent import futures
  8from typing import Dict, List, Optional, Union, Tuple
  9
 10import networkx as nx
 11import numpy as np
 12import torch
 13from scipy.ndimage import binary_closing
 14from skimage.measure import regionprops
 15
 16from bioimage_cpp.segmentation import label, relabel_sequential
 17from bioimage_cpp.graph import UndirectedGraph
 18from bioimage_cpp.utils import segmentation_overlap
 19
 20import elf.segmentation as seg_utils
 21import elf.tracking.tracking_utils as track_utils
 22from elf.tracking.motile_tracking import recolor_segmentation
 23
 24from segment_anything.predictor import SamPredictor
 25
 26try:
 27    from napari.utils import progress as tqdm
 28except ImportError:
 29    from tqdm import tqdm
 30
 31try:
 32    from trackastra.model import Trackastra
 33    from trackastra.tracking import graph_to_ctc, graph_to_napari_tracks
 34except ImportError:
 35    Trackastra = None
 36    graph_to_ctc = None
 37    graph_to_napari_tracks = None
 38
 39
 40from . import util
 41from .prompt_based_segmentation import segment_from_mask
 42from .instance_segmentation import AMGBase
 43
 44
 45PROJECTION_MODES = ("box", "mask", "points", "points_and_mask", "single_point")
 46
 47
 48def _validate_projection(projection):
 49    use_single_point = False
 50    if isinstance(projection, str):
 51        if projection == "mask":
 52            use_box, use_mask, use_points = True, True, False
 53        elif projection == "points":
 54            use_box, use_mask, use_points = False, False, True
 55        elif projection == "box":
 56            use_box, use_mask, use_points = True, False, False
 57        elif projection == "points_and_mask":
 58            use_box, use_mask, use_points = False, True, True
 59        elif projection == "single_point":
 60            use_box, use_mask, use_points = False, False, True
 61            use_single_point = True
 62        else:
 63            raise ValueError(
 64                "Choose projection method from 'mask' / 'points' / 'box' / 'points_and_mask' / 'single_point'. "
 65                f"You have passed the invalid option {projection}."
 66            )
 67    elif isinstance(projection, dict):
 68        assert len(projection.keys()) == 3, "There should be three parameters assigned for the projection method."
 69        use_box, use_mask, use_points = projection["use_box"], projection["use_mask"], projection["use_points"]
 70    else:
 71        raise ValueError(f"{projection} is not a supported projection method.")
 72    return use_box, use_mask, use_points, use_single_point
 73
 74
 75# Advanced stopping criterions.
 76# In practice these did not make a big difference, so we do not use this at the moment.
 77# We still leave it here for reference.
 78def _advanced_stopping_criteria(
 79    z, seg_z, seg_prev, z_start, z_increment, segmentation, criterion_choice, score, increment
 80):
 81    def _compute_mean_iou_for_n_slices(z, increment, seg_z, n_slices):
 82        iou_list = [
 83            util.compute_iou(segmentation[z - increment * _slice], seg_z) for _slice in range(1, n_slices+1)
 84        ]
 85        return np.mean(iou_list)
 86
 87    if criterion_choice == 1:
 88        # 1. current metric: iou of current segmentation and the previous slice
 89        iou = util.compute_iou(seg_prev, seg_z)
 90        criterion = iou
 91
 92    elif criterion_choice == 2:
 93        # 2. combining SAM iou + iou: curr. slice & first segmented slice + iou: curr. slice vs prev. slice
 94        iou = util.compute_iou(seg_prev, seg_z)
 95        ff_iou = util.compute_iou(segmentation[z_start], seg_z)
 96        criterion = 0.5 * iou + 0.3 * score + 0.2 * ff_iou
 97
 98    elif criterion_choice == 3:
 99        # 3. iou of current segmented slice w.r.t the previous n slices
100        criterion = _compute_mean_iou_for_n_slices(z, increment, seg_z, min(5, abs(z - z_start)))
101
102    return criterion
103
104
105def segment_mask_in_volume(
106    segmentation: np.ndarray,
107    predictor: SamPredictor,
108    image_embeddings: util.ImageEmbeddings,
109    segmented_slices: np.ndarray,
110    stop_lower: bool,
111    stop_upper: bool,
112    iou_threshold: float,
113    projection: Union[str, dict],
114    update_progress: Optional[callable] = None,
115    box_extension: float = 0.0,
116    verbose: bool = False,
117) -> Tuple[np.ndarray, Tuple[int, int]]:
118    """Segment an object mask in in volumetric data.
119
120    Args:
121        segmentation: The initial segmentation for the object.
122        predictor: The Segment Anything predictor.
123        image_embeddings: The precomputed image embeddings for the volume.
124        segmented_slices: List of slices for which this object has already been segmented.
125        stop_lower: Whether to stop at the lowest segmented slice.
126        stop_upper: Wheter to stop at the topmost segmented slice.
127        iou_threshold: The IOU threshold for continuing segmentation across 3d.
128        projection: The projection method to use. One of 'box', 'mask', 'points', 'points_and_mask' or 'single point'.
129            Pass a dictionary to choose the excact combination of projection modes.
130        update_progress: Callback to update an external progress bar.
131        box_extension: Extension factor for increasing the box size after projection.
132            By default, does not increase the projected box size.
133        verbose: Whether to print details about the segmentation steps. By default, set to 'True'.
134
135    Returns:
136        Array with the volumetric segmentation.
137        Tuple with the first and last segmented slice.
138    """
139    use_box, use_mask, use_points, use_single_point = _validate_projection(projection)
140
141    if update_progress is None:
142        def update_progress(*args):
143            pass
144
145    def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None, verbose=False):
146        z = z_start + increment
147        while True:
148            if verbose:
149                print(f"Segment {z_start} to {z_stop}: segmenting slice {z}")
150            seg_prev = segmentation[z - increment]
151            seg_z, score, _ = segment_from_mask(
152                predictor, seg_prev, image_embeddings=image_embeddings, i=z, use_mask=use_mask,
153                use_box=use_box, use_points=use_points, box_extension=box_extension, return_all=True,
154                use_single_point=use_single_point,
155            )
156            if threshold is not None:
157                iou = util.compute_iou(seg_prev, seg_z)
158                if iou < threshold:
159                    if verbose:
160                        msg = f"Segmentation stopped at slice {z} due to IOU {iou} < {threshold}."
161                        print(msg)
162                    break
163
164            segmentation[z] = seg_z
165            z += increment
166            if stopping_criterion(z, z_stop):
167                if verbose:
168                    print(f"Segment {z_start} to {z_stop}: stop at slice {z}")
169                break
170            update_progress(1)
171
172        return z - increment
173
174    z0, z1 = int(segmented_slices.min()), int(segmented_slices.max())
175
176    # segment below the min slice
177    if z0 > 0 and not stop_lower:
178        z_min = segment_range(z0, 0, -1, np.less, iou_threshold, verbose=verbose)
179    else:
180        z_min = z0
181
182    # segment above the max slice
183    if z1 < segmentation.shape[0] - 1 and not stop_upper:
184        z_max = segment_range(z1, segmentation.shape[0] - 1, 1, np.greater, iou_threshold, verbose=verbose)
185    else:
186        z_max = z1
187
188    # segment in between min and max slice
189    if z0 != z1:
190        for z_start, z_stop in zip(segmented_slices[:-1], segmented_slices[1:]):
191            slice_diff = z_stop - z_start
192            z_mid = int((z_start + z_stop) // 2)
193
194            if slice_diff == 1:  # the slices are adjacent -> we don't need to do anything
195                pass
196
197            elif z_start == z0 and stop_lower:  # the lower slice is stop: we just segment from upper
198                segment_range(z_stop, z_start, -1, np.less_equal, verbose=verbose)
199
200            elif z_stop == z1 and stop_upper:  # the upper slice is stop: we just segment from lower
201                segment_range(z_start, z_stop, 1, np.greater_equal, verbose=verbose)
202
203            elif slice_diff == 2:  # there is only one slice in between -> use combined mask
204                z = z_start + 1
205                seg_prompt = np.logical_or(segmentation[z_start] == 1, segmentation[z_stop] == 1)
206                segmentation[z] = segment_from_mask(
207                    predictor, seg_prompt, image_embeddings=image_embeddings, i=z,
208                    use_mask=use_mask, use_box=use_box, use_points=use_points,
209                    box_extension=box_extension
210                )
211                update_progress(1)
212
213            else:  # there is a range of more than 2 slices in between -> segment ranges
214                # segment from bottom
215                segment_range(
216                    z_start, z_mid, 1, np.greater_equal if slice_diff % 2 == 0 else np.greater, verbose=verbose
217                )
218                # segment from top
219                segment_range(z_stop, z_mid, -1, np.less_equal, verbose=verbose)
220                # if the difference between start and stop is even,
221                # then we have a slice in the middle that is the same distance from top bottom
222                # in this case the slice is not segmented in the ranges above, and we segment it
223                # using the combined mask from the adjacent top and bottom slice as prompt
224                if slice_diff % 2 == 0:
225                    seg_prompt = np.logical_or(segmentation[z_mid - 1] == 1, segmentation[z_mid + 1] == 1)
226                    segmentation[z_mid] = segment_from_mask(
227                        predictor, seg_prompt, image_embeddings=image_embeddings, i=z_mid,
228                        use_mask=use_mask, use_box=use_box, use_points=use_points,
229                        box_extension=box_extension
230                    )
231                    update_progress(1)
232
233    return segmentation, (z_min, z_max)
234
235
236def _preprocess_closing(slice_segmentation, gap_closing, pbar_update):
237    binarized = slice_segmentation > 0
238    # Use a structuring element that only closes elements in z, to avoid merging objects in-plane.
239    structuring_element = np.zeros((3, 1, 1))
240    structuring_element[:, 0, 0] = 1
241    closed_segmentation = binary_closing(binarized, iterations=gap_closing, structure=structuring_element)
242
243    new_segmentation = np.zeros_like(slice_segmentation)
244    n_slices = new_segmentation.shape[0]
245
246    def process_slice(z, offset):
247        seg_z = slice_segmentation[z]
248
249        # Closing does not work for the first and last gap slices
250        if z < gap_closing or z >= (n_slices - gap_closing):
251            seg_z, _, _ = relabel_sequential(seg_z, offset=offset)
252            offset = int(seg_z.max()) + 1
253            return seg_z, offset
254
255        # Apply connected components to the closed segmentation.
256        closed_z = label(closed_segmentation[z])
257
258        # Map objects in the closed and initial segmentation.
259        # We take objects from the closed segmentation unless they
260        # have overlap with more than one object from the initial segmentation.
261        # This indicates wrong merging of closeby objects that we want to prevent.
262        matches = segmentation_overlap(closed_z, seg_z)
263        matches = {
264            seg_id: matches.overlaps_for_label_a(seg_id)["label"] for seg_id in range(1, int(closed_z.max() + 1))
265        }
266        matches = {k: v[v != 0] for k, v in matches.items()}
267
268        ids_initial, ids_closed = [], []
269        for seg_id, matched in matches.items():
270            if len(matched) > 1:
271                ids_initial.extend(matched.tolist())
272            else:
273                ids_closed.append(seg_id)
274
275        seg_new = np.zeros_like(seg_z)
276        closed_mask = np.isin(closed_z, ids_closed)
277        seg_new[closed_mask] = closed_z[closed_mask]
278
279        if ids_initial:
280            initial_mask = np.isin(seg_z, ids_initial)
281            seg_new[initial_mask] = relabel_sequential(seg_z[initial_mask], offset=seg_new.max() + 1)[0]
282
283        seg_new, _, _ = relabel_sequential(seg_new, offset=offset)
284        max_z = seg_new.max()
285        if max_z > 0:
286            offset = int(max_z) + 1
287
288        return seg_new, offset
289
290    # Further optimization: parallelize
291    offset = 1
292    for z in range(n_slices):
293        new_segmentation[z], offset = process_slice(z, offset)
294        pbar_update(1)
295
296    return new_segmentation
297
298
299def _filter_z_extent(segmentation, min_z_extent):
300    props = regionprops(segmentation)
301    filter_ids = []
302    for prop in props:
303        box = prop.bbox
304        z_extent = box[3] - box[0]
305        if z_extent < min_z_extent:
306            filter_ids.append(prop.label)
307    if filter_ids:
308        segmentation[np.isin(segmentation, filter_ids)] = 0
309    return segmentation
310
311
312def merge_instance_segmentation_3d(
313    slice_segmentation: np.ndarray,
314    beta: float = 0.5,
315    with_background: bool = True,
316    gap_closing: Optional[int] = None,
317    min_z_extent: Optional[int] = None,
318    verbose: bool = True,
319    pbar_init: Optional[callable] = None,
320    pbar_update: Optional[callable] = None,
321) -> np.ndarray:
322    """Merge stacked 2d instance segmentations into a consistent 3d segmentation.
323
324    Solves a multicut problem based on the overlap of objects to merge across z.
325
326    Args:
327        slice_segmentation: The stacked segmentation across the slices.
328            We assume that the segmentation is labeled consecutive across z.
329        beta: The bias term for the multicut. Higher values lead to a larger
330            degree of over-segmentation and vice versa. by default, set to '0.5'.
331        with_background: Whether this is a segmentation problem with background.
332            In that case all edges connecting to the background are set to be repulsive.
333            By default, set to 'True'.
334        gap_closing: If given, gaps in the segmentation are closed with a binary closing
335            operation. The value is used to determine the number of iterations for the closing.
336        min_z_extent: Require a minimal extent in z for the segmented objects.
337            This can help to prevent segmentation artifacts.
338        verbose: Verbosity flag. By default, set to 'True'.
339        pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
340            Can be used together with pbar_update to handle napari progress bar in other thread.
341            To enable using this function within a threadworker.
342        pbar_update: Callback to update an external progress bar.
343
344    Returns:
345        The merged segmentation.
346    """
347    _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
348
349    if gap_closing is not None and gap_closing > 0:
350        pbar_init(slice_segmentation.shape[0] + 1, "Merge segmentation")
351        slice_segmentation = _preprocess_closing(slice_segmentation, gap_closing, pbar_update)
352    else:
353        pbar_init(1, "Merge segmentation")
354
355    # Extract the overlap between slices.
356    edges = track_utils.compute_edges_from_overlap(slice_segmentation, verbose=False)
357    if len(edges) == 0:  # Nothing to merge.
358        return slice_segmentation
359
360    uv_ids = np.array([[edge["source"], edge["target"]] for edge in edges])
361    overlaps = np.array([edge["score"] for edge in edges])
362
363    n_nodes = int(slice_segmentation.max() + 1)
364    graph = UndirectedGraph(n_nodes)
365    graph.insert_edges(uv_ids)
366
367    costs = seg_utils.multicut.compute_edge_costs(overlaps)
368    # Set background weights to be maximally repulsive.
369    if with_background:
370        bg_edges = (uv_ids == 0).any(axis=1)
371        costs[bg_edges] = -8.0
372
373    node_labels = seg_utils.multicut.multicut_decomposition(graph, 1.0 - costs, beta=beta)
374
375    segmentation = node_labels[slice_segmentation]
376    if min_z_extent is not None and min_z_extent > 0:
377        segmentation = _filter_z_extent(segmentation, min_z_extent)
378
379    pbar_update(1)
380    pbar_close()
381
382    return segmentation
383
384
385def _segment_slices(
386    data, predictor, segmentor, embedding_path, verbose, tile_shape, halo, batch_size=1, **kwargs
387):
388    assert data.ndim == 3
389
390    image_embeddings = util.precompute_image_embeddings(
391        predictor=predictor,
392        input_=data,
393        save_path=embedding_path,
394        ndim=3,
395        tile_shape=tile_shape,
396        halo=halo,
397        verbose=verbose,
398        batch_size=batch_size,
399    )
400
401    offset = 0
402    segmentation = np.zeros(data.shape, dtype="uint32")
403
404    for i in tqdm(range(segmentation.shape[0]), desc="Segment slices", disable=not verbose):
405        segmentor.initialize(data[i], image_embeddings=image_embeddings, verbose=False, i=i)
406        seg = segmentor.generate(**kwargs)
407
408        # Set offset for instance per slice.
409        max_z = int(seg.max())
410        if max_z == 0:
411            continue
412        seg[seg != 0] += offset
413        offset = max_z + offset
414        segmentation[i] = seg
415
416    return segmentation, image_embeddings
417
418
419def automatic_3d_segmentation(
420    volume: np.ndarray,
421    predictor: SamPredictor,
422    segmentor: AMGBase,
423    embedding_path: Optional[Union[str, os.PathLike]] = None,
424    with_background: bool = True,
425    gap_closing: Optional[int] = None,
426    min_z_extent: Optional[int] = None,
427    tile_shape: Optional[Tuple[int, int]] = None,
428    halo: Optional[Tuple[int, int]] = None,
429    verbose: bool = True,
430    return_embeddings: bool = False,
431    batch_size: int = 1,
432    **kwargs,
433) -> np.ndarray:
434    """Automatically segment objects in a volume.
435
436    First segments slices individually in 2d and then merges them across 3d
437    based on overlap of objects between slices.
438
439    Args:
440        volume: The input volume.
441        predictor: The Segment Anything predictor.
442        segmentor: The instance segmentation class.
443        embedding_path: The path to save pre-computed embeddings.
444        with_background: Whether the segmentation has background. By default, set to 'True'.
445        gap_closing: If given, gaps in the segmentation are closed with a binary closing
446            operation. The value is used to determine the number of iterations for the closing.
447        min_z_extent: Require a minimal extent in z for the segmented objects.
448            This can help to prevent segmentation artifacts.
449        tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
450        halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
451        verbose: Verbosity flag. By default, set to 'True'.
452        return_embeddings: Whether to return the precomputed image embeddings. By default, set to 'False'.
453        batch_size: The batch size to compute image embeddings over planes. By default, set to '1'.
454        kwargs: Keyword arguments for the 'generate' method of the 'segmentor'.
455
456    Returns:
457        The segmentation.
458    """
459    segmentation, image_embeddings = _segment_slices(
460        data=volume,
461        predictor=predictor,
462        segmentor=segmentor,
463        embedding_path=embedding_path,
464        verbose=verbose,
465        tile_shape=tile_shape,
466        halo=halo,
467        batch_size=batch_size,
468        **kwargs
469    )
470    segmentation = merge_instance_segmentation_3d(
471        segmentation,
472        beta=0.5,
473        with_background=with_background,
474        gap_closing=gap_closing,
475        min_z_extent=min_z_extent,
476        verbose=verbose,
477    )
478    if return_embeddings:
479        return segmentation, image_embeddings
480    else:
481        return segmentation
482
483
484def _filter_tracks(tracking_result, min_track_length):
485    props = regionprops(tracking_result)
486    discard_ids = []
487    for prop in props:
488        label_id = prop.label
489        z_start, z_stop = prop.bbox[0], prop.bbox[3]
490        if z_stop - z_start < min_track_length:
491            discard_ids.append(label_id)
492    tracking_result[np.isin(tracking_result, discard_ids)] = 0
493    tracking_result, _, _ = relabel_sequential(tracking_result)
494    return tracking_result
495
496
497def _extract_tracks_and_lineages(segmentations, track_data, parent_graph):
498    # The track data has the following layout: n_tracks x 4
499    # With the following columns:
500    # track_id - id of the track (= result from trackastra)
501    # timepoint
502    # y coordinate
503    # x coordinate
504
505    # Use the last three columns to index the segmentation and get the segmentation id.
506    index = np.round(track_data[:, 1:], 0).astype("int32")
507    index = tuple(index[:, i] for i in range(index.shape[1]))
508    segmentation_ids = segmentations[index]
509
510    # Find the mapping of nodes (= segmented objects) to track-ids.
511    track_ids = track_data[:, 0].astype("int32")
512    assert len(segmentation_ids) == len(track_ids)
513    node_to_track = {k: v for k, v in zip(segmentation_ids, track_ids)}
514
515    # Find the lineages as connected components in the parent graph.
516    # First, we build a proper graph.
517    lineage_graph = nx.Graph()
518    for k, v in parent_graph.items():
519        lineage_graph.add_edge(k, v)
520
521    # Then, find the connected components, and compute the lineage representation expected by micro-sam from it:
522    # E.g. if we have three lineages, the first consisting of three tracks and the second and third of one track each:
523    # [
524    #   {1: [2, 3]},  lineage with a dividing cell
525    #   {4: []}, lineage with just one cell
526    #   {5: []}, lineage with just one cell
527    # ]
528
529    # First, we fill the lineages which have one or more divisions, i.e. trees with more than one node.
530    lineages = []
531    for component in nx.connected_components(lineage_graph):
532        root = next(iter(component))
533        lineage_dict = {}
534
535        def dfs(node, parent):
536            # Avoid revisiting the parent node
537            children = [n for n in lineage_graph[node] if n != parent]
538            lineage_dict[node] = children
539            for child in children:
540                dfs(child, node)
541
542        dfs(root, None)
543        lineages.append(lineage_dict)
544
545    # Then add single node lineages, which are not reflected in the original graph.
546    all_tracks = set(track_ids.tolist())
547    lineage_tracks = []
548    for lineage in lineages:
549        for k, v in lineage.items():
550            lineage_tracks.append(k)
551            lineage_tracks.extend(v)
552    singleton_tracks = list(all_tracks - set(lineage_tracks))
553    lineages.extend([{track: []} for track in singleton_tracks])
554
555    # Make sure node_to_track contains everything.
556    all_seg_ids = np.unique(segmentations)
557    missing_seg_ids = np.setdiff1d(all_seg_ids, list(node_to_track.keys()))
558    node_to_track.update({seg_id: 0 for seg_id in missing_seg_ids})
559    return node_to_track, lineages
560
561
562def _filter_lineages(lineages, tracking_result):
563    track_ids = set(np.unique(tracking_result)) - {0}
564    filtered_lineages = []
565    for lineage in lineages:
566        filtered_lineage = {k: v for k, v in lineage.items() if k in track_ids}
567        if filtered_lineage:
568            filtered_lineages.append(filtered_lineage)
569    return filtered_lineages
570
571
572def _tracking_impl(timeseries, segmentation, mode, min_time_extent, output_folder=None):
573    device = "cuda" if torch.cuda.is_available() else "cpu"
574    model = Trackastra.from_pretrained("general_2d", device=device)
575    result = model.track(timeseries, segmentation, mode=mode)
576    try:
577        lineage_graph, _ = result
578    except ValueError:
579        lineage_graph = result
580
581    track_data, parent_graph, _ = graph_to_napari_tracks(lineage_graph)
582    if track_data.size == 0:
583        warnings.warn("Tracking result is empty.")
584        tracking_result = np.zeros_like(segmentation)
585        lineages = []
586        return tracking_result, lineages
587
588    node_to_track, lineages = _extract_tracks_and_lineages(segmentation, track_data, parent_graph)
589    tracking_result = recolor_segmentation(segmentation, node_to_track)
590
591    if output_folder is not None:  # Store tracking results in CTC format.
592        graph_to_ctc(lineage_graph, segmentation, outdir=output_folder)
593
594    # TODO
595    # We should check if trackastra supports this already.
596    # Filter out short tracks / lineages.
597    if min_time_extent is not None and min_time_extent > 0:
598        raise NotImplementedError
599
600    # Filter out pruned lineages.
601    # May either be missing due to track filtering or non-consecutive track numbering in trackastra.
602    lineages = _filter_lineages(lineages, tracking_result)
603
604    return tracking_result, lineages
605
606
607def track_across_frames(
608    timeseries: np.ndarray,
609    segmentation: np.ndarray,
610    gap_closing: Optional[int] = None,
611    min_time_extent: Optional[int] = None,
612    verbose: bool = True,
613    pbar_init: Optional[callable] = None,
614    pbar_update: Optional[callable] = None,
615    output_folder: Optional[Union[os.PathLike, str]] = None,
616) -> Tuple[np.ndarray, List[Dict]]:
617    """Track segmented objects over time.
618
619    This function uses Trackastra: https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf
620    for tracking. Please cite it if you use the automated tracking functionality.
621
622    Args:
623        timeseries: The input timeseries of images.
624        segmentation: The segmentation. Expect segmentation results per frame
625            that are relabeled so that segmentation ids don't overlap.
626        gap_closing: If given, gaps in the segmentation are closed with a binary closing
627            operation. The value is used to determine the number of iterations for the closing.
628        min_time_extent: Require a minimal extent in time for the tracked objects.
629        verbose: Verbosity flag. By default, set to 'True'.
630        pbar_init: Function to initialize the progress bar.
631        pbar_update: Function to update the progress bar.
632        output_folder: The folder where the tracking results are stored in CTC format.
633
634    Returns:
635        The tracking result. Each object is colored by its track id.
636        The lineages, which correspond to the cell divisions. Lineages are represented by a list of dicts,
637            with each dict encoding a lineage, where keys correspond to parent track ids.
638            Each key either maps to a list with two child track ids (cell division) or to an empty list (no division).
639    """
640    if Trackastra is None:
641        raise RuntimeError(
642            "Automatic tracking requires trackastra. You can install it via 'pip install trackastra'."
643        )
644
645    _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init=pbar_init, pbar_update=pbar_update)
646
647    if gap_closing is not None and gap_closing > 0:
648        segmentation = _preprocess_closing(segmentation, gap_closing, pbar_update)
649
650    segmentation, lineage = _tracking_impl(
651        timeseries=np.asarray(timeseries),
652        segmentation=segmentation,
653        mode="greedy",
654        min_time_extent=min_time_extent,
655        output_folder=output_folder,
656    )
657    return segmentation, lineage
658
659
660def automatic_tracking_implementation(
661    timeseries: np.ndarray,
662    predictor: SamPredictor,
663    segmentor: AMGBase,
664    embedding_path: Optional[Union[str, os.PathLike]] = None,
665    gap_closing: Optional[int] = None,
666    min_time_extent: Optional[int] = None,
667    tile_shape: Optional[Tuple[int, int]] = None,
668    halo: Optional[Tuple[int, int]] = None,
669    verbose: bool = True,
670    return_embeddings: bool = False,
671    batch_size: int = 1,
672    output_folder: Optional[Union[os.PathLike, str]] = None,
673    **kwargs,
674) -> Tuple[np.ndarray, List[Dict]]:
675    """Automatically track objects in a timesries based on per-frame automatic segmentation.
676
677    This function uses Trackastra: https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf
678    for tracking. Please cite it if you use the automated tracking functionality.
679
680    Args:
681        timeseries: The input timeseries of images.
682        predictor: The SAM model.
683        segmentor: The instance segmentation class.
684        embedding_path: The path to save pre-computed embeddings.
685        gap_closing: If given, gaps in the segmentation are closed with a binary closing
686            operation. The value is used to determine the number of iterations for the closing.
687        min_time_extent: Require a minimal extent in time for the tracked objects.
688        tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
689        halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
690        verbose: Verbosity flag. By default, set to 'True'.
691        return_embeddings: Whether to return the precomputed image embeddings. By default, set to 'False'.
692        batch_size: The batch size to compute image embeddings over planes. By default, set to '1'.
693        output_folder: The folder where the tracking results are stored in CTC format.
694        kwargs: Keyword arguments for the 'generate' method of the 'segmentor'.
695
696    Returns:
697        The tracking result. Each object is colored by its track id.
698        The lineages, which correspond to the cell divisions. Lineages are represented by a list of dicts,
699            with each dict encoding a lineage, where keys correspond to parent track ids.
700            Each key either maps to a list with two child track ids (cell division) or to an empty list (no division).
701    """
702    if Trackastra is None:
703        raise RuntimeError(
704            "Automatic tracking requires trackastra. You can install it via 'pip install trackastra'."
705        )
706
707    segmentation, image_embeddings = _segment_slices(
708        timeseries, predictor, segmentor, embedding_path, verbose,
709        tile_shape=tile_shape, halo=halo, batch_size=batch_size,
710        **kwargs,
711    )
712
713    segmentation, lineage = track_across_frames(
714        timeseries=timeseries,
715        segmentation=segmentation,
716        gap_closing=gap_closing,
717        min_time_extent=min_time_extent,
718        verbose=verbose,
719        output_folder=output_folder,
720    )
721
722    if return_embeddings:
723        return segmentation, lineage, image_embeddings
724    else:
725        return segmentation, lineage
726
727
728def get_napari_track_data(
729    segmentation: np.ndarray, lineages: List[Dict], n_threads: Optional[int] = None
730) -> Tuple[np.ndarray, Dict[int, List]]:
731    """Derive the inputs for the napari tracking layer from a tracking result.
732
733    Args:
734        segmentation: The segmentation, after relabeling with track ids.
735        lineages: The lineage information.
736        n_threads: Number of threads for extracting the track data from the segmentation.
737
738    Returns:
739        The array with the track data expected by napari.
740        The parent dictionary for napari.
741    """
742    if n_threads is None:
743        n_threads = mp.cpu_count()
744
745    def compute_props(t):
746        props = regionprops(segmentation[t])
747        # Create the track data representation for napari, which expects:
748        # track_id, timepoint, y, x
749        track_data = np.array([[prop.label, t] + list(prop.centroid) for prop in props])
750        return track_data
751
752    with futures.ThreadPoolExecutor(n_threads) as tp:
753        track_data = list(tp.map(compute_props, range(segmentation.shape[0])))
754    track_data = [data for data in track_data if data.size > 0]
755    track_data = np.concatenate(track_data)
756
757    # The graph representation of napari uses the children as keys and the parents as values,
758    # whereas our representation uses parents as keys and children as values.
759    # Hence, we need to translate the representation.
760    parent_graph = {
761        child: [parent] for lineage in lineages for parent, children in lineage.items() for child in children
762    }
763
764    return track_data, parent_graph
PROJECTION_MODES = ('box', 'mask', 'points', 'points_and_mask', 'single_point')
def segment_mask_in_volume( segmentation: numpy.ndarray, predictor: segment_anything.predictor.SamPredictor, image_embeddings: Dict[str, Any], segmented_slices: numpy.ndarray, stop_lower: bool, stop_upper: bool, iou_threshold: float, projection: Union[str, dict], update_progress: Optional[<built-in function callable>] = None, box_extension: float = 0.0, verbose: bool = False) -> Tuple[numpy.ndarray, Tuple[int, int]]:
106def segment_mask_in_volume(
107    segmentation: np.ndarray,
108    predictor: SamPredictor,
109    image_embeddings: util.ImageEmbeddings,
110    segmented_slices: np.ndarray,
111    stop_lower: bool,
112    stop_upper: bool,
113    iou_threshold: float,
114    projection: Union[str, dict],
115    update_progress: Optional[callable] = None,
116    box_extension: float = 0.0,
117    verbose: bool = False,
118) -> Tuple[np.ndarray, Tuple[int, int]]:
119    """Segment an object mask in in volumetric data.
120
121    Args:
122        segmentation: The initial segmentation for the object.
123        predictor: The Segment Anything predictor.
124        image_embeddings: The precomputed image embeddings for the volume.
125        segmented_slices: List of slices for which this object has already been segmented.
126        stop_lower: Whether to stop at the lowest segmented slice.
127        stop_upper: Wheter to stop at the topmost segmented slice.
128        iou_threshold: The IOU threshold for continuing segmentation across 3d.
129        projection: The projection method to use. One of 'box', 'mask', 'points', 'points_and_mask' or 'single point'.
130            Pass a dictionary to choose the excact combination of projection modes.
131        update_progress: Callback to update an external progress bar.
132        box_extension: Extension factor for increasing the box size after projection.
133            By default, does not increase the projected box size.
134        verbose: Whether to print details about the segmentation steps. By default, set to 'True'.
135
136    Returns:
137        Array with the volumetric segmentation.
138        Tuple with the first and last segmented slice.
139    """
140    use_box, use_mask, use_points, use_single_point = _validate_projection(projection)
141
142    if update_progress is None:
143        def update_progress(*args):
144            pass
145
146    def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None, verbose=False):
147        z = z_start + increment
148        while True:
149            if verbose:
150                print(f"Segment {z_start} to {z_stop}: segmenting slice {z}")
151            seg_prev = segmentation[z - increment]
152            seg_z, score, _ = segment_from_mask(
153                predictor, seg_prev, image_embeddings=image_embeddings, i=z, use_mask=use_mask,
154                use_box=use_box, use_points=use_points, box_extension=box_extension, return_all=True,
155                use_single_point=use_single_point,
156            )
157            if threshold is not None:
158                iou = util.compute_iou(seg_prev, seg_z)
159                if iou < threshold:
160                    if verbose:
161                        msg = f"Segmentation stopped at slice {z} due to IOU {iou} < {threshold}."
162                        print(msg)
163                    break
164
165            segmentation[z] = seg_z
166            z += increment
167            if stopping_criterion(z, z_stop):
168                if verbose:
169                    print(f"Segment {z_start} to {z_stop}: stop at slice {z}")
170                break
171            update_progress(1)
172
173        return z - increment
174
175    z0, z1 = int(segmented_slices.min()), int(segmented_slices.max())
176
177    # segment below the min slice
178    if z0 > 0 and not stop_lower:
179        z_min = segment_range(z0, 0, -1, np.less, iou_threshold, verbose=verbose)
180    else:
181        z_min = z0
182
183    # segment above the max slice
184    if z1 < segmentation.shape[0] - 1 and not stop_upper:
185        z_max = segment_range(z1, segmentation.shape[0] - 1, 1, np.greater, iou_threshold, verbose=verbose)
186    else:
187        z_max = z1
188
189    # segment in between min and max slice
190    if z0 != z1:
191        for z_start, z_stop in zip(segmented_slices[:-1], segmented_slices[1:]):
192            slice_diff = z_stop - z_start
193            z_mid = int((z_start + z_stop) // 2)
194
195            if slice_diff == 1:  # the slices are adjacent -> we don't need to do anything
196                pass
197
198            elif z_start == z0 and stop_lower:  # the lower slice is stop: we just segment from upper
199                segment_range(z_stop, z_start, -1, np.less_equal, verbose=verbose)
200
201            elif z_stop == z1 and stop_upper:  # the upper slice is stop: we just segment from lower
202                segment_range(z_start, z_stop, 1, np.greater_equal, verbose=verbose)
203
204            elif slice_diff == 2:  # there is only one slice in between -> use combined mask
205                z = z_start + 1
206                seg_prompt = np.logical_or(segmentation[z_start] == 1, segmentation[z_stop] == 1)
207                segmentation[z] = segment_from_mask(
208                    predictor, seg_prompt, image_embeddings=image_embeddings, i=z,
209                    use_mask=use_mask, use_box=use_box, use_points=use_points,
210                    box_extension=box_extension
211                )
212                update_progress(1)
213
214            else:  # there is a range of more than 2 slices in between -> segment ranges
215                # segment from bottom
216                segment_range(
217                    z_start, z_mid, 1, np.greater_equal if slice_diff % 2 == 0 else np.greater, verbose=verbose
218                )
219                # segment from top
220                segment_range(z_stop, z_mid, -1, np.less_equal, verbose=verbose)
221                # if the difference between start and stop is even,
222                # then we have a slice in the middle that is the same distance from top bottom
223                # in this case the slice is not segmented in the ranges above, and we segment it
224                # using the combined mask from the adjacent top and bottom slice as prompt
225                if slice_diff % 2 == 0:
226                    seg_prompt = np.logical_or(segmentation[z_mid - 1] == 1, segmentation[z_mid + 1] == 1)
227                    segmentation[z_mid] = segment_from_mask(
228                        predictor, seg_prompt, image_embeddings=image_embeddings, i=z_mid,
229                        use_mask=use_mask, use_box=use_box, use_points=use_points,
230                        box_extension=box_extension
231                    )
232                    update_progress(1)
233
234    return segmentation, (z_min, z_max)

Segment an object mask in in volumetric data.

Arguments:
  • segmentation: The initial segmentation for the object.
  • predictor: The Segment Anything predictor.
  • image_embeddings: The precomputed image embeddings for the volume.
  • segmented_slices: List of slices for which this object has already been segmented.
  • stop_lower: Whether to stop at the lowest segmented slice.
  • stop_upper: Wheter to stop at the topmost segmented slice.
  • iou_threshold: The IOU threshold for continuing segmentation across 3d.
  • projection: The projection method to use. One of 'box', 'mask', 'points', 'points_and_mask' or 'single point'. Pass a dictionary to choose the excact combination of projection modes.
  • update_progress: Callback to update an external progress bar.
  • box_extension: Extension factor for increasing the box size after projection. By default, does not increase the projected box size.
  • verbose: Whether to print details about the segmentation steps. By default, set to 'True'.
Returns:

Array with the volumetric segmentation. Tuple with the first and last segmented slice.

def merge_instance_segmentation_3d( slice_segmentation: numpy.ndarray, beta: float = 0.5, with_background: bool = True, gap_closing: Optional[int] = None, min_z_extent: Optional[int] = None, verbose: bool = True, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None) -> numpy.ndarray:
313def merge_instance_segmentation_3d(
314    slice_segmentation: np.ndarray,
315    beta: float = 0.5,
316    with_background: bool = True,
317    gap_closing: Optional[int] = None,
318    min_z_extent: Optional[int] = None,
319    verbose: bool = True,
320    pbar_init: Optional[callable] = None,
321    pbar_update: Optional[callable] = None,
322) -> np.ndarray:
323    """Merge stacked 2d instance segmentations into a consistent 3d segmentation.
324
325    Solves a multicut problem based on the overlap of objects to merge across z.
326
327    Args:
328        slice_segmentation: The stacked segmentation across the slices.
329            We assume that the segmentation is labeled consecutive across z.
330        beta: The bias term for the multicut. Higher values lead to a larger
331            degree of over-segmentation and vice versa. by default, set to '0.5'.
332        with_background: Whether this is a segmentation problem with background.
333            In that case all edges connecting to the background are set to be repulsive.
334            By default, set to 'True'.
335        gap_closing: If given, gaps in the segmentation are closed with a binary closing
336            operation. The value is used to determine the number of iterations for the closing.
337        min_z_extent: Require a minimal extent in z for the segmented objects.
338            This can help to prevent segmentation artifacts.
339        verbose: Verbosity flag. By default, set to 'True'.
340        pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
341            Can be used together with pbar_update to handle napari progress bar in other thread.
342            To enable using this function within a threadworker.
343        pbar_update: Callback to update an external progress bar.
344
345    Returns:
346        The merged segmentation.
347    """
348    _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
349
350    if gap_closing is not None and gap_closing > 0:
351        pbar_init(slice_segmentation.shape[0] + 1, "Merge segmentation")
352        slice_segmentation = _preprocess_closing(slice_segmentation, gap_closing, pbar_update)
353    else:
354        pbar_init(1, "Merge segmentation")
355
356    # Extract the overlap between slices.
357    edges = track_utils.compute_edges_from_overlap(slice_segmentation, verbose=False)
358    if len(edges) == 0:  # Nothing to merge.
359        return slice_segmentation
360
361    uv_ids = np.array([[edge["source"], edge["target"]] for edge in edges])
362    overlaps = np.array([edge["score"] for edge in edges])
363
364    n_nodes = int(slice_segmentation.max() + 1)
365    graph = UndirectedGraph(n_nodes)
366    graph.insert_edges(uv_ids)
367
368    costs = seg_utils.multicut.compute_edge_costs(overlaps)
369    # Set background weights to be maximally repulsive.
370    if with_background:
371        bg_edges = (uv_ids == 0).any(axis=1)
372        costs[bg_edges] = -8.0
373
374    node_labels = seg_utils.multicut.multicut_decomposition(graph, 1.0 - costs, beta=beta)
375
376    segmentation = node_labels[slice_segmentation]
377    if min_z_extent is not None and min_z_extent > 0:
378        segmentation = _filter_z_extent(segmentation, min_z_extent)
379
380    pbar_update(1)
381    pbar_close()
382
383    return segmentation

Merge stacked 2d instance segmentations into a consistent 3d segmentation.

Solves a multicut problem based on the overlap of objects to merge across z.

Arguments:
  • slice_segmentation: The stacked segmentation across the slices. We assume that the segmentation is labeled consecutive across z.
  • beta: The bias term for the multicut. Higher values lead to a larger degree of over-segmentation and vice versa. by default, set to '0.5'.
  • with_background: Whether this is a segmentation problem with background. In that case all edges connecting to the background are set to be repulsive. By default, set to 'True'.
  • gap_closing: If given, gaps in the segmentation are closed with a binary closing operation. The value is used to determine the number of iterations for the closing.
  • min_z_extent: Require a minimal extent in z for the segmented objects. This can help to prevent segmentation artifacts.
  • verbose: Verbosity flag. By default, set to 'True'.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enable using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
Returns:

The merged segmentation.

def automatic_3d_segmentation( volume: numpy.ndarray, predictor: segment_anything.predictor.SamPredictor, segmentor: micro_sam.instance_segmentation.AMGBase, embedding_path: Union[str, os.PathLike, NoneType] = None, with_background: bool = True, gap_closing: Optional[int] = None, min_z_extent: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, return_embeddings: bool = False, batch_size: int = 1, **kwargs) -> numpy.ndarray:
420def automatic_3d_segmentation(
421    volume: np.ndarray,
422    predictor: SamPredictor,
423    segmentor: AMGBase,
424    embedding_path: Optional[Union[str, os.PathLike]] = None,
425    with_background: bool = True,
426    gap_closing: Optional[int] = None,
427    min_z_extent: Optional[int] = None,
428    tile_shape: Optional[Tuple[int, int]] = None,
429    halo: Optional[Tuple[int, int]] = None,
430    verbose: bool = True,
431    return_embeddings: bool = False,
432    batch_size: int = 1,
433    **kwargs,
434) -> np.ndarray:
435    """Automatically segment objects in a volume.
436
437    First segments slices individually in 2d and then merges them across 3d
438    based on overlap of objects between slices.
439
440    Args:
441        volume: The input volume.
442        predictor: The Segment Anything predictor.
443        segmentor: The instance segmentation class.
444        embedding_path: The path to save pre-computed embeddings.
445        with_background: Whether the segmentation has background. By default, set to 'True'.
446        gap_closing: If given, gaps in the segmentation are closed with a binary closing
447            operation. The value is used to determine the number of iterations for the closing.
448        min_z_extent: Require a minimal extent in z for the segmented objects.
449            This can help to prevent segmentation artifacts.
450        tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
451        halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
452        verbose: Verbosity flag. By default, set to 'True'.
453        return_embeddings: Whether to return the precomputed image embeddings. By default, set to 'False'.
454        batch_size: The batch size to compute image embeddings over planes. By default, set to '1'.
455        kwargs: Keyword arguments for the 'generate' method of the 'segmentor'.
456
457    Returns:
458        The segmentation.
459    """
460    segmentation, image_embeddings = _segment_slices(
461        data=volume,
462        predictor=predictor,
463        segmentor=segmentor,
464        embedding_path=embedding_path,
465        verbose=verbose,
466        tile_shape=tile_shape,
467        halo=halo,
468        batch_size=batch_size,
469        **kwargs
470    )
471    segmentation = merge_instance_segmentation_3d(
472        segmentation,
473        beta=0.5,
474        with_background=with_background,
475        gap_closing=gap_closing,
476        min_z_extent=min_z_extent,
477        verbose=verbose,
478    )
479    if return_embeddings:
480        return segmentation, image_embeddings
481    else:
482        return segmentation

Automatically segment objects in a volume.

First segments slices individually in 2d and then merges them across 3d based on overlap of objects between slices.

Arguments:
  • volume: The input volume.
  • predictor: The Segment Anything predictor.
  • segmentor: The instance segmentation class.
  • embedding_path: The path to save pre-computed embeddings.
  • with_background: Whether the segmentation has background. By default, set to 'True'.
  • gap_closing: If given, gaps in the segmentation are closed with a binary closing operation. The value is used to determine the number of iterations for the closing.
  • min_z_extent: Require a minimal extent in z for the segmented objects. This can help to prevent segmentation artifacts.
  • tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
  • halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
  • verbose: Verbosity flag. By default, set to 'True'.
  • return_embeddings: Whether to return the precomputed image embeddings. By default, set to 'False'.
  • batch_size: The batch size to compute image embeddings over planes. By default, set to '1'.
  • kwargs: Keyword arguments for the 'generate' method of the 'segmentor'.
Returns:

The segmentation.

def track_across_frames( timeseries: numpy.ndarray, segmentation: numpy.ndarray, gap_closing: Optional[int] = None, min_time_extent: Optional[int] = None, verbose: bool = True, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None, output_folder: Union[str, os.PathLike, NoneType] = None) -> Tuple[numpy.ndarray, List[Dict]]:
608def track_across_frames(
609    timeseries: np.ndarray,
610    segmentation: np.ndarray,
611    gap_closing: Optional[int] = None,
612    min_time_extent: Optional[int] = None,
613    verbose: bool = True,
614    pbar_init: Optional[callable] = None,
615    pbar_update: Optional[callable] = None,
616    output_folder: Optional[Union[os.PathLike, str]] = None,
617) -> Tuple[np.ndarray, List[Dict]]:
618    """Track segmented objects over time.
619
620    This function uses Trackastra: https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf
621    for tracking. Please cite it if you use the automated tracking functionality.
622
623    Args:
624        timeseries: The input timeseries of images.
625        segmentation: The segmentation. Expect segmentation results per frame
626            that are relabeled so that segmentation ids don't overlap.
627        gap_closing: If given, gaps in the segmentation are closed with a binary closing
628            operation. The value is used to determine the number of iterations for the closing.
629        min_time_extent: Require a minimal extent in time for the tracked objects.
630        verbose: Verbosity flag. By default, set to 'True'.
631        pbar_init: Function to initialize the progress bar.
632        pbar_update: Function to update the progress bar.
633        output_folder: The folder where the tracking results are stored in CTC format.
634
635    Returns:
636        The tracking result. Each object is colored by its track id.
637        The lineages, which correspond to the cell divisions. Lineages are represented by a list of dicts,
638            with each dict encoding a lineage, where keys correspond to parent track ids.
639            Each key either maps to a list with two child track ids (cell division) or to an empty list (no division).
640    """
641    if Trackastra is None:
642        raise RuntimeError(
643            "Automatic tracking requires trackastra. You can install it via 'pip install trackastra'."
644        )
645
646    _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init=pbar_init, pbar_update=pbar_update)
647
648    if gap_closing is not None and gap_closing > 0:
649        segmentation = _preprocess_closing(segmentation, gap_closing, pbar_update)
650
651    segmentation, lineage = _tracking_impl(
652        timeseries=np.asarray(timeseries),
653        segmentation=segmentation,
654        mode="greedy",
655        min_time_extent=min_time_extent,
656        output_folder=output_folder,
657    )
658    return segmentation, lineage

Track segmented objects over time.

This function uses Trackastra: https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf for tracking. Please cite it if you use the automated tracking functionality.

Arguments:
  • timeseries: The input timeseries of images.
  • segmentation: The segmentation. Expect segmentation results per frame that are relabeled so that segmentation ids don't overlap.
  • gap_closing: If given, gaps in the segmentation are closed with a binary closing operation. The value is used to determine the number of iterations for the closing.
  • min_time_extent: Require a minimal extent in time for the tracked objects.
  • verbose: Verbosity flag. By default, set to 'True'.
  • pbar_init: Function to initialize the progress bar.
  • pbar_update: Function to update the progress bar.
  • output_folder: The folder where the tracking results are stored in CTC format.
Returns:

The tracking result. Each object is colored by its track id. The lineages, which correspond to the cell divisions. Lineages are represented by a list of dicts, with each dict encoding a lineage, where keys correspond to parent track ids. Each key either maps to a list with two child track ids (cell division) or to an empty list (no division).

def automatic_tracking_implementation( timeseries: numpy.ndarray, predictor: segment_anything.predictor.SamPredictor, segmentor: micro_sam.instance_segmentation.AMGBase, embedding_path: Union[str, os.PathLike, NoneType] = None, gap_closing: Optional[int] = None, min_time_extent: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, return_embeddings: bool = False, batch_size: int = 1, output_folder: Union[str, os.PathLike, NoneType] = None, **kwargs) -> Tuple[numpy.ndarray, List[Dict]]:
661def automatic_tracking_implementation(
662    timeseries: np.ndarray,
663    predictor: SamPredictor,
664    segmentor: AMGBase,
665    embedding_path: Optional[Union[str, os.PathLike]] = None,
666    gap_closing: Optional[int] = None,
667    min_time_extent: Optional[int] = None,
668    tile_shape: Optional[Tuple[int, int]] = None,
669    halo: Optional[Tuple[int, int]] = None,
670    verbose: bool = True,
671    return_embeddings: bool = False,
672    batch_size: int = 1,
673    output_folder: Optional[Union[os.PathLike, str]] = None,
674    **kwargs,
675) -> Tuple[np.ndarray, List[Dict]]:
676    """Automatically track objects in a timesries based on per-frame automatic segmentation.
677
678    This function uses Trackastra: https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf
679    for tracking. Please cite it if you use the automated tracking functionality.
680
681    Args:
682        timeseries: The input timeseries of images.
683        predictor: The SAM model.
684        segmentor: The instance segmentation class.
685        embedding_path: The path to save pre-computed embeddings.
686        gap_closing: If given, gaps in the segmentation are closed with a binary closing
687            operation. The value is used to determine the number of iterations for the closing.
688        min_time_extent: Require a minimal extent in time for the tracked objects.
689        tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
690        halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
691        verbose: Verbosity flag. By default, set to 'True'.
692        return_embeddings: Whether to return the precomputed image embeddings. By default, set to 'False'.
693        batch_size: The batch size to compute image embeddings over planes. By default, set to '1'.
694        output_folder: The folder where the tracking results are stored in CTC format.
695        kwargs: Keyword arguments for the 'generate' method of the 'segmentor'.
696
697    Returns:
698        The tracking result. Each object is colored by its track id.
699        The lineages, which correspond to the cell divisions. Lineages are represented by a list of dicts,
700            with each dict encoding a lineage, where keys correspond to parent track ids.
701            Each key either maps to a list with two child track ids (cell division) or to an empty list (no division).
702    """
703    if Trackastra is None:
704        raise RuntimeError(
705            "Automatic tracking requires trackastra. You can install it via 'pip install trackastra'."
706        )
707
708    segmentation, image_embeddings = _segment_slices(
709        timeseries, predictor, segmentor, embedding_path, verbose,
710        tile_shape=tile_shape, halo=halo, batch_size=batch_size,
711        **kwargs,
712    )
713
714    segmentation, lineage = track_across_frames(
715        timeseries=timeseries,
716        segmentation=segmentation,
717        gap_closing=gap_closing,
718        min_time_extent=min_time_extent,
719        verbose=verbose,
720        output_folder=output_folder,
721    )
722
723    if return_embeddings:
724        return segmentation, lineage, image_embeddings
725    else:
726        return segmentation, lineage

Automatically track objects in a timesries based on per-frame automatic segmentation.

This function uses Trackastra: https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf for tracking. Please cite it if you use the automated tracking functionality.

Arguments:
  • timeseries: The input timeseries of images.
  • predictor: The SAM model.
  • segmentor: The instance segmentation class.
  • embedding_path: The path to save pre-computed embeddings.
  • gap_closing: If given, gaps in the segmentation are closed with a binary closing operation. The value is used to determine the number of iterations for the closing.
  • min_time_extent: Require a minimal extent in time for the tracked objects.
  • tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
  • halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
  • verbose: Verbosity flag. By default, set to 'True'.
  • return_embeddings: Whether to return the precomputed image embeddings. By default, set to 'False'.
  • batch_size: The batch size to compute image embeddings over planes. By default, set to '1'.
  • output_folder: The folder where the tracking results are stored in CTC format.
  • kwargs: Keyword arguments for the 'generate' method of the 'segmentor'.
Returns:

The tracking result. Each object is colored by its track id. The lineages, which correspond to the cell divisions. Lineages are represented by a list of dicts, with each dict encoding a lineage, where keys correspond to parent track ids. Each key either maps to a list with two child track ids (cell division) or to an empty list (no division).

def get_napari_track_data( segmentation: numpy.ndarray, lineages: List[Dict], n_threads: Optional[int] = None) -> Tuple[numpy.ndarray, Dict[int, List]]:
729def get_napari_track_data(
730    segmentation: np.ndarray, lineages: List[Dict], n_threads: Optional[int] = None
731) -> Tuple[np.ndarray, Dict[int, List]]:
732    """Derive the inputs for the napari tracking layer from a tracking result.
733
734    Args:
735        segmentation: The segmentation, after relabeling with track ids.
736        lineages: The lineage information.
737        n_threads: Number of threads for extracting the track data from the segmentation.
738
739    Returns:
740        The array with the track data expected by napari.
741        The parent dictionary for napari.
742    """
743    if n_threads is None:
744        n_threads = mp.cpu_count()
745
746    def compute_props(t):
747        props = regionprops(segmentation[t])
748        # Create the track data representation for napari, which expects:
749        # track_id, timepoint, y, x
750        track_data = np.array([[prop.label, t] + list(prop.centroid) for prop in props])
751        return track_data
752
753    with futures.ThreadPoolExecutor(n_threads) as tp:
754        track_data = list(tp.map(compute_props, range(segmentation.shape[0])))
755    track_data = [data for data in track_data if data.size > 0]
756    track_data = np.concatenate(track_data)
757
758    # The graph representation of napari uses the children as keys and the parents as values,
759    # whereas our representation uses parents as keys and children as values.
760    # Hence, we need to translate the representation.
761    parent_graph = {
762        child: [parent] for lineage in lineages for parent, children in lineage.items() for child in children
763    }
764
765    return track_data, parent_graph

Derive the inputs for the napari tracking layer from a tracking result.

Arguments:
  • segmentation: The segmentation, after relabeling with track ids.
  • lineages: The lineage information.
  • n_threads: Number of threads for extracting the track data from the segmentation.
Returns:

The array with the track data expected by napari. The parent dictionary for napari.