micro_sam.multi_dimensional_segmentation
Multi-dimensional segmentation with segment anything.
1"""Multi-dimensional segmentation with segment anything. 2""" 3 4import os 5import multiprocessing as mp 6import warnings 7from concurrent import futures 8from typing import Dict, List, Optional, Union, Tuple 9 10import networkx as nx 11import numpy as np 12import torch 13from scipy.ndimage import binary_closing 14from skimage.measure import regionprops 15 16from bioimage_cpp.segmentation import label, relabel_sequential 17from bioimage_cpp.graph import UndirectedGraph 18from bioimage_cpp.utils import segmentation_overlap 19 20import elf.segmentation as seg_utils 21import elf.tracking.tracking_utils as track_utils 22from elf.tracking.motile_tracking import recolor_segmentation 23 24from segment_anything.predictor import SamPredictor 25 26try: 27 from napari.utils import progress as tqdm 28except ImportError: 29 from tqdm import tqdm 30 31try: 32 from trackastra.model import Trackastra 33 from trackastra.tracking import graph_to_ctc, graph_to_napari_tracks 34except ImportError: 35 Trackastra = None 36 graph_to_ctc = None 37 graph_to_napari_tracks = None 38 39 40from . import util 41from .prompt_based_segmentation import segment_from_mask 42from .instance_segmentation import AMGBase 43 44 45PROJECTION_MODES = ("box", "mask", "points", "points_and_mask", "single_point") 46 47 48def _validate_projection(projection): 49 use_single_point = False 50 if isinstance(projection, str): 51 if projection == "mask": 52 use_box, use_mask, use_points = True, True, False 53 elif projection == "points": 54 use_box, use_mask, use_points = False, False, True 55 elif projection == "box": 56 use_box, use_mask, use_points = True, False, False 57 elif projection == "points_and_mask": 58 use_box, use_mask, use_points = False, True, True 59 elif projection == "single_point": 60 use_box, use_mask, use_points = False, False, True 61 use_single_point = True 62 else: 63 raise ValueError( 64 "Choose projection method from 'mask' / 'points' / 'box' / 'points_and_mask' / 'single_point'. " 65 f"You have passed the invalid option {projection}." 66 ) 67 elif isinstance(projection, dict): 68 assert len(projection.keys()) == 3, "There should be three parameters assigned for the projection method." 69 use_box, use_mask, use_points = projection["use_box"], projection["use_mask"], projection["use_points"] 70 else: 71 raise ValueError(f"{projection} is not a supported projection method.") 72 return use_box, use_mask, use_points, use_single_point 73 74 75# Advanced stopping criterions. 76# In practice these did not make a big difference, so we do not use this at the moment. 77# We still leave it here for reference. 78def _advanced_stopping_criteria( 79 z, seg_z, seg_prev, z_start, z_increment, segmentation, criterion_choice, score, increment 80): 81 def _compute_mean_iou_for_n_slices(z, increment, seg_z, n_slices): 82 iou_list = [ 83 util.compute_iou(segmentation[z - increment * _slice], seg_z) for _slice in range(1, n_slices+1) 84 ] 85 return np.mean(iou_list) 86 87 if criterion_choice == 1: 88 # 1. current metric: iou of current segmentation and the previous slice 89 iou = util.compute_iou(seg_prev, seg_z) 90 criterion = iou 91 92 elif criterion_choice == 2: 93 # 2. combining SAM iou + iou: curr. slice & first segmented slice + iou: curr. slice vs prev. slice 94 iou = util.compute_iou(seg_prev, seg_z) 95 ff_iou = util.compute_iou(segmentation[z_start], seg_z) 96 criterion = 0.5 * iou + 0.3 * score + 0.2 * ff_iou 97 98 elif criterion_choice == 3: 99 # 3. iou of current segmented slice w.r.t the previous n slices 100 criterion = _compute_mean_iou_for_n_slices(z, increment, seg_z, min(5, abs(z - z_start))) 101 102 return criterion 103 104 105def segment_mask_in_volume( 106 segmentation: np.ndarray, 107 predictor: SamPredictor, 108 image_embeddings: util.ImageEmbeddings, 109 segmented_slices: np.ndarray, 110 stop_lower: bool, 111 stop_upper: bool, 112 iou_threshold: float, 113 projection: Union[str, dict], 114 update_progress: Optional[callable] = None, 115 box_extension: float = 0.0, 116 verbose: bool = False, 117) -> Tuple[np.ndarray, Tuple[int, int]]: 118 """Segment an object mask in in volumetric data. 119 120 Args: 121 segmentation: The initial segmentation for the object. 122 predictor: The Segment Anything predictor. 123 image_embeddings: The precomputed image embeddings for the volume. 124 segmented_slices: List of slices for which this object has already been segmented. 125 stop_lower: Whether to stop at the lowest segmented slice. 126 stop_upper: Wheter to stop at the topmost segmented slice. 127 iou_threshold: The IOU threshold for continuing segmentation across 3d. 128 projection: The projection method to use. One of 'box', 'mask', 'points', 'points_and_mask' or 'single point'. 129 Pass a dictionary to choose the excact combination of projection modes. 130 update_progress: Callback to update an external progress bar. 131 box_extension: Extension factor for increasing the box size after projection. 132 By default, does not increase the projected box size. 133 verbose: Whether to print details about the segmentation steps. By default, set to 'True'. 134 135 Returns: 136 Array with the volumetric segmentation. 137 Tuple with the first and last segmented slice. 138 """ 139 use_box, use_mask, use_points, use_single_point = _validate_projection(projection) 140 141 if update_progress is None: 142 def update_progress(*args): 143 pass 144 145 def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None, verbose=False): 146 z = z_start + increment 147 while True: 148 if verbose: 149 print(f"Segment {z_start} to {z_stop}: segmenting slice {z}") 150 seg_prev = segmentation[z - increment] 151 seg_z, score, _ = segment_from_mask( 152 predictor, seg_prev, image_embeddings=image_embeddings, i=z, use_mask=use_mask, 153 use_box=use_box, use_points=use_points, box_extension=box_extension, return_all=True, 154 use_single_point=use_single_point, 155 ) 156 if threshold is not None: 157 iou = util.compute_iou(seg_prev, seg_z) 158 if iou < threshold: 159 if verbose: 160 msg = f"Segmentation stopped at slice {z} due to IOU {iou} < {threshold}." 161 print(msg) 162 break 163 164 segmentation[z] = seg_z 165 z += increment 166 if stopping_criterion(z, z_stop): 167 if verbose: 168 print(f"Segment {z_start} to {z_stop}: stop at slice {z}") 169 break 170 update_progress(1) 171 172 return z - increment 173 174 z0, z1 = int(segmented_slices.min()), int(segmented_slices.max()) 175 176 # segment below the min slice 177 if z0 > 0 and not stop_lower: 178 z_min = segment_range(z0, 0, -1, np.less, iou_threshold, verbose=verbose) 179 else: 180 z_min = z0 181 182 # segment above the max slice 183 if z1 < segmentation.shape[0] - 1 and not stop_upper: 184 z_max = segment_range(z1, segmentation.shape[0] - 1, 1, np.greater, iou_threshold, verbose=verbose) 185 else: 186 z_max = z1 187 188 # segment in between min and max slice 189 if z0 != z1: 190 for z_start, z_stop in zip(segmented_slices[:-1], segmented_slices[1:]): 191 slice_diff = z_stop - z_start 192 z_mid = int((z_start + z_stop) // 2) 193 194 if slice_diff == 1: # the slices are adjacent -> we don't need to do anything 195 pass 196 197 elif z_start == z0 and stop_lower: # the lower slice is stop: we just segment from upper 198 segment_range(z_stop, z_start, -1, np.less_equal, verbose=verbose) 199 200 elif z_stop == z1 and stop_upper: # the upper slice is stop: we just segment from lower 201 segment_range(z_start, z_stop, 1, np.greater_equal, verbose=verbose) 202 203 elif slice_diff == 2: # there is only one slice in between -> use combined mask 204 z = z_start + 1 205 seg_prompt = np.logical_or(segmentation[z_start] == 1, segmentation[z_stop] == 1) 206 segmentation[z] = segment_from_mask( 207 predictor, seg_prompt, image_embeddings=image_embeddings, i=z, 208 use_mask=use_mask, use_box=use_box, use_points=use_points, 209 box_extension=box_extension 210 ) 211 update_progress(1) 212 213 else: # there is a range of more than 2 slices in between -> segment ranges 214 # segment from bottom 215 segment_range( 216 z_start, z_mid, 1, np.greater_equal if slice_diff % 2 == 0 else np.greater, verbose=verbose 217 ) 218 # segment from top 219 segment_range(z_stop, z_mid, -1, np.less_equal, verbose=verbose) 220 # if the difference between start and stop is even, 221 # then we have a slice in the middle that is the same distance from top bottom 222 # in this case the slice is not segmented in the ranges above, and we segment it 223 # using the combined mask from the adjacent top and bottom slice as prompt 224 if slice_diff % 2 == 0: 225 seg_prompt = np.logical_or(segmentation[z_mid - 1] == 1, segmentation[z_mid + 1] == 1) 226 segmentation[z_mid] = segment_from_mask( 227 predictor, seg_prompt, image_embeddings=image_embeddings, i=z_mid, 228 use_mask=use_mask, use_box=use_box, use_points=use_points, 229 box_extension=box_extension 230 ) 231 update_progress(1) 232 233 return segmentation, (z_min, z_max) 234 235 236def _preprocess_closing(slice_segmentation, gap_closing, pbar_update): 237 binarized = slice_segmentation > 0 238 # Use a structuring element that only closes elements in z, to avoid merging objects in-plane. 239 structuring_element = np.zeros((3, 1, 1)) 240 structuring_element[:, 0, 0] = 1 241 closed_segmentation = binary_closing(binarized, iterations=gap_closing, structure=structuring_element) 242 243 new_segmentation = np.zeros_like(slice_segmentation) 244 n_slices = new_segmentation.shape[0] 245 246 def process_slice(z, offset): 247 seg_z = slice_segmentation[z] 248 249 # Closing does not work for the first and last gap slices 250 if z < gap_closing or z >= (n_slices - gap_closing): 251 seg_z, _, _ = relabel_sequential(seg_z, offset=offset) 252 offset = int(seg_z.max()) + 1 253 return seg_z, offset 254 255 # Apply connected components to the closed segmentation. 256 closed_z = label(closed_segmentation[z]) 257 258 # Map objects in the closed and initial segmentation. 259 # We take objects from the closed segmentation unless they 260 # have overlap with more than one object from the initial segmentation. 261 # This indicates wrong merging of closeby objects that we want to prevent. 262 matches = segmentation_overlap(closed_z, seg_z) 263 matches = { 264 seg_id: matches.overlaps_for_label_a(seg_id)["label"] for seg_id in range(1, int(closed_z.max() + 1)) 265 } 266 matches = {k: v[v != 0] for k, v in matches.items()} 267 268 ids_initial, ids_closed = [], [] 269 for seg_id, matched in matches.items(): 270 if len(matched) > 1: 271 ids_initial.extend(matched.tolist()) 272 else: 273 ids_closed.append(seg_id) 274 275 seg_new = np.zeros_like(seg_z) 276 closed_mask = np.isin(closed_z, ids_closed) 277 seg_new[closed_mask] = closed_z[closed_mask] 278 279 if ids_initial: 280 initial_mask = np.isin(seg_z, ids_initial) 281 seg_new[initial_mask] = relabel_sequential(seg_z[initial_mask], offset=seg_new.max() + 1)[0] 282 283 seg_new, _, _ = relabel_sequential(seg_new, offset=offset) 284 max_z = seg_new.max() 285 if max_z > 0: 286 offset = int(max_z) + 1 287 288 return seg_new, offset 289 290 # Further optimization: parallelize 291 offset = 1 292 for z in range(n_slices): 293 new_segmentation[z], offset = process_slice(z, offset) 294 pbar_update(1) 295 296 return new_segmentation 297 298 299def _filter_z_extent(segmentation, min_z_extent): 300 props = regionprops(segmentation) 301 filter_ids = [] 302 for prop in props: 303 box = prop.bbox 304 z_extent = box[3] - box[0] 305 if z_extent < min_z_extent: 306 filter_ids.append(prop.label) 307 if filter_ids: 308 segmentation[np.isin(segmentation, filter_ids)] = 0 309 return segmentation 310 311 312def merge_instance_segmentation_3d( 313 slice_segmentation: np.ndarray, 314 beta: float = 0.5, 315 with_background: bool = True, 316 gap_closing: Optional[int] = None, 317 min_z_extent: Optional[int] = None, 318 verbose: bool = True, 319 pbar_init: Optional[callable] = None, 320 pbar_update: Optional[callable] = None, 321) -> np.ndarray: 322 """Merge stacked 2d instance segmentations into a consistent 3d segmentation. 323 324 Solves a multicut problem based on the overlap of objects to merge across z. 325 326 Args: 327 slice_segmentation: The stacked segmentation across the slices. 328 We assume that the segmentation is labeled consecutive across z. 329 beta: The bias term for the multicut. Higher values lead to a larger 330 degree of over-segmentation and vice versa. by default, set to '0.5'. 331 with_background: Whether this is a segmentation problem with background. 332 In that case all edges connecting to the background are set to be repulsive. 333 By default, set to 'True'. 334 gap_closing: If given, gaps in the segmentation are closed with a binary closing 335 operation. The value is used to determine the number of iterations for the closing. 336 min_z_extent: Require a minimal extent in z for the segmented objects. 337 This can help to prevent segmentation artifacts. 338 verbose: Verbosity flag. By default, set to 'True'. 339 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 340 Can be used together with pbar_update to handle napari progress bar in other thread. 341 To enable using this function within a threadworker. 342 pbar_update: Callback to update an external progress bar. 343 344 Returns: 345 The merged segmentation. 346 """ 347 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 348 349 if gap_closing is not None and gap_closing > 0: 350 pbar_init(slice_segmentation.shape[0] + 1, "Merge segmentation") 351 slice_segmentation = _preprocess_closing(slice_segmentation, gap_closing, pbar_update) 352 else: 353 pbar_init(1, "Merge segmentation") 354 355 # Extract the overlap between slices. 356 edges = track_utils.compute_edges_from_overlap(slice_segmentation, verbose=False) 357 if len(edges) == 0: # Nothing to merge. 358 return slice_segmentation 359 360 uv_ids = np.array([[edge["source"], edge["target"]] for edge in edges]) 361 overlaps = np.array([edge["score"] for edge in edges]) 362 363 n_nodes = int(slice_segmentation.max() + 1) 364 graph = UndirectedGraph(n_nodes) 365 graph.insert_edges(uv_ids) 366 367 costs = seg_utils.multicut.compute_edge_costs(overlaps) 368 # Set background weights to be maximally repulsive. 369 if with_background: 370 bg_edges = (uv_ids == 0).any(axis=1) 371 costs[bg_edges] = -8.0 372 373 node_labels = seg_utils.multicut.multicut_decomposition(graph, 1.0 - costs, beta=beta) 374 375 segmentation = node_labels[slice_segmentation] 376 if min_z_extent is not None and min_z_extent > 0: 377 segmentation = _filter_z_extent(segmentation, min_z_extent) 378 379 pbar_update(1) 380 pbar_close() 381 382 return segmentation 383 384 385def _segment_slices( 386 data, predictor, segmentor, embedding_path, verbose, tile_shape, halo, batch_size=1, **kwargs 387): 388 assert data.ndim == 3 389 390 image_embeddings = util.precompute_image_embeddings( 391 predictor=predictor, 392 input_=data, 393 save_path=embedding_path, 394 ndim=3, 395 tile_shape=tile_shape, 396 halo=halo, 397 verbose=verbose, 398 batch_size=batch_size, 399 ) 400 401 offset = 0 402 segmentation = np.zeros(data.shape, dtype="uint32") 403 404 for i in tqdm(range(segmentation.shape[0]), desc="Segment slices", disable=not verbose): 405 segmentor.initialize(data[i], image_embeddings=image_embeddings, verbose=False, i=i) 406 seg = segmentor.generate(**kwargs) 407 408 # Set offset for instance per slice. 409 max_z = int(seg.max()) 410 if max_z == 0: 411 continue 412 seg[seg != 0] += offset 413 offset = max_z + offset 414 segmentation[i] = seg 415 416 return segmentation, image_embeddings 417 418 419def automatic_3d_segmentation( 420 volume: np.ndarray, 421 predictor: SamPredictor, 422 segmentor: AMGBase, 423 embedding_path: Optional[Union[str, os.PathLike]] = None, 424 with_background: bool = True, 425 gap_closing: Optional[int] = None, 426 min_z_extent: Optional[int] = None, 427 tile_shape: Optional[Tuple[int, int]] = None, 428 halo: Optional[Tuple[int, int]] = None, 429 verbose: bool = True, 430 return_embeddings: bool = False, 431 batch_size: int = 1, 432 **kwargs, 433) -> np.ndarray: 434 """Automatically segment objects in a volume. 435 436 First segments slices individually in 2d and then merges them across 3d 437 based on overlap of objects between slices. 438 439 Args: 440 volume: The input volume. 441 predictor: The Segment Anything predictor. 442 segmentor: The instance segmentation class. 443 embedding_path: The path to save pre-computed embeddings. 444 with_background: Whether the segmentation has background. By default, set to 'True'. 445 gap_closing: If given, gaps in the segmentation are closed with a binary closing 446 operation. The value is used to determine the number of iterations for the closing. 447 min_z_extent: Require a minimal extent in z for the segmented objects. 448 This can help to prevent segmentation artifacts. 449 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 450 halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. 451 verbose: Verbosity flag. By default, set to 'True'. 452 return_embeddings: Whether to return the precomputed image embeddings. By default, set to 'False'. 453 batch_size: The batch size to compute image embeddings over planes. By default, set to '1'. 454 kwargs: Keyword arguments for the 'generate' method of the 'segmentor'. 455 456 Returns: 457 The segmentation. 458 """ 459 segmentation, image_embeddings = _segment_slices( 460 data=volume, 461 predictor=predictor, 462 segmentor=segmentor, 463 embedding_path=embedding_path, 464 verbose=verbose, 465 tile_shape=tile_shape, 466 halo=halo, 467 batch_size=batch_size, 468 **kwargs 469 ) 470 segmentation = merge_instance_segmentation_3d( 471 segmentation, 472 beta=0.5, 473 with_background=with_background, 474 gap_closing=gap_closing, 475 min_z_extent=min_z_extent, 476 verbose=verbose, 477 ) 478 if return_embeddings: 479 return segmentation, image_embeddings 480 else: 481 return segmentation 482 483 484def _filter_tracks(tracking_result, min_track_length): 485 props = regionprops(tracking_result) 486 discard_ids = [] 487 for prop in props: 488 label_id = prop.label 489 z_start, z_stop = prop.bbox[0], prop.bbox[3] 490 if z_stop - z_start < min_track_length: 491 discard_ids.append(label_id) 492 tracking_result[np.isin(tracking_result, discard_ids)] = 0 493 tracking_result, _, _ = relabel_sequential(tracking_result) 494 return tracking_result 495 496 497def _extract_tracks_and_lineages(segmentations, track_data, parent_graph): 498 # The track data has the following layout: n_tracks x 4 499 # With the following columns: 500 # track_id - id of the track (= result from trackastra) 501 # timepoint 502 # y coordinate 503 # x coordinate 504 505 # Use the last three columns to index the segmentation and get the segmentation id. 506 index = np.round(track_data[:, 1:], 0).astype("int32") 507 index = tuple(index[:, i] for i in range(index.shape[1])) 508 segmentation_ids = segmentations[index] 509 510 # Find the mapping of nodes (= segmented objects) to track-ids. 511 track_ids = track_data[:, 0].astype("int32") 512 assert len(segmentation_ids) == len(track_ids) 513 node_to_track = {k: v for k, v in zip(segmentation_ids, track_ids)} 514 515 # Find the lineages as connected components in the parent graph. 516 # First, we build a proper graph. 517 lineage_graph = nx.Graph() 518 for k, v in parent_graph.items(): 519 lineage_graph.add_edge(k, v) 520 521 # Then, find the connected components, and compute the lineage representation expected by micro-sam from it: 522 # E.g. if we have three lineages, the first consisting of three tracks and the second and third of one track each: 523 # [ 524 # {1: [2, 3]}, lineage with a dividing cell 525 # {4: []}, lineage with just one cell 526 # {5: []}, lineage with just one cell 527 # ] 528 529 # First, we fill the lineages which have one or more divisions, i.e. trees with more than one node. 530 lineages = [] 531 for component in nx.connected_components(lineage_graph): 532 root = next(iter(component)) 533 lineage_dict = {} 534 535 def dfs(node, parent): 536 # Avoid revisiting the parent node 537 children = [n for n in lineage_graph[node] if n != parent] 538 lineage_dict[node] = children 539 for child in children: 540 dfs(child, node) 541 542 dfs(root, None) 543 lineages.append(lineage_dict) 544 545 # Then add single node lineages, which are not reflected in the original graph. 546 all_tracks = set(track_ids.tolist()) 547 lineage_tracks = [] 548 for lineage in lineages: 549 for k, v in lineage.items(): 550 lineage_tracks.append(k) 551 lineage_tracks.extend(v) 552 singleton_tracks = list(all_tracks - set(lineage_tracks)) 553 lineages.extend([{track: []} for track in singleton_tracks]) 554 555 # Make sure node_to_track contains everything. 556 all_seg_ids = np.unique(segmentations) 557 missing_seg_ids = np.setdiff1d(all_seg_ids, list(node_to_track.keys())) 558 node_to_track.update({seg_id: 0 for seg_id in missing_seg_ids}) 559 return node_to_track, lineages 560 561 562def _filter_lineages(lineages, tracking_result): 563 track_ids = set(np.unique(tracking_result)) - {0} 564 filtered_lineages = [] 565 for lineage in lineages: 566 filtered_lineage = {k: v for k, v in lineage.items() if k in track_ids} 567 if filtered_lineage: 568 filtered_lineages.append(filtered_lineage) 569 return filtered_lineages 570 571 572def _tracking_impl(timeseries, segmentation, mode, min_time_extent, output_folder=None): 573 device = "cuda" if torch.cuda.is_available() else "cpu" 574 model = Trackastra.from_pretrained("general_2d", device=device) 575 result = model.track(timeseries, segmentation, mode=mode) 576 try: 577 lineage_graph, _ = result 578 except ValueError: 579 lineage_graph = result 580 581 track_data, parent_graph, _ = graph_to_napari_tracks(lineage_graph) 582 if track_data.size == 0: 583 warnings.warn("Tracking result is empty.") 584 tracking_result = np.zeros_like(segmentation) 585 lineages = [] 586 return tracking_result, lineages 587 588 node_to_track, lineages = _extract_tracks_and_lineages(segmentation, track_data, parent_graph) 589 tracking_result = recolor_segmentation(segmentation, node_to_track) 590 591 if output_folder is not None: # Store tracking results in CTC format. 592 graph_to_ctc(lineage_graph, segmentation, outdir=output_folder) 593 594 # TODO 595 # We should check if trackastra supports this already. 596 # Filter out short tracks / lineages. 597 if min_time_extent is not None and min_time_extent > 0: 598 raise NotImplementedError 599 600 # Filter out pruned lineages. 601 # May either be missing due to track filtering or non-consecutive track numbering in trackastra. 602 lineages = _filter_lineages(lineages, tracking_result) 603 604 return tracking_result, lineages 605 606 607def track_across_frames( 608 timeseries: np.ndarray, 609 segmentation: np.ndarray, 610 gap_closing: Optional[int] = None, 611 min_time_extent: Optional[int] = None, 612 verbose: bool = True, 613 pbar_init: Optional[callable] = None, 614 pbar_update: Optional[callable] = None, 615 output_folder: Optional[Union[os.PathLike, str]] = None, 616) -> Tuple[np.ndarray, List[Dict]]: 617 """Track segmented objects over time. 618 619 This function uses Trackastra: https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf 620 for tracking. Please cite it if you use the automated tracking functionality. 621 622 Args: 623 timeseries: The input timeseries of images. 624 segmentation: The segmentation. Expect segmentation results per frame 625 that are relabeled so that segmentation ids don't overlap. 626 gap_closing: If given, gaps in the segmentation are closed with a binary closing 627 operation. The value is used to determine the number of iterations for the closing. 628 min_time_extent: Require a minimal extent in time for the tracked objects. 629 verbose: Verbosity flag. By default, set to 'True'. 630 pbar_init: Function to initialize the progress bar. 631 pbar_update: Function to update the progress bar. 632 output_folder: The folder where the tracking results are stored in CTC format. 633 634 Returns: 635 The tracking result. Each object is colored by its track id. 636 The lineages, which correspond to the cell divisions. Lineages are represented by a list of dicts, 637 with each dict encoding a lineage, where keys correspond to parent track ids. 638 Each key either maps to a list with two child track ids (cell division) or to an empty list (no division). 639 """ 640 if Trackastra is None: 641 raise RuntimeError( 642 "Automatic tracking requires trackastra. You can install it via 'pip install trackastra'." 643 ) 644 645 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init=pbar_init, pbar_update=pbar_update) 646 647 if gap_closing is not None and gap_closing > 0: 648 segmentation = _preprocess_closing(segmentation, gap_closing, pbar_update) 649 650 segmentation, lineage = _tracking_impl( 651 timeseries=np.asarray(timeseries), 652 segmentation=segmentation, 653 mode="greedy", 654 min_time_extent=min_time_extent, 655 output_folder=output_folder, 656 ) 657 return segmentation, lineage 658 659 660def automatic_tracking_implementation( 661 timeseries: np.ndarray, 662 predictor: SamPredictor, 663 segmentor: AMGBase, 664 embedding_path: Optional[Union[str, os.PathLike]] = None, 665 gap_closing: Optional[int] = None, 666 min_time_extent: Optional[int] = None, 667 tile_shape: Optional[Tuple[int, int]] = None, 668 halo: Optional[Tuple[int, int]] = None, 669 verbose: bool = True, 670 return_embeddings: bool = False, 671 batch_size: int = 1, 672 output_folder: Optional[Union[os.PathLike, str]] = None, 673 **kwargs, 674) -> Tuple[np.ndarray, List[Dict]]: 675 """Automatically track objects in a timesries based on per-frame automatic segmentation. 676 677 This function uses Trackastra: https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf 678 for tracking. Please cite it if you use the automated tracking functionality. 679 680 Args: 681 timeseries: The input timeseries of images. 682 predictor: The SAM model. 683 segmentor: The instance segmentation class. 684 embedding_path: The path to save pre-computed embeddings. 685 gap_closing: If given, gaps in the segmentation are closed with a binary closing 686 operation. The value is used to determine the number of iterations for the closing. 687 min_time_extent: Require a minimal extent in time for the tracked objects. 688 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 689 halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. 690 verbose: Verbosity flag. By default, set to 'True'. 691 return_embeddings: Whether to return the precomputed image embeddings. By default, set to 'False'. 692 batch_size: The batch size to compute image embeddings over planes. By default, set to '1'. 693 output_folder: The folder where the tracking results are stored in CTC format. 694 kwargs: Keyword arguments for the 'generate' method of the 'segmentor'. 695 696 Returns: 697 The tracking result. Each object is colored by its track id. 698 The lineages, which correspond to the cell divisions. Lineages are represented by a list of dicts, 699 with each dict encoding a lineage, where keys correspond to parent track ids. 700 Each key either maps to a list with two child track ids (cell division) or to an empty list (no division). 701 """ 702 if Trackastra is None: 703 raise RuntimeError( 704 "Automatic tracking requires trackastra. You can install it via 'pip install trackastra'." 705 ) 706 707 segmentation, image_embeddings = _segment_slices( 708 timeseries, predictor, segmentor, embedding_path, verbose, 709 tile_shape=tile_shape, halo=halo, batch_size=batch_size, 710 **kwargs, 711 ) 712 713 segmentation, lineage = track_across_frames( 714 timeseries=timeseries, 715 segmentation=segmentation, 716 gap_closing=gap_closing, 717 min_time_extent=min_time_extent, 718 verbose=verbose, 719 output_folder=output_folder, 720 ) 721 722 if return_embeddings: 723 return segmentation, lineage, image_embeddings 724 else: 725 return segmentation, lineage 726 727 728def get_napari_track_data( 729 segmentation: np.ndarray, lineages: List[Dict], n_threads: Optional[int] = None 730) -> Tuple[np.ndarray, Dict[int, List]]: 731 """Derive the inputs for the napari tracking layer from a tracking result. 732 733 Args: 734 segmentation: The segmentation, after relabeling with track ids. 735 lineages: The lineage information. 736 n_threads: Number of threads for extracting the track data from the segmentation. 737 738 Returns: 739 The array with the track data expected by napari. 740 The parent dictionary for napari. 741 """ 742 if n_threads is None: 743 n_threads = mp.cpu_count() 744 745 def compute_props(t): 746 props = regionprops(segmentation[t]) 747 # Create the track data representation for napari, which expects: 748 # track_id, timepoint, y, x 749 track_data = np.array([[prop.label, t] + list(prop.centroid) for prop in props]) 750 return track_data 751 752 with futures.ThreadPoolExecutor(n_threads) as tp: 753 track_data = list(tp.map(compute_props, range(segmentation.shape[0]))) 754 track_data = [data for data in track_data if data.size > 0] 755 track_data = np.concatenate(track_data) 756 757 # The graph representation of napari uses the children as keys and the parents as values, 758 # whereas our representation uses parents as keys and children as values. 759 # Hence, we need to translate the representation. 760 parent_graph = { 761 child: [parent] for lineage in lineages for parent, children in lineage.items() for child in children 762 } 763 764 return track_data, parent_graph
106def segment_mask_in_volume( 107 segmentation: np.ndarray, 108 predictor: SamPredictor, 109 image_embeddings: util.ImageEmbeddings, 110 segmented_slices: np.ndarray, 111 stop_lower: bool, 112 stop_upper: bool, 113 iou_threshold: float, 114 projection: Union[str, dict], 115 update_progress: Optional[callable] = None, 116 box_extension: float = 0.0, 117 verbose: bool = False, 118) -> Tuple[np.ndarray, Tuple[int, int]]: 119 """Segment an object mask in in volumetric data. 120 121 Args: 122 segmentation: The initial segmentation for the object. 123 predictor: The Segment Anything predictor. 124 image_embeddings: The precomputed image embeddings for the volume. 125 segmented_slices: List of slices for which this object has already been segmented. 126 stop_lower: Whether to stop at the lowest segmented slice. 127 stop_upper: Wheter to stop at the topmost segmented slice. 128 iou_threshold: The IOU threshold for continuing segmentation across 3d. 129 projection: The projection method to use. One of 'box', 'mask', 'points', 'points_and_mask' or 'single point'. 130 Pass a dictionary to choose the excact combination of projection modes. 131 update_progress: Callback to update an external progress bar. 132 box_extension: Extension factor for increasing the box size after projection. 133 By default, does not increase the projected box size. 134 verbose: Whether to print details about the segmentation steps. By default, set to 'True'. 135 136 Returns: 137 Array with the volumetric segmentation. 138 Tuple with the first and last segmented slice. 139 """ 140 use_box, use_mask, use_points, use_single_point = _validate_projection(projection) 141 142 if update_progress is None: 143 def update_progress(*args): 144 pass 145 146 def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None, verbose=False): 147 z = z_start + increment 148 while True: 149 if verbose: 150 print(f"Segment {z_start} to {z_stop}: segmenting slice {z}") 151 seg_prev = segmentation[z - increment] 152 seg_z, score, _ = segment_from_mask( 153 predictor, seg_prev, image_embeddings=image_embeddings, i=z, use_mask=use_mask, 154 use_box=use_box, use_points=use_points, box_extension=box_extension, return_all=True, 155 use_single_point=use_single_point, 156 ) 157 if threshold is not None: 158 iou = util.compute_iou(seg_prev, seg_z) 159 if iou < threshold: 160 if verbose: 161 msg = f"Segmentation stopped at slice {z} due to IOU {iou} < {threshold}." 162 print(msg) 163 break 164 165 segmentation[z] = seg_z 166 z += increment 167 if stopping_criterion(z, z_stop): 168 if verbose: 169 print(f"Segment {z_start} to {z_stop}: stop at slice {z}") 170 break 171 update_progress(1) 172 173 return z - increment 174 175 z0, z1 = int(segmented_slices.min()), int(segmented_slices.max()) 176 177 # segment below the min slice 178 if z0 > 0 and not stop_lower: 179 z_min = segment_range(z0, 0, -1, np.less, iou_threshold, verbose=verbose) 180 else: 181 z_min = z0 182 183 # segment above the max slice 184 if z1 < segmentation.shape[0] - 1 and not stop_upper: 185 z_max = segment_range(z1, segmentation.shape[0] - 1, 1, np.greater, iou_threshold, verbose=verbose) 186 else: 187 z_max = z1 188 189 # segment in between min and max slice 190 if z0 != z1: 191 for z_start, z_stop in zip(segmented_slices[:-1], segmented_slices[1:]): 192 slice_diff = z_stop - z_start 193 z_mid = int((z_start + z_stop) // 2) 194 195 if slice_diff == 1: # the slices are adjacent -> we don't need to do anything 196 pass 197 198 elif z_start == z0 and stop_lower: # the lower slice is stop: we just segment from upper 199 segment_range(z_stop, z_start, -1, np.less_equal, verbose=verbose) 200 201 elif z_stop == z1 and stop_upper: # the upper slice is stop: we just segment from lower 202 segment_range(z_start, z_stop, 1, np.greater_equal, verbose=verbose) 203 204 elif slice_diff == 2: # there is only one slice in between -> use combined mask 205 z = z_start + 1 206 seg_prompt = np.logical_or(segmentation[z_start] == 1, segmentation[z_stop] == 1) 207 segmentation[z] = segment_from_mask( 208 predictor, seg_prompt, image_embeddings=image_embeddings, i=z, 209 use_mask=use_mask, use_box=use_box, use_points=use_points, 210 box_extension=box_extension 211 ) 212 update_progress(1) 213 214 else: # there is a range of more than 2 slices in between -> segment ranges 215 # segment from bottom 216 segment_range( 217 z_start, z_mid, 1, np.greater_equal if slice_diff % 2 == 0 else np.greater, verbose=verbose 218 ) 219 # segment from top 220 segment_range(z_stop, z_mid, -1, np.less_equal, verbose=verbose) 221 # if the difference between start and stop is even, 222 # then we have a slice in the middle that is the same distance from top bottom 223 # in this case the slice is not segmented in the ranges above, and we segment it 224 # using the combined mask from the adjacent top and bottom slice as prompt 225 if slice_diff % 2 == 0: 226 seg_prompt = np.logical_or(segmentation[z_mid - 1] == 1, segmentation[z_mid + 1] == 1) 227 segmentation[z_mid] = segment_from_mask( 228 predictor, seg_prompt, image_embeddings=image_embeddings, i=z_mid, 229 use_mask=use_mask, use_box=use_box, use_points=use_points, 230 box_extension=box_extension 231 ) 232 update_progress(1) 233 234 return segmentation, (z_min, z_max)
Segment an object mask in in volumetric data.
Arguments:
- segmentation: The initial segmentation for the object.
- predictor: The Segment Anything predictor.
- image_embeddings: The precomputed image embeddings for the volume.
- segmented_slices: List of slices for which this object has already been segmented.
- stop_lower: Whether to stop at the lowest segmented slice.
- stop_upper: Wheter to stop at the topmost segmented slice.
- iou_threshold: The IOU threshold for continuing segmentation across 3d.
- projection: The projection method to use. One of 'box', 'mask', 'points', 'points_and_mask' or 'single point'. Pass a dictionary to choose the excact combination of projection modes.
- update_progress: Callback to update an external progress bar.
- box_extension: Extension factor for increasing the box size after projection. By default, does not increase the projected box size.
- verbose: Whether to print details about the segmentation steps. By default, set to 'True'.
Returns:
Array with the volumetric segmentation. Tuple with the first and last segmented slice.
313def merge_instance_segmentation_3d( 314 slice_segmentation: np.ndarray, 315 beta: float = 0.5, 316 with_background: bool = True, 317 gap_closing: Optional[int] = None, 318 min_z_extent: Optional[int] = None, 319 verbose: bool = True, 320 pbar_init: Optional[callable] = None, 321 pbar_update: Optional[callable] = None, 322) -> np.ndarray: 323 """Merge stacked 2d instance segmentations into a consistent 3d segmentation. 324 325 Solves a multicut problem based on the overlap of objects to merge across z. 326 327 Args: 328 slice_segmentation: The stacked segmentation across the slices. 329 We assume that the segmentation is labeled consecutive across z. 330 beta: The bias term for the multicut. Higher values lead to a larger 331 degree of over-segmentation and vice versa. by default, set to '0.5'. 332 with_background: Whether this is a segmentation problem with background. 333 In that case all edges connecting to the background are set to be repulsive. 334 By default, set to 'True'. 335 gap_closing: If given, gaps in the segmentation are closed with a binary closing 336 operation. The value is used to determine the number of iterations for the closing. 337 min_z_extent: Require a minimal extent in z for the segmented objects. 338 This can help to prevent segmentation artifacts. 339 verbose: Verbosity flag. By default, set to 'True'. 340 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 341 Can be used together with pbar_update to handle napari progress bar in other thread. 342 To enable using this function within a threadworker. 343 pbar_update: Callback to update an external progress bar. 344 345 Returns: 346 The merged segmentation. 347 """ 348 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 349 350 if gap_closing is not None and gap_closing > 0: 351 pbar_init(slice_segmentation.shape[0] + 1, "Merge segmentation") 352 slice_segmentation = _preprocess_closing(slice_segmentation, gap_closing, pbar_update) 353 else: 354 pbar_init(1, "Merge segmentation") 355 356 # Extract the overlap between slices. 357 edges = track_utils.compute_edges_from_overlap(slice_segmentation, verbose=False) 358 if len(edges) == 0: # Nothing to merge. 359 return slice_segmentation 360 361 uv_ids = np.array([[edge["source"], edge["target"]] for edge in edges]) 362 overlaps = np.array([edge["score"] for edge in edges]) 363 364 n_nodes = int(slice_segmentation.max() + 1) 365 graph = UndirectedGraph(n_nodes) 366 graph.insert_edges(uv_ids) 367 368 costs = seg_utils.multicut.compute_edge_costs(overlaps) 369 # Set background weights to be maximally repulsive. 370 if with_background: 371 bg_edges = (uv_ids == 0).any(axis=1) 372 costs[bg_edges] = -8.0 373 374 node_labels = seg_utils.multicut.multicut_decomposition(graph, 1.0 - costs, beta=beta) 375 376 segmentation = node_labels[slice_segmentation] 377 if min_z_extent is not None and min_z_extent > 0: 378 segmentation = _filter_z_extent(segmentation, min_z_extent) 379 380 pbar_update(1) 381 pbar_close() 382 383 return segmentation
Merge stacked 2d instance segmentations into a consistent 3d segmentation.
Solves a multicut problem based on the overlap of objects to merge across z.
Arguments:
- slice_segmentation: The stacked segmentation across the slices. We assume that the segmentation is labeled consecutive across z.
- beta: The bias term for the multicut. Higher values lead to a larger degree of over-segmentation and vice versa. by default, set to '0.5'.
- with_background: Whether this is a segmentation problem with background. In that case all edges connecting to the background are set to be repulsive. By default, set to 'True'.
- gap_closing: If given, gaps in the segmentation are closed with a binary closing operation. The value is used to determine the number of iterations for the closing.
- min_z_extent: Require a minimal extent in z for the segmented objects. This can help to prevent segmentation artifacts.
- verbose: Verbosity flag. By default, set to 'True'.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enable using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
Returns:
The merged segmentation.
420def automatic_3d_segmentation( 421 volume: np.ndarray, 422 predictor: SamPredictor, 423 segmentor: AMGBase, 424 embedding_path: Optional[Union[str, os.PathLike]] = None, 425 with_background: bool = True, 426 gap_closing: Optional[int] = None, 427 min_z_extent: Optional[int] = None, 428 tile_shape: Optional[Tuple[int, int]] = None, 429 halo: Optional[Tuple[int, int]] = None, 430 verbose: bool = True, 431 return_embeddings: bool = False, 432 batch_size: int = 1, 433 **kwargs, 434) -> np.ndarray: 435 """Automatically segment objects in a volume. 436 437 First segments slices individually in 2d and then merges them across 3d 438 based on overlap of objects between slices. 439 440 Args: 441 volume: The input volume. 442 predictor: The Segment Anything predictor. 443 segmentor: The instance segmentation class. 444 embedding_path: The path to save pre-computed embeddings. 445 with_background: Whether the segmentation has background. By default, set to 'True'. 446 gap_closing: If given, gaps in the segmentation are closed with a binary closing 447 operation. The value is used to determine the number of iterations for the closing. 448 min_z_extent: Require a minimal extent in z for the segmented objects. 449 This can help to prevent segmentation artifacts. 450 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 451 halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. 452 verbose: Verbosity flag. By default, set to 'True'. 453 return_embeddings: Whether to return the precomputed image embeddings. By default, set to 'False'. 454 batch_size: The batch size to compute image embeddings over planes. By default, set to '1'. 455 kwargs: Keyword arguments for the 'generate' method of the 'segmentor'. 456 457 Returns: 458 The segmentation. 459 """ 460 segmentation, image_embeddings = _segment_slices( 461 data=volume, 462 predictor=predictor, 463 segmentor=segmentor, 464 embedding_path=embedding_path, 465 verbose=verbose, 466 tile_shape=tile_shape, 467 halo=halo, 468 batch_size=batch_size, 469 **kwargs 470 ) 471 segmentation = merge_instance_segmentation_3d( 472 segmentation, 473 beta=0.5, 474 with_background=with_background, 475 gap_closing=gap_closing, 476 min_z_extent=min_z_extent, 477 verbose=verbose, 478 ) 479 if return_embeddings: 480 return segmentation, image_embeddings 481 else: 482 return segmentation
Automatically segment objects in a volume.
First segments slices individually in 2d and then merges them across 3d based on overlap of objects between slices.
Arguments:
- volume: The input volume.
- predictor: The Segment Anything predictor.
- segmentor: The instance segmentation class.
- embedding_path: The path to save pre-computed embeddings.
- with_background: Whether the segmentation has background. By default, set to 'True'.
- gap_closing: If given, gaps in the segmentation are closed with a binary closing operation. The value is used to determine the number of iterations for the closing.
- min_z_extent: Require a minimal extent in z for the segmented objects. This can help to prevent segmentation artifacts.
- tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
- halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
- verbose: Verbosity flag. By default, set to 'True'.
- return_embeddings: Whether to return the precomputed image embeddings. By default, set to 'False'.
- batch_size: The batch size to compute image embeddings over planes. By default, set to '1'.
- kwargs: Keyword arguments for the 'generate' method of the 'segmentor'.
Returns:
The segmentation.
608def track_across_frames( 609 timeseries: np.ndarray, 610 segmentation: np.ndarray, 611 gap_closing: Optional[int] = None, 612 min_time_extent: Optional[int] = None, 613 verbose: bool = True, 614 pbar_init: Optional[callable] = None, 615 pbar_update: Optional[callable] = None, 616 output_folder: Optional[Union[os.PathLike, str]] = None, 617) -> Tuple[np.ndarray, List[Dict]]: 618 """Track segmented objects over time. 619 620 This function uses Trackastra: https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf 621 for tracking. Please cite it if you use the automated tracking functionality. 622 623 Args: 624 timeseries: The input timeseries of images. 625 segmentation: The segmentation. Expect segmentation results per frame 626 that are relabeled so that segmentation ids don't overlap. 627 gap_closing: If given, gaps in the segmentation are closed with a binary closing 628 operation. The value is used to determine the number of iterations for the closing. 629 min_time_extent: Require a minimal extent in time for the tracked objects. 630 verbose: Verbosity flag. By default, set to 'True'. 631 pbar_init: Function to initialize the progress bar. 632 pbar_update: Function to update the progress bar. 633 output_folder: The folder where the tracking results are stored in CTC format. 634 635 Returns: 636 The tracking result. Each object is colored by its track id. 637 The lineages, which correspond to the cell divisions. Lineages are represented by a list of dicts, 638 with each dict encoding a lineage, where keys correspond to parent track ids. 639 Each key either maps to a list with two child track ids (cell division) or to an empty list (no division). 640 """ 641 if Trackastra is None: 642 raise RuntimeError( 643 "Automatic tracking requires trackastra. You can install it via 'pip install trackastra'." 644 ) 645 646 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init=pbar_init, pbar_update=pbar_update) 647 648 if gap_closing is not None and gap_closing > 0: 649 segmentation = _preprocess_closing(segmentation, gap_closing, pbar_update) 650 651 segmentation, lineage = _tracking_impl( 652 timeseries=np.asarray(timeseries), 653 segmentation=segmentation, 654 mode="greedy", 655 min_time_extent=min_time_extent, 656 output_folder=output_folder, 657 ) 658 return segmentation, lineage
Track segmented objects over time.
This function uses Trackastra: https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf for tracking. Please cite it if you use the automated tracking functionality.
Arguments:
- timeseries: The input timeseries of images.
- segmentation: The segmentation. Expect segmentation results per frame that are relabeled so that segmentation ids don't overlap.
- gap_closing: If given, gaps in the segmentation are closed with a binary closing operation. The value is used to determine the number of iterations for the closing.
- min_time_extent: Require a minimal extent in time for the tracked objects.
- verbose: Verbosity flag. By default, set to 'True'.
- pbar_init: Function to initialize the progress bar.
- pbar_update: Function to update the progress bar.
- output_folder: The folder where the tracking results are stored in CTC format.
Returns:
The tracking result. Each object is colored by its track id. The lineages, which correspond to the cell divisions. Lineages are represented by a list of dicts, with each dict encoding a lineage, where keys correspond to parent track ids. Each key either maps to a list with two child track ids (cell division) or to an empty list (no division).
661def automatic_tracking_implementation( 662 timeseries: np.ndarray, 663 predictor: SamPredictor, 664 segmentor: AMGBase, 665 embedding_path: Optional[Union[str, os.PathLike]] = None, 666 gap_closing: Optional[int] = None, 667 min_time_extent: Optional[int] = None, 668 tile_shape: Optional[Tuple[int, int]] = None, 669 halo: Optional[Tuple[int, int]] = None, 670 verbose: bool = True, 671 return_embeddings: bool = False, 672 batch_size: int = 1, 673 output_folder: Optional[Union[os.PathLike, str]] = None, 674 **kwargs, 675) -> Tuple[np.ndarray, List[Dict]]: 676 """Automatically track objects in a timesries based on per-frame automatic segmentation. 677 678 This function uses Trackastra: https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf 679 for tracking. Please cite it if you use the automated tracking functionality. 680 681 Args: 682 timeseries: The input timeseries of images. 683 predictor: The SAM model. 684 segmentor: The instance segmentation class. 685 embedding_path: The path to save pre-computed embeddings. 686 gap_closing: If given, gaps in the segmentation are closed with a binary closing 687 operation. The value is used to determine the number of iterations for the closing. 688 min_time_extent: Require a minimal extent in time for the tracked objects. 689 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 690 halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. 691 verbose: Verbosity flag. By default, set to 'True'. 692 return_embeddings: Whether to return the precomputed image embeddings. By default, set to 'False'. 693 batch_size: The batch size to compute image embeddings over planes. By default, set to '1'. 694 output_folder: The folder where the tracking results are stored in CTC format. 695 kwargs: Keyword arguments for the 'generate' method of the 'segmentor'. 696 697 Returns: 698 The tracking result. Each object is colored by its track id. 699 The lineages, which correspond to the cell divisions. Lineages are represented by a list of dicts, 700 with each dict encoding a lineage, where keys correspond to parent track ids. 701 Each key either maps to a list with two child track ids (cell division) or to an empty list (no division). 702 """ 703 if Trackastra is None: 704 raise RuntimeError( 705 "Automatic tracking requires trackastra. You can install it via 'pip install trackastra'." 706 ) 707 708 segmentation, image_embeddings = _segment_slices( 709 timeseries, predictor, segmentor, embedding_path, verbose, 710 tile_shape=tile_shape, halo=halo, batch_size=batch_size, 711 **kwargs, 712 ) 713 714 segmentation, lineage = track_across_frames( 715 timeseries=timeseries, 716 segmentation=segmentation, 717 gap_closing=gap_closing, 718 min_time_extent=min_time_extent, 719 verbose=verbose, 720 output_folder=output_folder, 721 ) 722 723 if return_embeddings: 724 return segmentation, lineage, image_embeddings 725 else: 726 return segmentation, lineage
Automatically track objects in a timesries based on per-frame automatic segmentation.
This function uses Trackastra: https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf for tracking. Please cite it if you use the automated tracking functionality.
Arguments:
- timeseries: The input timeseries of images.
- predictor: The SAM model.
- segmentor: The instance segmentation class.
- embedding_path: The path to save pre-computed embeddings.
- gap_closing: If given, gaps in the segmentation are closed with a binary closing operation. The value is used to determine the number of iterations for the closing.
- min_time_extent: Require a minimal extent in time for the tracked objects.
- tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
- halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
- verbose: Verbosity flag. By default, set to 'True'.
- return_embeddings: Whether to return the precomputed image embeddings. By default, set to 'False'.
- batch_size: The batch size to compute image embeddings over planes. By default, set to '1'.
- output_folder: The folder where the tracking results are stored in CTC format.
- kwargs: Keyword arguments for the 'generate' method of the 'segmentor'.
Returns:
The tracking result. Each object is colored by its track id. The lineages, which correspond to the cell divisions. Lineages are represented by a list of dicts, with each dict encoding a lineage, where keys correspond to parent track ids. Each key either maps to a list with two child track ids (cell division) or to an empty list (no division).
729def get_napari_track_data( 730 segmentation: np.ndarray, lineages: List[Dict], n_threads: Optional[int] = None 731) -> Tuple[np.ndarray, Dict[int, List]]: 732 """Derive the inputs for the napari tracking layer from a tracking result. 733 734 Args: 735 segmentation: The segmentation, after relabeling with track ids. 736 lineages: The lineage information. 737 n_threads: Number of threads for extracting the track data from the segmentation. 738 739 Returns: 740 The array with the track data expected by napari. 741 The parent dictionary for napari. 742 """ 743 if n_threads is None: 744 n_threads = mp.cpu_count() 745 746 def compute_props(t): 747 props = regionprops(segmentation[t]) 748 # Create the track data representation for napari, which expects: 749 # track_id, timepoint, y, x 750 track_data = np.array([[prop.label, t] + list(prop.centroid) for prop in props]) 751 return track_data 752 753 with futures.ThreadPoolExecutor(n_threads) as tp: 754 track_data = list(tp.map(compute_props, range(segmentation.shape[0]))) 755 track_data = [data for data in track_data if data.size > 0] 756 track_data = np.concatenate(track_data) 757 758 # The graph representation of napari uses the children as keys and the parents as values, 759 # whereas our representation uses parents as keys and children as values. 760 # Hence, we need to translate the representation. 761 parent_graph = { 762 child: [parent] for lineage in lineages for parent, children in lineage.items() for child in children 763 } 764 765 return track_data, parent_graph
Derive the inputs for the napari tracking layer from a tracking result.
Arguments:
- segmentation: The segmentation, after relabeling with track ids.
- lineages: The lineage information.
- n_threads: Number of threads for extracting the track data from the segmentation.
Returns:
The array with the track data expected by napari. The parent dictionary for napari.