micro_sam.multi_dimensional_segmentation
Multi-dimensional segmentation with segment anything.
1"""Multi-dimensional segmentation with segment anything. 2""" 3 4import os 5import multiprocessing as mp 6from concurrent import futures 7from typing import Dict, List, Optional, Union, Tuple 8 9import networkx as nx 10import nifty 11import numpy as np 12import torch 13from scipy.ndimage import binary_closing 14from skimage.measure import label, regionprops 15from skimage.segmentation import relabel_sequential 16 17import elf.segmentation as seg_utils 18import elf.tracking.tracking_utils as track_utils 19from elf.tracking.motile_tracking import recolor_segmentation 20 21from segment_anything.predictor import SamPredictor 22 23try: 24 from napari.utils import progress as tqdm 25except ImportError: 26 from tqdm import tqdm 27 28try: 29 from trackastra.model import Trackastra 30 from trackastra.tracking import graph_to_ctc, graph_to_napari_tracks 31except ImportError: 32 Trackastra = None 33 34 35from . import util 36from .prompt_based_segmentation import segment_from_mask 37from .instance_segmentation import AMGBase 38 39 40PROJECTION_MODES = ("box", "mask", "points", "points_and_mask", "single_point") 41 42 43def _validate_projection(projection): 44 use_single_point = False 45 if isinstance(projection, str): 46 if projection == "mask": 47 use_box, use_mask, use_points = True, True, False 48 elif projection == "points": 49 use_box, use_mask, use_points = False, False, True 50 elif projection == "box": 51 use_box, use_mask, use_points = True, False, False 52 elif projection == "points_and_mask": 53 use_box, use_mask, use_points = False, True, True 54 elif projection == "single_point": 55 use_box, use_mask, use_points = False, False, True 56 use_single_point = True 57 else: 58 raise ValueError( 59 "Choose projection method from 'mask' / 'points' / 'box' / 'points_and_mask' / 'single_point'. " 60 f"You have passed the invalid option {projection}." 61 ) 62 elif isinstance(projection, dict): 63 assert len(projection.keys()) == 3, "There should be three parameters assigned for the projection method." 64 use_box, use_mask, use_points = projection["use_box"], projection["use_mask"], projection["use_points"] 65 else: 66 raise ValueError(f"{projection} is not a supported projection method.") 67 return use_box, use_mask, use_points, use_single_point 68 69 70# Advanced stopping criterions. 71# In practice these did not make a big difference, so we do not use this at the moment. 72# We still leave it here for reference. 73def _advanced_stopping_criteria( 74 z, seg_z, seg_prev, z_start, z_increment, segmentation, criterion_choice, score, increment 75): 76 def _compute_mean_iou_for_n_slices(z, increment, seg_z, n_slices): 77 iou_list = [ 78 util.compute_iou(segmentation[z - increment * _slice], seg_z) for _slice in range(1, n_slices+1) 79 ] 80 return np.mean(iou_list) 81 82 if criterion_choice == 1: 83 # 1. current metric: iou of current segmentation and the previous slice 84 iou = util.compute_iou(seg_prev, seg_z) 85 criterion = iou 86 87 elif criterion_choice == 2: 88 # 2. combining SAM iou + iou: curr. slice & first segmented slice + iou: curr. slice vs prev. slice 89 iou = util.compute_iou(seg_prev, seg_z) 90 ff_iou = util.compute_iou(segmentation[z_start], seg_z) 91 criterion = 0.5 * iou + 0.3 * score + 0.2 * ff_iou 92 93 elif criterion_choice == 3: 94 # 3. iou of current segmented slice w.r.t the previous n slices 95 criterion = _compute_mean_iou_for_n_slices(z, increment, seg_z, min(5, abs(z - z_start))) 96 97 return criterion 98 99 100def segment_mask_in_volume( 101 segmentation: np.ndarray, 102 predictor: SamPredictor, 103 image_embeddings: util.ImageEmbeddings, 104 segmented_slices: np.ndarray, 105 stop_lower: bool, 106 stop_upper: bool, 107 iou_threshold: float, 108 projection: Union[str, dict], 109 update_progress: Optional[callable] = None, 110 box_extension: float = 0.0, 111 verbose: bool = False, 112) -> Tuple[np.ndarray, Tuple[int, int]]: 113 """Segment an object mask in in volumetric data. 114 115 Args: 116 segmentation: The initial segmentation for the object. 117 predictor: The Segment Anything predictor. 118 image_embeddings: The precomputed image embeddings for the volume. 119 segmented_slices: List of slices for which this object has already been segmented. 120 stop_lower: Whether to stop at the lowest segmented slice. 121 stop_upper: Wheter to stop at the topmost segmented slice. 122 iou_threshold: The IOU threshold for continuing segmentation across 3d. 123 projection: The projection method to use. One of 'box', 'mask', 'points', 'points_and_mask' or 'single point'. 124 Pass a dictionary to choose the excact combination of projection modes. 125 update_progress: Callback to update an external progress bar. 126 box_extension: Extension factor for increasing the box size after projection. 127 By default, does not increase the projected box size. 128 verbose: Whether to print details about the segmentation steps. By default, set to 'True'. 129 130 Returns: 131 Array with the volumetric segmentation. 132 Tuple with the first and last segmented slice. 133 """ 134 use_box, use_mask, use_points, use_single_point = _validate_projection(projection) 135 136 if update_progress is None: 137 def update_progress(*args): 138 pass 139 140 def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None, verbose=False): 141 z = z_start + increment 142 while True: 143 if verbose: 144 print(f"Segment {z_start} to {z_stop}: segmenting slice {z}") 145 seg_prev = segmentation[z - increment] 146 seg_z, score, _ = segment_from_mask( 147 predictor, seg_prev, image_embeddings=image_embeddings, i=z, use_mask=use_mask, 148 use_box=use_box, use_points=use_points, box_extension=box_extension, return_all=True, 149 use_single_point=use_single_point, 150 ) 151 if threshold is not None: 152 iou = util.compute_iou(seg_prev, seg_z) 153 if iou < threshold: 154 if verbose: 155 msg = f"Segmentation stopped at slice {z} due to IOU {iou} < {threshold}." 156 print(msg) 157 break 158 159 segmentation[z] = seg_z 160 z += increment 161 if stopping_criterion(z, z_stop): 162 if verbose: 163 print(f"Segment {z_start} to {z_stop}: stop at slice {z}") 164 break 165 update_progress(1) 166 167 return z - increment 168 169 z0, z1 = int(segmented_slices.min()), int(segmented_slices.max()) 170 171 # segment below the min slice 172 if z0 > 0 and not stop_lower: 173 z_min = segment_range(z0, 0, -1, np.less, iou_threshold, verbose=verbose) 174 else: 175 z_min = z0 176 177 # segment above the max slice 178 if z1 < segmentation.shape[0] - 1 and not stop_upper: 179 z_max = segment_range(z1, segmentation.shape[0] - 1, 1, np.greater, iou_threshold, verbose=verbose) 180 else: 181 z_max = z1 182 183 # segment in between min and max slice 184 if z0 != z1: 185 for z_start, z_stop in zip(segmented_slices[:-1], segmented_slices[1:]): 186 slice_diff = z_stop - z_start 187 z_mid = int((z_start + z_stop) // 2) 188 189 if slice_diff == 1: # the slices are adjacent -> we don't need to do anything 190 pass 191 192 elif z_start == z0 and stop_lower: # the lower slice is stop: we just segment from upper 193 segment_range(z_stop, z_start, -1, np.less_equal, verbose=verbose) 194 195 elif z_stop == z1 and stop_upper: # the upper slice is stop: we just segment from lower 196 segment_range(z_start, z_stop, 1, np.greater_equal, verbose=verbose) 197 198 elif slice_diff == 2: # there is only one slice in between -> use combined mask 199 z = z_start + 1 200 seg_prompt = np.logical_or(segmentation[z_start] == 1, segmentation[z_stop] == 1) 201 segmentation[z] = segment_from_mask( 202 predictor, seg_prompt, image_embeddings=image_embeddings, i=z, 203 use_mask=use_mask, use_box=use_box, use_points=use_points, 204 box_extension=box_extension 205 ) 206 update_progress(1) 207 208 else: # there is a range of more than 2 slices in between -> segment ranges 209 # segment from bottom 210 segment_range( 211 z_start, z_mid, 1, np.greater_equal if slice_diff % 2 == 0 else np.greater, verbose=verbose 212 ) 213 # segment from top 214 segment_range(z_stop, z_mid, -1, np.less_equal, verbose=verbose) 215 # if the difference between start and stop is even, 216 # then we have a slice in the middle that is the same distance from top bottom 217 # in this case the slice is not segmented in the ranges above, and we segment it 218 # using the combined mask from the adjacent top and bottom slice as prompt 219 if slice_diff % 2 == 0: 220 seg_prompt = np.logical_or(segmentation[z_mid - 1] == 1, segmentation[z_mid + 1] == 1) 221 segmentation[z_mid] = segment_from_mask( 222 predictor, seg_prompt, image_embeddings=image_embeddings, i=z_mid, 223 use_mask=use_mask, use_box=use_box, use_points=use_points, 224 box_extension=box_extension 225 ) 226 update_progress(1) 227 228 return segmentation, (z_min, z_max) 229 230 231def _preprocess_closing(slice_segmentation, gap_closing, pbar_update): 232 binarized = slice_segmentation > 0 233 # Use a structuring element that only closes elements in z, to avoid merging objects in-plane. 234 structuring_element = np.zeros((3, 1, 1)) 235 structuring_element[:, 0, 0] = 1 236 closed_segmentation = binary_closing(binarized, iterations=gap_closing, structure=structuring_element) 237 238 new_segmentation = np.zeros_like(slice_segmentation) 239 n_slices = new_segmentation.shape[0] 240 241 def process_slice(z, offset): 242 seg_z = slice_segmentation[z] 243 244 # Closing does not work for the first and last gap slices 245 if z < gap_closing or z >= (n_slices - gap_closing): 246 seg_z, _, _ = relabel_sequential(seg_z, offset=offset) 247 offset = int(seg_z.max()) + 1 248 return seg_z, offset 249 250 # Apply connected components to the closed segmentation. 251 closed_z = label(closed_segmentation[z]) 252 253 # Map objects in the closed and initial segmentation. 254 # We take objects from the closed segmentation unless they 255 # have overlap with more than one object from the initial segmentation. 256 # This indicates wrong merging of closeby objects that we want to prevent. 257 matches = nifty.ground_truth.overlap(closed_z, seg_z) 258 matches = { 259 seg_id: matches.overlapArrays(seg_id, sorted=False)[0] for seg_id in range(1, int(closed_z.max() + 1)) 260 } 261 matches = {k: v[v != 0] for k, v in matches.items()} 262 263 ids_initial, ids_closed = [], [] 264 for seg_id, matched in matches.items(): 265 if len(matched) > 1: 266 ids_initial.extend(matched.tolist()) 267 else: 268 ids_closed.append(seg_id) 269 270 seg_new = np.zeros_like(seg_z) 271 closed_mask = np.isin(closed_z, ids_closed) 272 seg_new[closed_mask] = closed_z[closed_mask] 273 274 if ids_initial: 275 initial_mask = np.isin(seg_z, ids_initial) 276 seg_new[initial_mask] = relabel_sequential(seg_z[initial_mask], offset=seg_new.max() + 1)[0] 277 278 seg_new, _, _ = relabel_sequential(seg_new, offset=offset) 279 max_z = seg_new.max() 280 if max_z > 0: 281 offset = int(max_z) + 1 282 283 return seg_new, offset 284 285 # Further optimization: parallelize 286 offset = 1 287 for z in range(n_slices): 288 new_segmentation[z], offset = process_slice(z, offset) 289 pbar_update(1) 290 291 return new_segmentation 292 293 294def _filter_z_extent(segmentation, min_z_extent): 295 props = regionprops(segmentation) 296 filter_ids = [] 297 for prop in props: 298 box = prop.bbox 299 z_extent = box[3] - box[0] 300 if z_extent < min_z_extent: 301 filter_ids.append(prop.label) 302 if filter_ids: 303 segmentation[np.isin(segmentation, filter_ids)] = 0 304 return segmentation 305 306 307def merge_instance_segmentation_3d( 308 slice_segmentation: np.ndarray, 309 beta: float = 0.5, 310 with_background: bool = True, 311 gap_closing: Optional[int] = None, 312 min_z_extent: Optional[int] = None, 313 verbose: bool = True, 314 pbar_init: Optional[callable] = None, 315 pbar_update: Optional[callable] = None, 316) -> np.ndarray: 317 """Merge stacked 2d instance segmentations into a consistent 3d segmentation. 318 319 Solves a multicut problem based on the overlap of objects to merge across z. 320 321 Args: 322 slice_segmentation: The stacked segmentation across the slices. 323 We assume that the segmentation is labeled consecutive across z. 324 beta: The bias term for the multicut. Higher values lead to a larger 325 degree of over-segmentation and vice versa. by default, set to '0.5'. 326 with_background: Whether this is a segmentation problem with background. 327 In that case all edges connecting to the background are set to be repulsive. 328 By default, set to 'True'. 329 gap_closing: If given, gaps in the segmentation are closed with a binary closing 330 operation. The value is used to determine the number of iterations for the closing. 331 min_z_extent: Require a minimal extent in z for the segmented objects. 332 This can help to prevent segmentation artifacts. 333 verbose: Verbosity flag. By default, set to 'True'. 334 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 335 Can be used together with pbar_update to handle napari progress bar in other thread. 336 To enables using this function within a threadworker. 337 pbar_update: Callback to update an external progress bar. 338 339 Returns: 340 The merged segmentation. 341 """ 342 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 343 344 if gap_closing is not None and gap_closing > 0: 345 pbar_init(slice_segmentation.shape[0] + 1, "Merge segmentation") 346 slice_segmentation = _preprocess_closing(slice_segmentation, gap_closing, pbar_update) 347 else: 348 pbar_init(1, "Merge segmentation") 349 350 # Extract the overlap between slices. 351 edges = track_utils.compute_edges_from_overlap(slice_segmentation, verbose=False) 352 if len(edges) == 0: # Nothing to merge. 353 return slice_segmentation 354 355 uv_ids = np.array([[edge["source"], edge["target"]] for edge in edges]) 356 overlaps = np.array([edge["score"] for edge in edges]) 357 358 n_nodes = int(slice_segmentation.max() + 1) 359 graph = nifty.graph.undirectedGraph(n_nodes) 360 graph.insertEdges(uv_ids) 361 362 costs = seg_utils.multicut.compute_edge_costs(overlaps) 363 # Set background weights to be maximally repulsive. 364 if with_background: 365 bg_edges = (uv_ids == 0).any(axis=1) 366 costs[bg_edges] = -8.0 367 368 node_labels = seg_utils.multicut.multicut_decomposition(graph, 1.0 - costs, beta=beta) 369 370 segmentation = nifty.tools.take(node_labels, slice_segmentation) 371 if min_z_extent is not None and min_z_extent > 0: 372 segmentation = _filter_z_extent(segmentation, min_z_extent) 373 374 pbar_update(1) 375 pbar_close() 376 377 return segmentation 378 379 380def _segment_slices( 381 data, predictor, segmentor, embedding_path, verbose, tile_shape, halo, batch_size=1, **kwargs 382): 383 assert data.ndim == 3 384 385 image_embeddings = util.precompute_image_embeddings( 386 predictor=predictor, 387 input_=data, 388 save_path=embedding_path, 389 ndim=3, 390 tile_shape=tile_shape, 391 halo=halo, 392 verbose=verbose, 393 batch_size=batch_size, 394 ) 395 396 offset = 0 397 segmentation = np.zeros(data.shape, dtype="uint32") 398 399 for i in tqdm(range(segmentation.shape[0]), desc="Segment slices", disable=not verbose): 400 segmentor.initialize(data[i], image_embeddings=image_embeddings, verbose=False, i=i) 401 seg = segmentor.generate(**kwargs) 402 403 # Set offset for instance per slice. 404 max_z = int(seg.max()) 405 if max_z == 0: 406 continue 407 seg[seg != 0] += offset 408 offset = max_z + offset 409 segmentation[i] = seg 410 411 return segmentation, image_embeddings 412 413 414def automatic_3d_segmentation( 415 volume: np.ndarray, 416 predictor: SamPredictor, 417 segmentor: AMGBase, 418 embedding_path: Optional[Union[str, os.PathLike]] = None, 419 with_background: bool = True, 420 gap_closing: Optional[int] = None, 421 min_z_extent: Optional[int] = None, 422 tile_shape: Optional[Tuple[int, int]] = None, 423 halo: Optional[Tuple[int, int]] = None, 424 verbose: bool = True, 425 return_embeddings: bool = False, 426 batch_size: int = 1, 427 **kwargs, 428) -> np.ndarray: 429 """Automatically segment objects in a volume. 430 431 First segments slices individually in 2d and then merges them across 3d 432 based on overlap of objects between slices. 433 434 Args: 435 volume: The input volume. 436 predictor: The Segment Anything predictor. 437 segmentor: The instance segmentation class. 438 embedding_path: The path to save pre-computed embeddings. 439 with_background: Whether the segmentation has background. By default, set to 'True'. 440 gap_closing: If given, gaps in the segmentation are closed with a binary closing 441 operation. The value is used to determine the number of iterations for the closing. 442 min_z_extent: Require a minimal extent in z for the segmented objects. 443 This can help to prevent segmentation artifacts. 444 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 445 halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. 446 verbose: Verbosity flag. By default, set to 'True'. 447 return_embeddings: Whether to return the precomputed image embeddings. By default, set to 'False'. 448 batch_size: The batch size to compute image embeddings over planes. By default, set to '1'. 449 kwargs: Keyword arguments for the 'generate' method of the 'segmentor'. 450 451 Returns: 452 The segmentation. 453 """ 454 segmentation, image_embeddings = _segment_slices( 455 data=volume, 456 predictor=predictor, 457 segmentor=segmentor, 458 embedding_path=embedding_path, 459 verbose=verbose, 460 tile_shape=tile_shape, 461 halo=halo, 462 batch_size=batch_size, 463 **kwargs 464 ) 465 segmentation = merge_instance_segmentation_3d( 466 segmentation, 467 beta=0.5, 468 with_background=with_background, 469 gap_closing=gap_closing, 470 min_z_extent=min_z_extent, 471 verbose=verbose, 472 ) 473 if return_embeddings: 474 return segmentation, image_embeddings 475 else: 476 return segmentation 477 478 479def _filter_tracks(tracking_result, min_track_length): 480 props = regionprops(tracking_result) 481 discard_ids = [] 482 for prop in props: 483 label_id = prop.label 484 z_start, z_stop = prop.bbox[0], prop.bbox[3] 485 if z_stop - z_start < min_track_length: 486 discard_ids.append(label_id) 487 tracking_result[np.isin(tracking_result, discard_ids)] = 0 488 tracking_result, _, _ = relabel_sequential(tracking_result) 489 return tracking_result 490 491 492def _extract_tracks_and_lineages(segmentations, track_data, parent_graph): 493 # The track data has the following layout: n_tracks x 4 494 # With the following columns: 495 # track_id - id of the track (= result from trackastra) 496 # timepoint 497 # y coordinate 498 # x coordinate 499 500 # Use the last three columns to index the segmentation and get the segmentation id. 501 index = np.round(track_data[:, 1:], 0).astype("int32") 502 index = tuple(index[:, i] for i in range(index.shape[1])) 503 segmentation_ids = segmentations[index] 504 505 # Find the mapping of nodes (= segmented objects) to track-ids. 506 track_ids = track_data[:, 0].astype("int32") 507 assert len(segmentation_ids) == len(track_ids) 508 node_to_track = {k: v for k, v in zip(segmentation_ids, track_ids)} 509 510 # Find the lineages as connected components in the parent graph. 511 # First, we build a proper graph. 512 lineage_graph = nx.Graph() 513 for k, v in parent_graph.items(): 514 lineage_graph.add_edge(k, v) 515 516 # Then, find the connected components, and compute the lineage representation expected by micro-sam from it: 517 # E.g. if we have three lineages, the first consisting of three tracks and the second and third of one track each: 518 # [ 519 # {1: [2, 3]}, lineage with a dividing cell 520 # {4: []}, lineage with just one cell 521 # {5: []}, lineage with just one cell 522 # ] 523 524 # First, we fill the lineages which have one or more divisions, i.e. trees with more than one node. 525 lineages = [] 526 for component in nx.connected_components(lineage_graph): 527 root = next(iter(component)) 528 lineage_dict = {} 529 530 def dfs(node, parent): 531 # Avoid revisiting the parent node 532 children = [n for n in lineage_graph[node] if n != parent] 533 lineage_dict[node] = children 534 for child in children: 535 dfs(child, node) 536 537 dfs(root, None) 538 lineages.append(lineage_dict) 539 540 # Then add single node lineages, which are not reflected in the original graph. 541 all_tracks = set(track_ids.tolist()) 542 lineage_tracks = [] 543 for lineage in lineages: 544 for k, v in lineage.items(): 545 lineage_tracks.append(k) 546 lineage_tracks.extend(v) 547 singleton_tracks = list(all_tracks - set(lineage_tracks)) 548 lineages.extend([{track: []} for track in singleton_tracks]) 549 550 # Make sure node_to_track contains everything. 551 all_seg_ids = np.unique(segmentations) 552 missing_seg_ids = np.setdiff1d(all_seg_ids, list(node_to_track.keys())) 553 node_to_track.update({seg_id: 0 for seg_id in missing_seg_ids}) 554 return node_to_track, lineages 555 556 557def _filter_lineages(lineages, tracking_result): 558 track_ids = set(np.unique(tracking_result)) - {0} 559 filtered_lineages = [] 560 for lineage in lineages: 561 filtered_lineage = {k: v for k, v in lineage.items() if k in track_ids} 562 if filtered_lineage: 563 filtered_lineages.append(filtered_lineage) 564 return filtered_lineages 565 566 567def _tracking_impl(timeseries, segmentation, mode, min_time_extent, output_folder=None): 568 device = "cuda" if torch.cuda.is_available() else "cpu" 569 model = Trackastra.from_pretrained("general_2d", device=device) 570 lineage_graph = model.track(timeseries, segmentation, mode=mode) 571 track_data, parent_graph, _ = graph_to_napari_tracks(lineage_graph) 572 node_to_track, lineages = _extract_tracks_and_lineages(segmentation, track_data, parent_graph) 573 tracking_result = recolor_segmentation(segmentation, node_to_track) 574 575 if output_folder is not None: # Store tracking results in CTC format. 576 graph_to_ctc(lineage_graph, segmentation, outdir=output_folder) 577 578 # TODO 579 # We should check if trackastra supports this already. 580 # Filter out short tracks / lineages. 581 if min_time_extent is not None and min_time_extent > 0: 582 raise NotImplementedError 583 584 # Filter out pruned lineages. 585 # Mmay either be missing due to track filtering or non-consectutive track numbering in trackastra. 586 lineages = _filter_lineages(lineages, tracking_result) 587 588 return tracking_result, lineages 589 590 591def track_across_frames( 592 timeseries: np.ndarray, 593 segmentation: np.ndarray, 594 gap_closing: Optional[int] = None, 595 min_time_extent: Optional[int] = None, 596 verbose: bool = True, 597 pbar_init: Optional[callable] = None, 598 pbar_update: Optional[callable] = None, 599 output_folder: Optional[Union[os.PathLike, str]] = None, 600) -> Tuple[np.ndarray, List[Dict]]: 601 """Track segmented objects over time. 602 603 This function uses Trackastra: https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf 604 for tracking. Please cite it if you use the automated tracking functionality. 605 606 Args: 607 timeseries: The input timeseries of images. 608 segmentation: The segmentation. Expect segmentation results per frame 609 that are relabeled so that segmentation ids don't overlap. 610 gap_closing: If given, gaps in the segmentation are closed with a binary closing 611 operation. The value is used to determine the number of iterations for the closing. 612 min_time_extent: Require a minimal extent in time for the tracked objects. 613 verbose: Verbosity flag. By default, set to 'True'. 614 pbar_init: Function to initialize the progress bar. 615 pbar_update: Function to update the progress bar. 616 output_folder: The folder where the tracking results are stored in CTC format. 617 618 Returns: 619 The tracking result. Each object is colored by its track id. 620 The lineages, which correspond to the cell divisions. Lineages are represented by a list of dicts, 621 with each dict encoding a lineage, where keys correspond to parent track ids. 622 Each key either maps to a list with two child track ids (cell division) or to an empty list (no division). 623 """ 624 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init=pbar_init, pbar_update=pbar_update) 625 626 if gap_closing is not None and gap_closing > 0: 627 segmentation = _preprocess_closing(segmentation, gap_closing, pbar_update) 628 629 segmentation, lineage = _tracking_impl( 630 timeseries=np.asarray(timeseries), 631 segmentation=segmentation, 632 mode="greedy", 633 min_time_extent=min_time_extent, 634 output_folder=output_folder, 635 ) 636 return segmentation, lineage 637 638 639def automatic_tracking_implementation( 640 timeseries: np.ndarray, 641 predictor: SamPredictor, 642 segmentor: AMGBase, 643 embedding_path: Optional[Union[str, os.PathLike]] = None, 644 gap_closing: Optional[int] = None, 645 min_time_extent: Optional[int] = None, 646 tile_shape: Optional[Tuple[int, int]] = None, 647 halo: Optional[Tuple[int, int]] = None, 648 verbose: bool = True, 649 return_embeddings: bool = False, 650 batch_size: int = 1, 651 output_folder: Optional[Union[os.PathLike, str]] = None, 652 **kwargs, 653) -> Tuple[np.ndarray, List[Dict]]: 654 """Automatically track objects in a timesries based on per-frame automatic segmentation. 655 656 This function uses Trackastra: https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf 657 for tracking. Please cite it if you use the automated tracking functionality. 658 659 Args: 660 timeseries: The input timeseries of images. 661 predictor: The SAM model. 662 segmentor: The instance segmentation class. 663 embedding_path: The path to save pre-computed embeddings. 664 gap_closing: If given, gaps in the segmentation are closed with a binary closing 665 operation. The value is used to determine the number of iterations for the closing. 666 min_time_extent: Require a minimal extent in time for the tracked objects. 667 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 668 halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. 669 verbose: Verbosity flag. By default, set to 'True'. 670 return_embeddings: Whether to return the precomputed image embeddings. By default, set to 'False'. 671 batch_size: The batch size to compute image embeddings over planes. By default, set to '1'. 672 output_folder: The folder where the tracking results are stored in CTC format. 673 kwargs: Keyword arguments for the 'generate' method of the 'segmentor'. 674 675 Returns: 676 The tracking result. Each object is colored by its track id. 677 The lineages, which correspond to the cell divisions. Lineages are represented by a list of dicts, 678 with each dict encoding a lineage, where keys correspond to parent track ids. 679 Each key either maps to a list with two child track ids (cell division) or to an empty list (no division). 680 """ 681 if Trackastra is None: 682 raise RuntimeError( 683 "Automatic tracking requires trackastra. You can install it via 'pip install trackastra'." 684 ) 685 686 segmentation, image_embeddings = _segment_slices( 687 timeseries, predictor, segmentor, embedding_path, verbose, 688 tile_shape=tile_shape, halo=halo, batch_size=batch_size, 689 **kwargs, 690 ) 691 692 segmentation, lineage = track_across_frames( 693 timeseries=timeseries, 694 segmentation=segmentation, 695 gap_closing=gap_closing, 696 min_time_extent=min_time_extent, 697 verbose=verbose, 698 output_folder=output_folder, 699 ) 700 701 if return_embeddings: 702 return segmentation, lineage, image_embeddings 703 else: 704 return segmentation, lineage 705 706 707def get_napari_track_data( 708 segmentation: np.ndarray, lineages: List[Dict], n_threads: Optional[int] = None 709) -> Tuple[np.ndarray, Dict[int, List]]: 710 """Derive the inputs for the napari tracking layer from a tracking result. 711 712 Args: 713 segmentation: The segmentation, after relabeling with track ids. 714 lineages: The lineage information. 715 n_threads: Number of threads for extracting the track data from the segmentation. 716 717 Returns: 718 The array with the track data expected by napari. 719 The parent dictionary for napari. 720 """ 721 if n_threads is None: 722 n_threads = mp.cpu_count() 723 724 def compute_props(t): 725 props = regionprops(segmentation[t]) 726 # Create the track data representation for napari, which expects: 727 # track_id, timepoint, y, x 728 track_data = np.array([[prop.label, t] + list(prop.centroid) for prop in props]) 729 return track_data 730 731 with futures.ThreadPoolExecutor(n_threads) as tp: 732 track_data = list(tp.map(compute_props, range(segmentation.shape[0]))) 733 track_data = [data for data in track_data if data.size > 0] 734 track_data = np.concatenate(track_data) 735 736 # The graph representation of napari uses the children as keys and the parents as values, 737 # whereas our representation uses parents as keys and children as values. 738 # Hence, we need to translate the representation. 739 parent_graph = { 740 child: [parent] for lineage in lineages for parent, children in lineage.items() for child in children 741 } 742 743 return track_data, parent_graph
101def segment_mask_in_volume( 102 segmentation: np.ndarray, 103 predictor: SamPredictor, 104 image_embeddings: util.ImageEmbeddings, 105 segmented_slices: np.ndarray, 106 stop_lower: bool, 107 stop_upper: bool, 108 iou_threshold: float, 109 projection: Union[str, dict], 110 update_progress: Optional[callable] = None, 111 box_extension: float = 0.0, 112 verbose: bool = False, 113) -> Tuple[np.ndarray, Tuple[int, int]]: 114 """Segment an object mask in in volumetric data. 115 116 Args: 117 segmentation: The initial segmentation for the object. 118 predictor: The Segment Anything predictor. 119 image_embeddings: The precomputed image embeddings for the volume. 120 segmented_slices: List of slices for which this object has already been segmented. 121 stop_lower: Whether to stop at the lowest segmented slice. 122 stop_upper: Wheter to stop at the topmost segmented slice. 123 iou_threshold: The IOU threshold for continuing segmentation across 3d. 124 projection: The projection method to use. One of 'box', 'mask', 'points', 'points_and_mask' or 'single point'. 125 Pass a dictionary to choose the excact combination of projection modes. 126 update_progress: Callback to update an external progress bar. 127 box_extension: Extension factor for increasing the box size after projection. 128 By default, does not increase the projected box size. 129 verbose: Whether to print details about the segmentation steps. By default, set to 'True'. 130 131 Returns: 132 Array with the volumetric segmentation. 133 Tuple with the first and last segmented slice. 134 """ 135 use_box, use_mask, use_points, use_single_point = _validate_projection(projection) 136 137 if update_progress is None: 138 def update_progress(*args): 139 pass 140 141 def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None, verbose=False): 142 z = z_start + increment 143 while True: 144 if verbose: 145 print(f"Segment {z_start} to {z_stop}: segmenting slice {z}") 146 seg_prev = segmentation[z - increment] 147 seg_z, score, _ = segment_from_mask( 148 predictor, seg_prev, image_embeddings=image_embeddings, i=z, use_mask=use_mask, 149 use_box=use_box, use_points=use_points, box_extension=box_extension, return_all=True, 150 use_single_point=use_single_point, 151 ) 152 if threshold is not None: 153 iou = util.compute_iou(seg_prev, seg_z) 154 if iou < threshold: 155 if verbose: 156 msg = f"Segmentation stopped at slice {z} due to IOU {iou} < {threshold}." 157 print(msg) 158 break 159 160 segmentation[z] = seg_z 161 z += increment 162 if stopping_criterion(z, z_stop): 163 if verbose: 164 print(f"Segment {z_start} to {z_stop}: stop at slice {z}") 165 break 166 update_progress(1) 167 168 return z - increment 169 170 z0, z1 = int(segmented_slices.min()), int(segmented_slices.max()) 171 172 # segment below the min slice 173 if z0 > 0 and not stop_lower: 174 z_min = segment_range(z0, 0, -1, np.less, iou_threshold, verbose=verbose) 175 else: 176 z_min = z0 177 178 # segment above the max slice 179 if z1 < segmentation.shape[0] - 1 and not stop_upper: 180 z_max = segment_range(z1, segmentation.shape[0] - 1, 1, np.greater, iou_threshold, verbose=verbose) 181 else: 182 z_max = z1 183 184 # segment in between min and max slice 185 if z0 != z1: 186 for z_start, z_stop in zip(segmented_slices[:-1], segmented_slices[1:]): 187 slice_diff = z_stop - z_start 188 z_mid = int((z_start + z_stop) // 2) 189 190 if slice_diff == 1: # the slices are adjacent -> we don't need to do anything 191 pass 192 193 elif z_start == z0 and stop_lower: # the lower slice is stop: we just segment from upper 194 segment_range(z_stop, z_start, -1, np.less_equal, verbose=verbose) 195 196 elif z_stop == z1 and stop_upper: # the upper slice is stop: we just segment from lower 197 segment_range(z_start, z_stop, 1, np.greater_equal, verbose=verbose) 198 199 elif slice_diff == 2: # there is only one slice in between -> use combined mask 200 z = z_start + 1 201 seg_prompt = np.logical_or(segmentation[z_start] == 1, segmentation[z_stop] == 1) 202 segmentation[z] = segment_from_mask( 203 predictor, seg_prompt, image_embeddings=image_embeddings, i=z, 204 use_mask=use_mask, use_box=use_box, use_points=use_points, 205 box_extension=box_extension 206 ) 207 update_progress(1) 208 209 else: # there is a range of more than 2 slices in between -> segment ranges 210 # segment from bottom 211 segment_range( 212 z_start, z_mid, 1, np.greater_equal if slice_diff % 2 == 0 else np.greater, verbose=verbose 213 ) 214 # segment from top 215 segment_range(z_stop, z_mid, -1, np.less_equal, verbose=verbose) 216 # if the difference between start and stop is even, 217 # then we have a slice in the middle that is the same distance from top bottom 218 # in this case the slice is not segmented in the ranges above, and we segment it 219 # using the combined mask from the adjacent top and bottom slice as prompt 220 if slice_diff % 2 == 0: 221 seg_prompt = np.logical_or(segmentation[z_mid - 1] == 1, segmentation[z_mid + 1] == 1) 222 segmentation[z_mid] = segment_from_mask( 223 predictor, seg_prompt, image_embeddings=image_embeddings, i=z_mid, 224 use_mask=use_mask, use_box=use_box, use_points=use_points, 225 box_extension=box_extension 226 ) 227 update_progress(1) 228 229 return segmentation, (z_min, z_max)
Segment an object mask in in volumetric data.
Arguments:
- segmentation: The initial segmentation for the object.
- predictor: The Segment Anything predictor.
- image_embeddings: The precomputed image embeddings for the volume.
- segmented_slices: List of slices for which this object has already been segmented.
- stop_lower: Whether to stop at the lowest segmented slice.
- stop_upper: Wheter to stop at the topmost segmented slice.
- iou_threshold: The IOU threshold for continuing segmentation across 3d.
- projection: The projection method to use. One of 'box', 'mask', 'points', 'points_and_mask' or 'single point'. Pass a dictionary to choose the excact combination of projection modes.
- update_progress: Callback to update an external progress bar.
- box_extension: Extension factor for increasing the box size after projection. By default, does not increase the projected box size.
- verbose: Whether to print details about the segmentation steps. By default, set to 'True'.
Returns:
Array with the volumetric segmentation. Tuple with the first and last segmented slice.
308def merge_instance_segmentation_3d( 309 slice_segmentation: np.ndarray, 310 beta: float = 0.5, 311 with_background: bool = True, 312 gap_closing: Optional[int] = None, 313 min_z_extent: Optional[int] = None, 314 verbose: bool = True, 315 pbar_init: Optional[callable] = None, 316 pbar_update: Optional[callable] = None, 317) -> np.ndarray: 318 """Merge stacked 2d instance segmentations into a consistent 3d segmentation. 319 320 Solves a multicut problem based on the overlap of objects to merge across z. 321 322 Args: 323 slice_segmentation: The stacked segmentation across the slices. 324 We assume that the segmentation is labeled consecutive across z. 325 beta: The bias term for the multicut. Higher values lead to a larger 326 degree of over-segmentation and vice versa. by default, set to '0.5'. 327 with_background: Whether this is a segmentation problem with background. 328 In that case all edges connecting to the background are set to be repulsive. 329 By default, set to 'True'. 330 gap_closing: If given, gaps in the segmentation are closed with a binary closing 331 operation. The value is used to determine the number of iterations for the closing. 332 min_z_extent: Require a minimal extent in z for the segmented objects. 333 This can help to prevent segmentation artifacts. 334 verbose: Verbosity flag. By default, set to 'True'. 335 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 336 Can be used together with pbar_update to handle napari progress bar in other thread. 337 To enables using this function within a threadworker. 338 pbar_update: Callback to update an external progress bar. 339 340 Returns: 341 The merged segmentation. 342 """ 343 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 344 345 if gap_closing is not None and gap_closing > 0: 346 pbar_init(slice_segmentation.shape[0] + 1, "Merge segmentation") 347 slice_segmentation = _preprocess_closing(slice_segmentation, gap_closing, pbar_update) 348 else: 349 pbar_init(1, "Merge segmentation") 350 351 # Extract the overlap between slices. 352 edges = track_utils.compute_edges_from_overlap(slice_segmentation, verbose=False) 353 if len(edges) == 0: # Nothing to merge. 354 return slice_segmentation 355 356 uv_ids = np.array([[edge["source"], edge["target"]] for edge in edges]) 357 overlaps = np.array([edge["score"] for edge in edges]) 358 359 n_nodes = int(slice_segmentation.max() + 1) 360 graph = nifty.graph.undirectedGraph(n_nodes) 361 graph.insertEdges(uv_ids) 362 363 costs = seg_utils.multicut.compute_edge_costs(overlaps) 364 # Set background weights to be maximally repulsive. 365 if with_background: 366 bg_edges = (uv_ids == 0).any(axis=1) 367 costs[bg_edges] = -8.0 368 369 node_labels = seg_utils.multicut.multicut_decomposition(graph, 1.0 - costs, beta=beta) 370 371 segmentation = nifty.tools.take(node_labels, slice_segmentation) 372 if min_z_extent is not None and min_z_extent > 0: 373 segmentation = _filter_z_extent(segmentation, min_z_extent) 374 375 pbar_update(1) 376 pbar_close() 377 378 return segmentation
Merge stacked 2d instance segmentations into a consistent 3d segmentation.
Solves a multicut problem based on the overlap of objects to merge across z.
Arguments:
- slice_segmentation: The stacked segmentation across the slices. We assume that the segmentation is labeled consecutive across z.
- beta: The bias term for the multicut. Higher values lead to a larger degree of over-segmentation and vice versa. by default, set to '0.5'.
- with_background: Whether this is a segmentation problem with background. In that case all edges connecting to the background are set to be repulsive. By default, set to 'True'.
- gap_closing: If given, gaps in the segmentation are closed with a binary closing operation. The value is used to determine the number of iterations for the closing.
- min_z_extent: Require a minimal extent in z for the segmented objects. This can help to prevent segmentation artifacts.
- verbose: Verbosity flag. By default, set to 'True'.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
Returns:
The merged segmentation.
415def automatic_3d_segmentation( 416 volume: np.ndarray, 417 predictor: SamPredictor, 418 segmentor: AMGBase, 419 embedding_path: Optional[Union[str, os.PathLike]] = None, 420 with_background: bool = True, 421 gap_closing: Optional[int] = None, 422 min_z_extent: Optional[int] = None, 423 tile_shape: Optional[Tuple[int, int]] = None, 424 halo: Optional[Tuple[int, int]] = None, 425 verbose: bool = True, 426 return_embeddings: bool = False, 427 batch_size: int = 1, 428 **kwargs, 429) -> np.ndarray: 430 """Automatically segment objects in a volume. 431 432 First segments slices individually in 2d and then merges them across 3d 433 based on overlap of objects between slices. 434 435 Args: 436 volume: The input volume. 437 predictor: The Segment Anything predictor. 438 segmentor: The instance segmentation class. 439 embedding_path: The path to save pre-computed embeddings. 440 with_background: Whether the segmentation has background. By default, set to 'True'. 441 gap_closing: If given, gaps in the segmentation are closed with a binary closing 442 operation. The value is used to determine the number of iterations for the closing. 443 min_z_extent: Require a minimal extent in z for the segmented objects. 444 This can help to prevent segmentation artifacts. 445 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 446 halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. 447 verbose: Verbosity flag. By default, set to 'True'. 448 return_embeddings: Whether to return the precomputed image embeddings. By default, set to 'False'. 449 batch_size: The batch size to compute image embeddings over planes. By default, set to '1'. 450 kwargs: Keyword arguments for the 'generate' method of the 'segmentor'. 451 452 Returns: 453 The segmentation. 454 """ 455 segmentation, image_embeddings = _segment_slices( 456 data=volume, 457 predictor=predictor, 458 segmentor=segmentor, 459 embedding_path=embedding_path, 460 verbose=verbose, 461 tile_shape=tile_shape, 462 halo=halo, 463 batch_size=batch_size, 464 **kwargs 465 ) 466 segmentation = merge_instance_segmentation_3d( 467 segmentation, 468 beta=0.5, 469 with_background=with_background, 470 gap_closing=gap_closing, 471 min_z_extent=min_z_extent, 472 verbose=verbose, 473 ) 474 if return_embeddings: 475 return segmentation, image_embeddings 476 else: 477 return segmentation
Automatically segment objects in a volume.
First segments slices individually in 2d and then merges them across 3d based on overlap of objects between slices.
Arguments:
- volume: The input volume.
- predictor: The Segment Anything predictor.
- segmentor: The instance segmentation class.
- embedding_path: The path to save pre-computed embeddings.
- with_background: Whether the segmentation has background. By default, set to 'True'.
- gap_closing: If given, gaps in the segmentation are closed with a binary closing operation. The value is used to determine the number of iterations for the closing.
- min_z_extent: Require a minimal extent in z for the segmented objects. This can help to prevent segmentation artifacts.
- tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
- halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
- verbose: Verbosity flag. By default, set to 'True'.
- return_embeddings: Whether to return the precomputed image embeddings. By default, set to 'False'.
- batch_size: The batch size to compute image embeddings over planes. By default, set to '1'.
- kwargs: Keyword arguments for the 'generate' method of the 'segmentor'.
Returns:
The segmentation.
592def track_across_frames( 593 timeseries: np.ndarray, 594 segmentation: np.ndarray, 595 gap_closing: Optional[int] = None, 596 min_time_extent: Optional[int] = None, 597 verbose: bool = True, 598 pbar_init: Optional[callable] = None, 599 pbar_update: Optional[callable] = None, 600 output_folder: Optional[Union[os.PathLike, str]] = None, 601) -> Tuple[np.ndarray, List[Dict]]: 602 """Track segmented objects over time. 603 604 This function uses Trackastra: https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf 605 for tracking. Please cite it if you use the automated tracking functionality. 606 607 Args: 608 timeseries: The input timeseries of images. 609 segmentation: The segmentation. Expect segmentation results per frame 610 that are relabeled so that segmentation ids don't overlap. 611 gap_closing: If given, gaps in the segmentation are closed with a binary closing 612 operation. The value is used to determine the number of iterations for the closing. 613 min_time_extent: Require a minimal extent in time for the tracked objects. 614 verbose: Verbosity flag. By default, set to 'True'. 615 pbar_init: Function to initialize the progress bar. 616 pbar_update: Function to update the progress bar. 617 output_folder: The folder where the tracking results are stored in CTC format. 618 619 Returns: 620 The tracking result. Each object is colored by its track id. 621 The lineages, which correspond to the cell divisions. Lineages are represented by a list of dicts, 622 with each dict encoding a lineage, where keys correspond to parent track ids. 623 Each key either maps to a list with two child track ids (cell division) or to an empty list (no division). 624 """ 625 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init=pbar_init, pbar_update=pbar_update) 626 627 if gap_closing is not None and gap_closing > 0: 628 segmentation = _preprocess_closing(segmentation, gap_closing, pbar_update) 629 630 segmentation, lineage = _tracking_impl( 631 timeseries=np.asarray(timeseries), 632 segmentation=segmentation, 633 mode="greedy", 634 min_time_extent=min_time_extent, 635 output_folder=output_folder, 636 ) 637 return segmentation, lineage
Track segmented objects over time.
This function uses Trackastra: https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf for tracking. Please cite it if you use the automated tracking functionality.
Arguments:
- timeseries: The input timeseries of images.
- segmentation: The segmentation. Expect segmentation results per frame that are relabeled so that segmentation ids don't overlap.
- gap_closing: If given, gaps in the segmentation are closed with a binary closing operation. The value is used to determine the number of iterations for the closing.
- min_time_extent: Require a minimal extent in time for the tracked objects.
- verbose: Verbosity flag. By default, set to 'True'.
- pbar_init: Function to initialize the progress bar.
- pbar_update: Function to update the progress bar.
- output_folder: The folder where the tracking results are stored in CTC format.
Returns:
The tracking result. Each object is colored by its track id. The lineages, which correspond to the cell divisions. Lineages are represented by a list of dicts, with each dict encoding a lineage, where keys correspond to parent track ids. Each key either maps to a list with two child track ids (cell division) or to an empty list (no division).
640def automatic_tracking_implementation( 641 timeseries: np.ndarray, 642 predictor: SamPredictor, 643 segmentor: AMGBase, 644 embedding_path: Optional[Union[str, os.PathLike]] = None, 645 gap_closing: Optional[int] = None, 646 min_time_extent: Optional[int] = None, 647 tile_shape: Optional[Tuple[int, int]] = None, 648 halo: Optional[Tuple[int, int]] = None, 649 verbose: bool = True, 650 return_embeddings: bool = False, 651 batch_size: int = 1, 652 output_folder: Optional[Union[os.PathLike, str]] = None, 653 **kwargs, 654) -> Tuple[np.ndarray, List[Dict]]: 655 """Automatically track objects in a timesries based on per-frame automatic segmentation. 656 657 This function uses Trackastra: https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf 658 for tracking. Please cite it if you use the automated tracking functionality. 659 660 Args: 661 timeseries: The input timeseries of images. 662 predictor: The SAM model. 663 segmentor: The instance segmentation class. 664 embedding_path: The path to save pre-computed embeddings. 665 gap_closing: If given, gaps in the segmentation are closed with a binary closing 666 operation. The value is used to determine the number of iterations for the closing. 667 min_time_extent: Require a minimal extent in time for the tracked objects. 668 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 669 halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. 670 verbose: Verbosity flag. By default, set to 'True'. 671 return_embeddings: Whether to return the precomputed image embeddings. By default, set to 'False'. 672 batch_size: The batch size to compute image embeddings over planes. By default, set to '1'. 673 output_folder: The folder where the tracking results are stored in CTC format. 674 kwargs: Keyword arguments for the 'generate' method of the 'segmentor'. 675 676 Returns: 677 The tracking result. Each object is colored by its track id. 678 The lineages, which correspond to the cell divisions. Lineages are represented by a list of dicts, 679 with each dict encoding a lineage, where keys correspond to parent track ids. 680 Each key either maps to a list with two child track ids (cell division) or to an empty list (no division). 681 """ 682 if Trackastra is None: 683 raise RuntimeError( 684 "Automatic tracking requires trackastra. You can install it via 'pip install trackastra'." 685 ) 686 687 segmentation, image_embeddings = _segment_slices( 688 timeseries, predictor, segmentor, embedding_path, verbose, 689 tile_shape=tile_shape, halo=halo, batch_size=batch_size, 690 **kwargs, 691 ) 692 693 segmentation, lineage = track_across_frames( 694 timeseries=timeseries, 695 segmentation=segmentation, 696 gap_closing=gap_closing, 697 min_time_extent=min_time_extent, 698 verbose=verbose, 699 output_folder=output_folder, 700 ) 701 702 if return_embeddings: 703 return segmentation, lineage, image_embeddings 704 else: 705 return segmentation, lineage
Automatically track objects in a timesries based on per-frame automatic segmentation.
This function uses Trackastra: https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf for tracking. Please cite it if you use the automated tracking functionality.
Arguments:
- timeseries: The input timeseries of images.
- predictor: The SAM model.
- segmentor: The instance segmentation class.
- embedding_path: The path to save pre-computed embeddings.
- gap_closing: If given, gaps in the segmentation are closed with a binary closing operation. The value is used to determine the number of iterations for the closing.
- min_time_extent: Require a minimal extent in time for the tracked objects.
- tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
- halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
- verbose: Verbosity flag. By default, set to 'True'.
- return_embeddings: Whether to return the precomputed image embeddings. By default, set to 'False'.
- batch_size: The batch size to compute image embeddings over planes. By default, set to '1'.
- output_folder: The folder where the tracking results are stored in CTC format.
- kwargs: Keyword arguments for the 'generate' method of the 'segmentor'.
Returns:
The tracking result. Each object is colored by its track id. The lineages, which correspond to the cell divisions. Lineages are represented by a list of dicts, with each dict encoding a lineage, where keys correspond to parent track ids. Each key either maps to a list with two child track ids (cell division) or to an empty list (no division).
708def get_napari_track_data( 709 segmentation: np.ndarray, lineages: List[Dict], n_threads: Optional[int] = None 710) -> Tuple[np.ndarray, Dict[int, List]]: 711 """Derive the inputs for the napari tracking layer from a tracking result. 712 713 Args: 714 segmentation: The segmentation, after relabeling with track ids. 715 lineages: The lineage information. 716 n_threads: Number of threads for extracting the track data from the segmentation. 717 718 Returns: 719 The array with the track data expected by napari. 720 The parent dictionary for napari. 721 """ 722 if n_threads is None: 723 n_threads = mp.cpu_count() 724 725 def compute_props(t): 726 props = regionprops(segmentation[t]) 727 # Create the track data representation for napari, which expects: 728 # track_id, timepoint, y, x 729 track_data = np.array([[prop.label, t] + list(prop.centroid) for prop in props]) 730 return track_data 731 732 with futures.ThreadPoolExecutor(n_threads) as tp: 733 track_data = list(tp.map(compute_props, range(segmentation.shape[0]))) 734 track_data = [data for data in track_data if data.size > 0] 735 track_data = np.concatenate(track_data) 736 737 # The graph representation of napari uses the children as keys and the parents as values, 738 # whereas our representation uses parents as keys and children as values. 739 # Hence, we need to translate the representation. 740 parent_graph = { 741 child: [parent] for lineage in lineages for parent, children in lineage.items() for child in children 742 } 743 744 return track_data, parent_graph
Derive the inputs for the napari tracking layer from a tracking result.
Arguments:
- segmentation: The segmentation, after relabeling with track ids.
- lineages: The lineage information.
- n_threads: Number of threads for extracting the track data from the segmentation.
Returns:
The array with the track data expected by napari. The parent dictionary for napari.