micro_sam.multi_dimensional_segmentation
Multi-dimensional segmentation with segment anything.
1"""Multi-dimensional segmentation with segment anything. 2""" 3 4import os 5import multiprocessing as mp 6from concurrent import futures 7from typing import Dict, List, Optional, Union, Tuple 8 9import networkx as nx 10import nifty 11import numpy as np 12import torch 13from scipy.ndimage import binary_closing 14from skimage.measure import label, regionprops 15from skimage.segmentation import relabel_sequential 16 17import elf.segmentation as seg_utils 18import elf.tracking.tracking_utils as track_utils 19from elf.tracking.motile_tracking import recolor_segmentation 20 21from segment_anything.predictor import SamPredictor 22 23try: 24 from napari.utils import progress as tqdm 25except ImportError: 26 from tqdm import tqdm 27 28try: 29 from trackastra.model import Trackastra 30 from trackastra.tracking import graph_to_ctc, graph_to_napari_tracks 31except ImportError: 32 Trackastra = None 33 34 35from . import util 36from .prompt_based_segmentation import segment_from_mask 37from .instance_segmentation import AMGBase, mask_data_to_segmentation 38 39 40PROJECTION_MODES = ("box", "mask", "points", "points_and_mask", "single_point") 41 42 43def _validate_projection(projection): 44 use_single_point = False 45 if isinstance(projection, str): 46 if projection == "mask": 47 use_box, use_mask, use_points = True, True, False 48 elif projection == "points": 49 use_box, use_mask, use_points = False, False, True 50 elif projection == "box": 51 use_box, use_mask, use_points = True, False, False 52 elif projection == "points_and_mask": 53 use_box, use_mask, use_points = False, True, True 54 elif projection == "single_point": 55 use_box, use_mask, use_points = False, False, True 56 use_single_point = True 57 else: 58 raise ValueError( 59 "Choose projection method from 'mask' / 'points' / 'box' / 'points_and_mask' / 'single_point'. " 60 f"You have passed the invalid option {projection}." 61 ) 62 elif isinstance(projection, dict): 63 assert len(projection.keys()) == 3, "There should be three parameters assigned for the projection method." 64 use_box, use_mask, use_points = projection["use_box"], projection["use_mask"], projection["use_points"] 65 else: 66 raise ValueError(f"{projection} is not a supported projection method.") 67 return use_box, use_mask, use_points, use_single_point 68 69 70# Advanced stopping criterions. 71# In practice these did not make a big difference, so we do not use this at the moment. 72# We still leave it here for reference. 73def _advanced_stopping_criteria( 74 z, seg_z, seg_prev, z_start, z_increment, segmentation, criterion_choice, score, increment 75): 76 def _compute_mean_iou_for_n_slices(z, increment, seg_z, n_slices): 77 iou_list = [ 78 util.compute_iou(segmentation[z - increment * _slice], seg_z) for _slice in range(1, n_slices+1) 79 ] 80 return np.mean(iou_list) 81 82 if criterion_choice == 1: 83 # 1. current metric: iou of current segmentation and the previous slice 84 iou = util.compute_iou(seg_prev, seg_z) 85 criterion = iou 86 87 elif criterion_choice == 2: 88 # 2. combining SAM iou + iou: curr. slice & first segmented slice + iou: curr. slice vs prev. slice 89 iou = util.compute_iou(seg_prev, seg_z) 90 ff_iou = util.compute_iou(segmentation[z_start], seg_z) 91 criterion = 0.5 * iou + 0.3 * score + 0.2 * ff_iou 92 93 elif criterion_choice == 3: 94 # 3. iou of current segmented slice w.r.t the previous n slices 95 criterion = _compute_mean_iou_for_n_slices(z, increment, seg_z, min(5, abs(z - z_start))) 96 97 return criterion 98 99 100def segment_mask_in_volume( 101 segmentation: np.ndarray, 102 predictor: SamPredictor, 103 image_embeddings: util.ImageEmbeddings, 104 segmented_slices: np.ndarray, 105 stop_lower: bool, 106 stop_upper: bool, 107 iou_threshold: float, 108 projection: Union[str, dict], 109 update_progress: Optional[callable] = None, 110 box_extension: float = 0.0, 111 verbose: bool = False, 112) -> Tuple[np.ndarray, Tuple[int, int]]: 113 """Segment an object mask in in volumetric data. 114 115 Args: 116 segmentation: The initial segmentation for the object. 117 predictor: The Segment Anything predictor. 118 image_embeddings: The precomputed image embeddings for the volume. 119 segmented_slices: List of slices for which this object has already been segmented. 120 stop_lower: Whether to stop at the lowest segmented slice. 121 stop_upper: Wheter to stop at the topmost segmented slice. 122 iou_threshold: The IOU threshold for continuing segmentation across 3d. 123 projection: The projection method to use. One of 'box', 'mask', 'points', 'points_and_mask' or 'single point'. 124 Pass a dictionary to choose the excact combination of projection modes. 125 update_progress: Callback to update an external progress bar. 126 box_extension: Extension factor for increasing the box size after projection. 127 By default, does not increase the projected box size. 128 verbose: Whether to print details about the segmentation steps. By default, set to 'True'. 129 130 Returns: 131 Array with the volumetric segmentation. 132 Tuple with the first and last segmented slice. 133 """ 134 use_box, use_mask, use_points, use_single_point = _validate_projection(projection) 135 136 if update_progress is None: 137 def update_progress(*args): 138 pass 139 140 def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None, verbose=False): 141 z = z_start + increment 142 while True: 143 if verbose: 144 print(f"Segment {z_start} to {z_stop}: segmenting slice {z}") 145 seg_prev = segmentation[z - increment] 146 seg_z, score, _ = segment_from_mask( 147 predictor, seg_prev, image_embeddings=image_embeddings, i=z, use_mask=use_mask, 148 use_box=use_box, use_points=use_points, box_extension=box_extension, return_all=True, 149 use_single_point=use_single_point, 150 ) 151 if threshold is not None: 152 iou = util.compute_iou(seg_prev, seg_z) 153 if iou < threshold: 154 if verbose: 155 msg = f"Segmentation stopped at slice {z} due to IOU {iou} < {threshold}." 156 print(msg) 157 break 158 159 segmentation[z] = seg_z 160 z += increment 161 if stopping_criterion(z, z_stop): 162 if verbose: 163 print(f"Segment {z_start} to {z_stop}: stop at slice {z}") 164 break 165 update_progress(1) 166 167 return z - increment 168 169 z0, z1 = int(segmented_slices.min()), int(segmented_slices.max()) 170 171 # segment below the min slice 172 if z0 > 0 and not stop_lower: 173 z_min = segment_range(z0, 0, -1, np.less, iou_threshold, verbose=verbose) 174 else: 175 z_min = z0 176 177 # segment above the max slice 178 if z1 < segmentation.shape[0] - 1 and not stop_upper: 179 z_max = segment_range(z1, segmentation.shape[0] - 1, 1, np.greater, iou_threshold, verbose=verbose) 180 else: 181 z_max = z1 182 183 # segment in between min and max slice 184 if z0 != z1: 185 for z_start, z_stop in zip(segmented_slices[:-1], segmented_slices[1:]): 186 slice_diff = z_stop - z_start 187 z_mid = int((z_start + z_stop) // 2) 188 189 if slice_diff == 1: # the slices are adjacent -> we don't need to do anything 190 pass 191 192 elif z_start == z0 and stop_lower: # the lower slice is stop: we just segment from upper 193 segment_range(z_stop, z_start, -1, np.less_equal, verbose=verbose) 194 195 elif z_stop == z1 and stop_upper: # the upper slice is stop: we just segment from lower 196 segment_range(z_start, z_stop, 1, np.greater_equal, verbose=verbose) 197 198 elif slice_diff == 2: # there is only one slice in between -> use combined mask 199 z = z_start + 1 200 seg_prompt = np.logical_or(segmentation[z_start] == 1, segmentation[z_stop] == 1) 201 segmentation[z] = segment_from_mask( 202 predictor, seg_prompt, image_embeddings=image_embeddings, i=z, 203 use_mask=use_mask, use_box=use_box, use_points=use_points, 204 box_extension=box_extension 205 ) 206 update_progress(1) 207 208 else: # there is a range of more than 2 slices in between -> segment ranges 209 # segment from bottom 210 segment_range( 211 z_start, z_mid, 1, np.greater_equal if slice_diff % 2 == 0 else np.greater, verbose=verbose 212 ) 213 # segment from top 214 segment_range(z_stop, z_mid, -1, np.less_equal, verbose=verbose) 215 # if the difference between start and stop is even, 216 # then we have a slice in the middle that is the same distance from top bottom 217 # in this case the slice is not segmented in the ranges above, and we segment it 218 # using the combined mask from the adjacent top and bottom slice as prompt 219 if slice_diff % 2 == 0: 220 seg_prompt = np.logical_or(segmentation[z_mid - 1] == 1, segmentation[z_mid + 1] == 1) 221 segmentation[z_mid] = segment_from_mask( 222 predictor, seg_prompt, image_embeddings=image_embeddings, i=z_mid, 223 use_mask=use_mask, use_box=use_box, use_points=use_points, 224 box_extension=box_extension 225 ) 226 update_progress(1) 227 228 return segmentation, (z_min, z_max) 229 230 231def _preprocess_closing(slice_segmentation, gap_closing, pbar_update): 232 binarized = slice_segmentation > 0 233 # Use a structuring element that only closes elements in z, to avoid merging objects in-plane. 234 structuring_element = np.zeros((3, 1, 1)) 235 structuring_element[:, 0, 0] = 1 236 closed_segmentation = binary_closing(binarized, iterations=gap_closing, structure=structuring_element) 237 238 new_segmentation = np.zeros_like(slice_segmentation) 239 n_slices = new_segmentation.shape[0] 240 241 def process_slice(z, offset): 242 seg_z = slice_segmentation[z] 243 244 # Closing does not work for the first and last gap slices 245 if z < gap_closing or z >= (n_slices - gap_closing): 246 seg_z, _, _ = relabel_sequential(seg_z, offset=offset) 247 offset = int(seg_z.max()) + 1 248 return seg_z, offset 249 250 # Apply connected components to the closed segmentation. 251 closed_z = label(closed_segmentation[z]) 252 253 # Map objects in the closed and initial segmentation. 254 # We take objects from the closed segmentation unless they 255 # have overlap with more than one object from the initial segmentation. 256 # This indicates wrong merging of closeby objects that we want to prevent. 257 matches = nifty.ground_truth.overlap(closed_z, seg_z) 258 matches = { 259 seg_id: matches.overlapArrays(seg_id, sorted=False)[0] for seg_id in range(1, int(closed_z.max() + 1)) 260 } 261 matches = {k: v[v != 0] for k, v in matches.items()} 262 263 ids_initial, ids_closed = [], [] 264 for seg_id, matched in matches.items(): 265 if len(matched) > 1: 266 ids_initial.extend(matched.tolist()) 267 else: 268 ids_closed.append(seg_id) 269 270 seg_new = np.zeros_like(seg_z) 271 closed_mask = np.isin(closed_z, ids_closed) 272 seg_new[closed_mask] = closed_z[closed_mask] 273 274 if ids_initial: 275 initial_mask = np.isin(seg_z, ids_initial) 276 seg_new[initial_mask] = relabel_sequential(seg_z[initial_mask], offset=seg_new.max() + 1)[0] 277 278 seg_new, _, _ = relabel_sequential(seg_new, offset=offset) 279 max_z = seg_new.max() 280 if max_z > 0: 281 offset = int(max_z) + 1 282 283 return seg_new, offset 284 285 # Further optimization: parallelize 286 offset = 1 287 for z in range(n_slices): 288 new_segmentation[z], offset = process_slice(z, offset) 289 pbar_update(1) 290 291 return new_segmentation 292 293 294def _filter_z_extent(segmentation, min_z_extent): 295 props = regionprops(segmentation) 296 filter_ids = [] 297 for prop in props: 298 box = prop.bbox 299 z_extent = box[3] - box[0] 300 if z_extent < min_z_extent: 301 filter_ids.append(prop.label) 302 if filter_ids: 303 segmentation[np.isin(segmentation, filter_ids)] = 0 304 return segmentation 305 306 307def merge_instance_segmentation_3d( 308 slice_segmentation: np.ndarray, 309 beta: float = 0.5, 310 with_background: bool = True, 311 gap_closing: Optional[int] = None, 312 min_z_extent: Optional[int] = None, 313 verbose: bool = True, 314 pbar_init: Optional[callable] = None, 315 pbar_update: Optional[callable] = None, 316) -> np.ndarray: 317 """Merge stacked 2d instance segmentations into a consistent 3d segmentation. 318 319 Solves a multicut problem based on the overlap of objects to merge across z. 320 321 Args: 322 slice_segmentation: The stacked segmentation across the slices. 323 We assume that the segmentation is labeled consecutive across z. 324 beta: The bias term for the multicut. Higher values lead to a larger 325 degree of over-segmentation and vice versa. by default, set to '0.5'. 326 with_background: Whether this is a segmentation problem with background. 327 In that case all edges connecting to the background are set to be repulsive. 328 By default, set to 'True'. 329 gap_closing: If given, gaps in the segmentation are closed with a binary closing 330 operation. The value is used to determine the number of iterations for the closing. 331 min_z_extent: Require a minimal extent in z for the segmented objects. 332 This can help to prevent segmentation artifacts. 333 verbose: Verbosity flag. By default, set to 'True'. 334 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 335 Can be used together with pbar_update to handle napari progress bar in other thread. 336 To enables using this function within a threadworker. 337 pbar_update: Callback to update an external progress bar. 338 339 Returns: 340 The merged segmentation. 341 """ 342 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 343 344 if gap_closing is not None and gap_closing > 0: 345 pbar_init(slice_segmentation.shape[0] + 1, "Merge segmentation") 346 slice_segmentation = _preprocess_closing(slice_segmentation, gap_closing, pbar_update) 347 else: 348 pbar_init(1, "Merge segmentation") 349 350 # Extract the overlap between slices. 351 edges = track_utils.compute_edges_from_overlap(slice_segmentation, verbose=False) 352 353 uv_ids = np.array([[edge["source"], edge["target"]] for edge in edges]) 354 overlaps = np.array([edge["score"] for edge in edges]) 355 356 n_nodes = int(slice_segmentation.max() + 1) 357 graph = nifty.graph.undirectedGraph(n_nodes) 358 graph.insertEdges(uv_ids) 359 360 costs = seg_utils.multicut.compute_edge_costs(overlaps) 361 # Set background weights to be maximally repulsive. 362 if with_background: 363 bg_edges = (uv_ids == 0).any(axis=1) 364 costs[bg_edges] = -8.0 365 366 node_labels = seg_utils.multicut.multicut_decomposition(graph, 1.0 - costs, beta=beta) 367 368 segmentation = nifty.tools.take(node_labels, slice_segmentation) 369 if min_z_extent is not None and min_z_extent > 0: 370 segmentation = _filter_z_extent(segmentation, min_z_extent) 371 372 pbar_update(1) 373 pbar_close() 374 375 return segmentation 376 377 378def _segment_slices( 379 data, predictor, segmentor, embedding_path, verbose, tile_shape, halo, with_background=True, batch_size=1, **kwargs 380): 381 assert data.ndim == 3 382 383 min_object_size = kwargs.pop("min_object_size", 0) 384 image_embeddings = util.precompute_image_embeddings( 385 predictor=predictor, 386 input_=data, 387 save_path=embedding_path, 388 ndim=3, 389 tile_shape=tile_shape, 390 halo=halo, 391 verbose=verbose, 392 batch_size=batch_size, 393 ) 394 395 offset = 0 396 segmentation = np.zeros(data.shape, dtype="uint32") 397 398 for i in tqdm(range(segmentation.shape[0]), desc="Segment slices", disable=not verbose): 399 segmentor.initialize(data[i], image_embeddings=image_embeddings, verbose=False, i=i) 400 seg = segmentor.generate(**kwargs) 401 402 if isinstance(seg, list) and len(seg) == 0: 403 continue 404 else: 405 if isinstance(seg, list): 406 seg = mask_data_to_segmentation( 407 seg, with_background=with_background, min_object_size=min_object_size 408 ) 409 410 # Set offset for instance per slice. 411 max_z = int(seg.max()) 412 if max_z == 0: 413 continue 414 seg[seg != 0] += offset 415 offset = max_z + offset 416 417 segmentation[i] = seg 418 419 return segmentation, image_embeddings 420 421 422def automatic_3d_segmentation( 423 volume: np.ndarray, 424 predictor: SamPredictor, 425 segmentor: AMGBase, 426 embedding_path: Optional[Union[str, os.PathLike]] = None, 427 with_background: bool = True, 428 gap_closing: Optional[int] = None, 429 min_z_extent: Optional[int] = None, 430 tile_shape: Optional[Tuple[int, int]] = None, 431 halo: Optional[Tuple[int, int]] = None, 432 verbose: bool = True, 433 return_embeddings: bool = False, 434 batch_size: int = 1, 435 **kwargs, 436) -> np.ndarray: 437 """Automatically segment objects in a volume. 438 439 First segments slices individually in 2d and then merges them across 3d 440 based on overlap of objects between slices. 441 442 Args: 443 volume: The input volume. 444 predictor: The Segment Anything predictor. 445 segmentor: The instance segmentation class. 446 embedding_path: The path to save pre-computed embeddings. 447 with_background: Whether the segmentation has background. By default, set to 'True'. 448 gap_closing: If given, gaps in the segmentation are closed with a binary closing 449 operation. The value is used to determine the number of iterations for the closing. 450 min_z_extent: Require a minimal extent in z for the segmented objects. 451 This can help to prevent segmentation artifacts. 452 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 453 halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. 454 verbose: Verbosity flag. By default, set to 'True'. 455 return_embeddings: Whether to return the precomputed image embeddings. By default, set to 'False'. 456 batch_size: The batch size to compute image embeddings over planes. By default, set to '1'. 457 kwargs: Keyword arguments for the 'generate' method of the 'segmentor'. 458 459 Returns: 460 The segmentation. 461 """ 462 segmentation, image_embeddings = _segment_slices( 463 data=volume, 464 predictor=predictor, 465 segmentor=segmentor, 466 embedding_path=embedding_path, 467 verbose=verbose, 468 tile_shape=tile_shape, 469 halo=halo, 470 with_background=with_background, 471 batch_size=batch_size, 472 **kwargs 473 ) 474 segmentation = merge_instance_segmentation_3d( 475 segmentation, 476 beta=0.5, 477 with_background=with_background, 478 gap_closing=gap_closing, 479 min_z_extent=min_z_extent, 480 verbose=verbose, 481 ) 482 if return_embeddings: 483 return segmentation, image_embeddings 484 else: 485 return segmentation 486 487 488def _filter_tracks(tracking_result, min_track_length): 489 props = regionprops(tracking_result) 490 discard_ids = [] 491 for prop in props: 492 label_id = prop.label 493 z_start, z_stop = prop.bbox[0], prop.bbox[3] 494 if z_stop - z_start < min_track_length: 495 discard_ids.append(label_id) 496 tracking_result[np.isin(tracking_result, discard_ids)] = 0 497 tracking_result, _, _ = relabel_sequential(tracking_result) 498 return tracking_result 499 500 501def _extract_tracks_and_lineages(segmentations, track_data, parent_graph): 502 # The track data has the following layout: n_tracks x 4 503 # With the following columns: 504 # track_id - id of the track (= result from trackastra) 505 # timepoint 506 # y coordinate 507 # x coordinate 508 509 # Use the last three columns to index the segmentation and get the segmentation id. 510 index = np.round(track_data[:, 1:], 0).astype("int32") 511 index = tuple(index[:, i] for i in range(index.shape[1])) 512 segmentation_ids = segmentations[index] 513 514 # Find the mapping of nodes (= segmented objects) to track-ids. 515 track_ids = track_data[:, 0].astype("int32") 516 assert len(segmentation_ids) == len(track_ids) 517 node_to_track = {k: v for k, v in zip(segmentation_ids, track_ids)} 518 519 # Find the lineages as connected components in the parent graph. 520 # First, we build a proper graph. 521 lineage_graph = nx.Graph() 522 for k, v in parent_graph.items(): 523 lineage_graph.add_edge(k, v) 524 525 # Then, find the connected components, and compute the lineage representation expected by micro-sam from it: 526 # E.g. if we have three lineages, the first consisting of three tracks and the second and third of one track each: 527 # [ 528 # {1: [2, 3]}, lineage with a dividing cell 529 # {4: []}, lineage with just one cell 530 # {5: []}, lineage with just one cell 531 # ] 532 533 # First, we fill the lineages which have one or more divisions, i.e. trees with more than one node. 534 lineages = [] 535 for component in nx.connected_components(lineage_graph): 536 root = next(iter(component)) 537 lineage_dict = {} 538 539 def dfs(node, parent): 540 # Avoid revisiting the parent node 541 children = [n for n in lineage_graph[node] if n != parent] 542 lineage_dict[node] = children 543 for child in children: 544 dfs(child, node) 545 546 dfs(root, None) 547 lineages.append(lineage_dict) 548 549 # Then add single node lineages, which are not reflected in the original graph. 550 all_tracks = set(track_ids.tolist()) 551 lineage_tracks = [] 552 for lineage in lineages: 553 for k, v in lineage.items(): 554 lineage_tracks.append(k) 555 lineage_tracks.extend(v) 556 singleton_tracks = list(all_tracks - set(lineage_tracks)) 557 lineages.extend([{track: []} for track in singleton_tracks]) 558 559 # Make sure node_to_track contains everything. 560 all_seg_ids = np.unique(segmentations) 561 missing_seg_ids = np.setdiff1d(all_seg_ids, list(node_to_track.keys())) 562 node_to_track.update({seg_id: 0 for seg_id in missing_seg_ids}) 563 return node_to_track, lineages 564 565 566def _filter_lineages(lineages, tracking_result): 567 track_ids = set(np.unique(tracking_result)) - {0} 568 filtered_lineages = [] 569 for lineage in lineages: 570 filtered_lineage = {k: v for k, v in lineage.items() if k in track_ids} 571 if filtered_lineage: 572 filtered_lineages.append(filtered_lineage) 573 return filtered_lineages 574 575 576def _tracking_impl(timeseries, segmentation, mode, min_time_extent, output_folder=None): 577 device = "cuda" if torch.cuda.is_available() else "cpu" 578 model = Trackastra.from_pretrained("general_2d", device=device) 579 lineage_graph = model.track(timeseries, segmentation, mode=mode) 580 track_data, parent_graph, _ = graph_to_napari_tracks(lineage_graph) 581 node_to_track, lineages = _extract_tracks_and_lineages(segmentation, track_data, parent_graph) 582 tracking_result = recolor_segmentation(segmentation, node_to_track) 583 584 if output_folder is not None: # Store tracking results in CTC format. 585 graph_to_ctc(lineage_graph, segmentation, outdir=output_folder) 586 587 # TODO 588 # We should check if trackastra supports this already. 589 # Filter out short tracks / lineages. 590 if min_time_extent is not None and min_time_extent > 0: 591 raise NotImplementedError 592 593 # Filter out pruned lineages. 594 # Mmay either be missing due to track filtering or non-consectutive track numbering in trackastra. 595 lineages = _filter_lineages(lineages, tracking_result) 596 597 return tracking_result, lineages 598 599 600def track_across_frames( 601 timeseries: np.ndarray, 602 segmentation: np.ndarray, 603 gap_closing: Optional[int] = None, 604 min_time_extent: Optional[int] = None, 605 verbose: bool = True, 606 pbar_init: Optional[callable] = None, 607 pbar_update: Optional[callable] = None, 608 output_folder: Optional[Union[os.PathLike, str]] = None, 609) -> Tuple[np.ndarray, List[Dict]]: 610 """Track segmented objects over time. 611 612 This function uses Trackastra: https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf 613 for tracking. Please cite it if you use the automated tracking functionality. 614 615 Args: 616 timeseries: The input timeseries of images. 617 segmentation: The segmentation. Expect segmentation results per frame 618 that are relabeled so that segmentation ids don't overlap. 619 gap_closing: If given, gaps in the segmentation are closed with a binary closing 620 operation. The value is used to determine the number of iterations for the closing. 621 min_time_extent: Require a minimal extent in time for the tracked objects. 622 verbose: Verbosity flag. By default, set to 'True'. 623 pbar_init: Function to initialize the progress bar. 624 pbar_update: Function to update the progress bar. 625 output_folder: The folder where the tracking results are stored in CTC format. 626 627 Returns: 628 The tracking result. Each object is colored by its track id. 629 The lineages, which correspond to the cell divisions. Lineages are represented by a list of dicts, 630 with each dict encoding a lineage, where keys correspond to parent track ids. 631 Each key either maps to a list with two child track ids (cell division) or to an empty list (no division). 632 """ 633 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init=pbar_init, pbar_update=pbar_update) 634 635 if gap_closing is not None and gap_closing > 0: 636 segmentation = _preprocess_closing(segmentation, gap_closing, pbar_update) 637 638 segmentation, lineage = _tracking_impl( 639 timeseries=np.asarray(timeseries), 640 segmentation=segmentation, 641 mode="greedy", 642 min_time_extent=min_time_extent, 643 output_folder=output_folder, 644 ) 645 return segmentation, lineage 646 647 648def automatic_tracking_implementation( 649 timeseries: np.ndarray, 650 predictor: SamPredictor, 651 segmentor: AMGBase, 652 embedding_path: Optional[Union[str, os.PathLike]] = None, 653 gap_closing: Optional[int] = None, 654 min_time_extent: Optional[int] = None, 655 tile_shape: Optional[Tuple[int, int]] = None, 656 halo: Optional[Tuple[int, int]] = None, 657 verbose: bool = True, 658 return_embeddings: bool = False, 659 batch_size: int = 1, 660 output_folder: Optional[Union[os.PathLike, str]] = None, 661 **kwargs, 662) -> Tuple[np.ndarray, List[Dict]]: 663 """Automatically track objects in a timesries based on per-frame automatic segmentation. 664 665 This function uses Trackastra: https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf 666 for tracking. Please cite it if you use the automated tracking functionality. 667 668 Args: 669 timeseries: The input timeseries of images. 670 predictor: The SAM model. 671 segmentor: The instance segmentation class. 672 embedding_path: The path to save pre-computed embeddings. 673 gap_closing: If given, gaps in the segmentation are closed with a binary closing 674 operation. The value is used to determine the number of iterations for the closing. 675 min_time_extent: Require a minimal extent in time for the tracked objects. 676 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 677 halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. 678 verbose: Verbosity flag. By default, set to 'True'. 679 return_embeddings: Whether to return the precomputed image embeddings. By default, set to 'False'. 680 batch_size: The batch size to compute image embeddings over planes. By default, set to '1'. 681 output_folder: The folder where the tracking results are stored in CTC format. 682 kwargs: Keyword arguments for the 'generate' method of the 'segmentor'. 683 684 Returns: 685 The tracking result. Each object is colored by its track id. 686 The lineages, which correspond to the cell divisions. Lineages are represented by a list of dicts, 687 with each dict encoding a lineage, where keys correspond to parent track ids. 688 Each key either maps to a list with two child track ids (cell division) or to an empty list (no division). 689 """ 690 if Trackastra is None: 691 raise RuntimeError( 692 "Automatic tracking requires trackastra. You can install it via 'pip install trackastra'." 693 ) 694 695 segmentation, image_embeddings = _segment_slices( 696 timeseries, predictor, segmentor, embedding_path, verbose, 697 tile_shape=tile_shape, halo=halo, batch_size=batch_size, 698 **kwargs, 699 ) 700 701 segmentation, lineage = track_across_frames( 702 timeseries=timeseries, 703 segmentation=segmentation, 704 gap_closing=gap_closing, 705 min_time_extent=min_time_extent, 706 verbose=verbose, 707 output_folder=output_folder, 708 ) 709 710 if return_embeddings: 711 return segmentation, lineage, image_embeddings 712 else: 713 return segmentation, lineage 714 715 716def get_napari_track_data( 717 segmentation: np.ndarray, lineages: List[Dict], n_threads: Optional[int] = None 718) -> Tuple[np.ndarray, Dict[int, List]]: 719 """Derive the inputs for the napari tracking layer from a tracking result. 720 721 Args: 722 segmentation: The segmentation, after relabeling with track ids. 723 lineages: The lineage information. 724 n_threads: Number of threads for extracting the track data from the segmentation. 725 726 Returns: 727 The array with the track data expected by napari. 728 The parent dictionary for napari. 729 """ 730 if n_threads is None: 731 n_threads = mp.cpu_count() 732 733 def compute_props(t): 734 props = regionprops(segmentation[t]) 735 # Create the track data representation for napari, which expects: 736 # track_id, timepoint, y, x 737 track_data = np.array([[prop.label, t] + list(prop.centroid) for prop in props]) 738 return track_data 739 740 with futures.ThreadPoolExecutor(n_threads) as tp: 741 track_data = list(tp.map(compute_props, range(segmentation.shape[0]))) 742 track_data = [data for data in track_data if data.size > 0] 743 track_data = np.concatenate(track_data) 744 745 # The graph representation of napari uses the children as keys and the parents as values, 746 # whereas our representation uses parents as keys and children as values. 747 # Hence, we need to translate the representation. 748 parent_graph = { 749 child: [parent] for lineage in lineages for parent, children in lineage.items() for child in children 750 } 751 752 return track_data, parent_graph
101def segment_mask_in_volume( 102 segmentation: np.ndarray, 103 predictor: SamPredictor, 104 image_embeddings: util.ImageEmbeddings, 105 segmented_slices: np.ndarray, 106 stop_lower: bool, 107 stop_upper: bool, 108 iou_threshold: float, 109 projection: Union[str, dict], 110 update_progress: Optional[callable] = None, 111 box_extension: float = 0.0, 112 verbose: bool = False, 113) -> Tuple[np.ndarray, Tuple[int, int]]: 114 """Segment an object mask in in volumetric data. 115 116 Args: 117 segmentation: The initial segmentation for the object. 118 predictor: The Segment Anything predictor. 119 image_embeddings: The precomputed image embeddings for the volume. 120 segmented_slices: List of slices for which this object has already been segmented. 121 stop_lower: Whether to stop at the lowest segmented slice. 122 stop_upper: Wheter to stop at the topmost segmented slice. 123 iou_threshold: The IOU threshold for continuing segmentation across 3d. 124 projection: The projection method to use. One of 'box', 'mask', 'points', 'points_and_mask' or 'single point'. 125 Pass a dictionary to choose the excact combination of projection modes. 126 update_progress: Callback to update an external progress bar. 127 box_extension: Extension factor for increasing the box size after projection. 128 By default, does not increase the projected box size. 129 verbose: Whether to print details about the segmentation steps. By default, set to 'True'. 130 131 Returns: 132 Array with the volumetric segmentation. 133 Tuple with the first and last segmented slice. 134 """ 135 use_box, use_mask, use_points, use_single_point = _validate_projection(projection) 136 137 if update_progress is None: 138 def update_progress(*args): 139 pass 140 141 def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None, verbose=False): 142 z = z_start + increment 143 while True: 144 if verbose: 145 print(f"Segment {z_start} to {z_stop}: segmenting slice {z}") 146 seg_prev = segmentation[z - increment] 147 seg_z, score, _ = segment_from_mask( 148 predictor, seg_prev, image_embeddings=image_embeddings, i=z, use_mask=use_mask, 149 use_box=use_box, use_points=use_points, box_extension=box_extension, return_all=True, 150 use_single_point=use_single_point, 151 ) 152 if threshold is not None: 153 iou = util.compute_iou(seg_prev, seg_z) 154 if iou < threshold: 155 if verbose: 156 msg = f"Segmentation stopped at slice {z} due to IOU {iou} < {threshold}." 157 print(msg) 158 break 159 160 segmentation[z] = seg_z 161 z += increment 162 if stopping_criterion(z, z_stop): 163 if verbose: 164 print(f"Segment {z_start} to {z_stop}: stop at slice {z}") 165 break 166 update_progress(1) 167 168 return z - increment 169 170 z0, z1 = int(segmented_slices.min()), int(segmented_slices.max()) 171 172 # segment below the min slice 173 if z0 > 0 and not stop_lower: 174 z_min = segment_range(z0, 0, -1, np.less, iou_threshold, verbose=verbose) 175 else: 176 z_min = z0 177 178 # segment above the max slice 179 if z1 < segmentation.shape[0] - 1 and not stop_upper: 180 z_max = segment_range(z1, segmentation.shape[0] - 1, 1, np.greater, iou_threshold, verbose=verbose) 181 else: 182 z_max = z1 183 184 # segment in between min and max slice 185 if z0 != z1: 186 for z_start, z_stop in zip(segmented_slices[:-1], segmented_slices[1:]): 187 slice_diff = z_stop - z_start 188 z_mid = int((z_start + z_stop) // 2) 189 190 if slice_diff == 1: # the slices are adjacent -> we don't need to do anything 191 pass 192 193 elif z_start == z0 and stop_lower: # the lower slice is stop: we just segment from upper 194 segment_range(z_stop, z_start, -1, np.less_equal, verbose=verbose) 195 196 elif z_stop == z1 and stop_upper: # the upper slice is stop: we just segment from lower 197 segment_range(z_start, z_stop, 1, np.greater_equal, verbose=verbose) 198 199 elif slice_diff == 2: # there is only one slice in between -> use combined mask 200 z = z_start + 1 201 seg_prompt = np.logical_or(segmentation[z_start] == 1, segmentation[z_stop] == 1) 202 segmentation[z] = segment_from_mask( 203 predictor, seg_prompt, image_embeddings=image_embeddings, i=z, 204 use_mask=use_mask, use_box=use_box, use_points=use_points, 205 box_extension=box_extension 206 ) 207 update_progress(1) 208 209 else: # there is a range of more than 2 slices in between -> segment ranges 210 # segment from bottom 211 segment_range( 212 z_start, z_mid, 1, np.greater_equal if slice_diff % 2 == 0 else np.greater, verbose=verbose 213 ) 214 # segment from top 215 segment_range(z_stop, z_mid, -1, np.less_equal, verbose=verbose) 216 # if the difference between start and stop is even, 217 # then we have a slice in the middle that is the same distance from top bottom 218 # in this case the slice is not segmented in the ranges above, and we segment it 219 # using the combined mask from the adjacent top and bottom slice as prompt 220 if slice_diff % 2 == 0: 221 seg_prompt = np.logical_or(segmentation[z_mid - 1] == 1, segmentation[z_mid + 1] == 1) 222 segmentation[z_mid] = segment_from_mask( 223 predictor, seg_prompt, image_embeddings=image_embeddings, i=z_mid, 224 use_mask=use_mask, use_box=use_box, use_points=use_points, 225 box_extension=box_extension 226 ) 227 update_progress(1) 228 229 return segmentation, (z_min, z_max)
Segment an object mask in in volumetric data.
Arguments:
- segmentation: The initial segmentation for the object.
- predictor: The Segment Anything predictor.
- image_embeddings: The precomputed image embeddings for the volume.
- segmented_slices: List of slices for which this object has already been segmented.
- stop_lower: Whether to stop at the lowest segmented slice.
- stop_upper: Wheter to stop at the topmost segmented slice.
- iou_threshold: The IOU threshold for continuing segmentation across 3d.
- projection: The projection method to use. One of 'box', 'mask', 'points', 'points_and_mask' or 'single point'. Pass a dictionary to choose the excact combination of projection modes.
- update_progress: Callback to update an external progress bar.
- box_extension: Extension factor for increasing the box size after projection. By default, does not increase the projected box size.
- verbose: Whether to print details about the segmentation steps. By default, set to 'True'.
Returns:
Array with the volumetric segmentation. Tuple with the first and last segmented slice.
308def merge_instance_segmentation_3d( 309 slice_segmentation: np.ndarray, 310 beta: float = 0.5, 311 with_background: bool = True, 312 gap_closing: Optional[int] = None, 313 min_z_extent: Optional[int] = None, 314 verbose: bool = True, 315 pbar_init: Optional[callable] = None, 316 pbar_update: Optional[callable] = None, 317) -> np.ndarray: 318 """Merge stacked 2d instance segmentations into a consistent 3d segmentation. 319 320 Solves a multicut problem based on the overlap of objects to merge across z. 321 322 Args: 323 slice_segmentation: The stacked segmentation across the slices. 324 We assume that the segmentation is labeled consecutive across z. 325 beta: The bias term for the multicut. Higher values lead to a larger 326 degree of over-segmentation and vice versa. by default, set to '0.5'. 327 with_background: Whether this is a segmentation problem with background. 328 In that case all edges connecting to the background are set to be repulsive. 329 By default, set to 'True'. 330 gap_closing: If given, gaps in the segmentation are closed with a binary closing 331 operation. The value is used to determine the number of iterations for the closing. 332 min_z_extent: Require a minimal extent in z for the segmented objects. 333 This can help to prevent segmentation artifacts. 334 verbose: Verbosity flag. By default, set to 'True'. 335 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 336 Can be used together with pbar_update to handle napari progress bar in other thread. 337 To enables using this function within a threadworker. 338 pbar_update: Callback to update an external progress bar. 339 340 Returns: 341 The merged segmentation. 342 """ 343 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 344 345 if gap_closing is not None and gap_closing > 0: 346 pbar_init(slice_segmentation.shape[0] + 1, "Merge segmentation") 347 slice_segmentation = _preprocess_closing(slice_segmentation, gap_closing, pbar_update) 348 else: 349 pbar_init(1, "Merge segmentation") 350 351 # Extract the overlap between slices. 352 edges = track_utils.compute_edges_from_overlap(slice_segmentation, verbose=False) 353 354 uv_ids = np.array([[edge["source"], edge["target"]] for edge in edges]) 355 overlaps = np.array([edge["score"] for edge in edges]) 356 357 n_nodes = int(slice_segmentation.max() + 1) 358 graph = nifty.graph.undirectedGraph(n_nodes) 359 graph.insertEdges(uv_ids) 360 361 costs = seg_utils.multicut.compute_edge_costs(overlaps) 362 # Set background weights to be maximally repulsive. 363 if with_background: 364 bg_edges = (uv_ids == 0).any(axis=1) 365 costs[bg_edges] = -8.0 366 367 node_labels = seg_utils.multicut.multicut_decomposition(graph, 1.0 - costs, beta=beta) 368 369 segmentation = nifty.tools.take(node_labels, slice_segmentation) 370 if min_z_extent is not None and min_z_extent > 0: 371 segmentation = _filter_z_extent(segmentation, min_z_extent) 372 373 pbar_update(1) 374 pbar_close() 375 376 return segmentation
Merge stacked 2d instance segmentations into a consistent 3d segmentation.
Solves a multicut problem based on the overlap of objects to merge across z.
Arguments:
- slice_segmentation: The stacked segmentation across the slices. We assume that the segmentation is labeled consecutive across z.
- beta: The bias term for the multicut. Higher values lead to a larger degree of over-segmentation and vice versa. by default, set to '0.5'.
- with_background: Whether this is a segmentation problem with background. In that case all edges connecting to the background are set to be repulsive. By default, set to 'True'.
- gap_closing: If given, gaps in the segmentation are closed with a binary closing operation. The value is used to determine the number of iterations for the closing.
- min_z_extent: Require a minimal extent in z for the segmented objects. This can help to prevent segmentation artifacts.
- verbose: Verbosity flag. By default, set to 'True'.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
Returns:
The merged segmentation.
423def automatic_3d_segmentation( 424 volume: np.ndarray, 425 predictor: SamPredictor, 426 segmentor: AMGBase, 427 embedding_path: Optional[Union[str, os.PathLike]] = None, 428 with_background: bool = True, 429 gap_closing: Optional[int] = None, 430 min_z_extent: Optional[int] = None, 431 tile_shape: Optional[Tuple[int, int]] = None, 432 halo: Optional[Tuple[int, int]] = None, 433 verbose: bool = True, 434 return_embeddings: bool = False, 435 batch_size: int = 1, 436 **kwargs, 437) -> np.ndarray: 438 """Automatically segment objects in a volume. 439 440 First segments slices individually in 2d and then merges them across 3d 441 based on overlap of objects between slices. 442 443 Args: 444 volume: The input volume. 445 predictor: The Segment Anything predictor. 446 segmentor: The instance segmentation class. 447 embedding_path: The path to save pre-computed embeddings. 448 with_background: Whether the segmentation has background. By default, set to 'True'. 449 gap_closing: If given, gaps in the segmentation are closed with a binary closing 450 operation. The value is used to determine the number of iterations for the closing. 451 min_z_extent: Require a minimal extent in z for the segmented objects. 452 This can help to prevent segmentation artifacts. 453 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 454 halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. 455 verbose: Verbosity flag. By default, set to 'True'. 456 return_embeddings: Whether to return the precomputed image embeddings. By default, set to 'False'. 457 batch_size: The batch size to compute image embeddings over planes. By default, set to '1'. 458 kwargs: Keyword arguments for the 'generate' method of the 'segmentor'. 459 460 Returns: 461 The segmentation. 462 """ 463 segmentation, image_embeddings = _segment_slices( 464 data=volume, 465 predictor=predictor, 466 segmentor=segmentor, 467 embedding_path=embedding_path, 468 verbose=verbose, 469 tile_shape=tile_shape, 470 halo=halo, 471 with_background=with_background, 472 batch_size=batch_size, 473 **kwargs 474 ) 475 segmentation = merge_instance_segmentation_3d( 476 segmentation, 477 beta=0.5, 478 with_background=with_background, 479 gap_closing=gap_closing, 480 min_z_extent=min_z_extent, 481 verbose=verbose, 482 ) 483 if return_embeddings: 484 return segmentation, image_embeddings 485 else: 486 return segmentation
Automatically segment objects in a volume.
First segments slices individually in 2d and then merges them across 3d based on overlap of objects between slices.
Arguments:
- volume: The input volume.
- predictor: The Segment Anything predictor.
- segmentor: The instance segmentation class.
- embedding_path: The path to save pre-computed embeddings.
- with_background: Whether the segmentation has background. By default, set to 'True'.
- gap_closing: If given, gaps in the segmentation are closed with a binary closing operation. The value is used to determine the number of iterations for the closing.
- min_z_extent: Require a minimal extent in z for the segmented objects. This can help to prevent segmentation artifacts.
- tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
- halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
- verbose: Verbosity flag. By default, set to 'True'.
- return_embeddings: Whether to return the precomputed image embeddings. By default, set to 'False'.
- batch_size: The batch size to compute image embeddings over planes. By default, set to '1'.
- kwargs: Keyword arguments for the 'generate' method of the 'segmentor'.
Returns:
The segmentation.
601def track_across_frames( 602 timeseries: np.ndarray, 603 segmentation: np.ndarray, 604 gap_closing: Optional[int] = None, 605 min_time_extent: Optional[int] = None, 606 verbose: bool = True, 607 pbar_init: Optional[callable] = None, 608 pbar_update: Optional[callable] = None, 609 output_folder: Optional[Union[os.PathLike, str]] = None, 610) -> Tuple[np.ndarray, List[Dict]]: 611 """Track segmented objects over time. 612 613 This function uses Trackastra: https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf 614 for tracking. Please cite it if you use the automated tracking functionality. 615 616 Args: 617 timeseries: The input timeseries of images. 618 segmentation: The segmentation. Expect segmentation results per frame 619 that are relabeled so that segmentation ids don't overlap. 620 gap_closing: If given, gaps in the segmentation are closed with a binary closing 621 operation. The value is used to determine the number of iterations for the closing. 622 min_time_extent: Require a minimal extent in time for the tracked objects. 623 verbose: Verbosity flag. By default, set to 'True'. 624 pbar_init: Function to initialize the progress bar. 625 pbar_update: Function to update the progress bar. 626 output_folder: The folder where the tracking results are stored in CTC format. 627 628 Returns: 629 The tracking result. Each object is colored by its track id. 630 The lineages, which correspond to the cell divisions. Lineages are represented by a list of dicts, 631 with each dict encoding a lineage, where keys correspond to parent track ids. 632 Each key either maps to a list with two child track ids (cell division) or to an empty list (no division). 633 """ 634 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init=pbar_init, pbar_update=pbar_update) 635 636 if gap_closing is not None and gap_closing > 0: 637 segmentation = _preprocess_closing(segmentation, gap_closing, pbar_update) 638 639 segmentation, lineage = _tracking_impl( 640 timeseries=np.asarray(timeseries), 641 segmentation=segmentation, 642 mode="greedy", 643 min_time_extent=min_time_extent, 644 output_folder=output_folder, 645 ) 646 return segmentation, lineage
Track segmented objects over time.
This function uses Trackastra: https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf for tracking. Please cite it if you use the automated tracking functionality.
Arguments:
- timeseries: The input timeseries of images.
- segmentation: The segmentation. Expect segmentation results per frame that are relabeled so that segmentation ids don't overlap.
- gap_closing: If given, gaps in the segmentation are closed with a binary closing operation. The value is used to determine the number of iterations for the closing.
- min_time_extent: Require a minimal extent in time for the tracked objects.
- verbose: Verbosity flag. By default, set to 'True'.
- pbar_init: Function to initialize the progress bar.
- pbar_update: Function to update the progress bar.
- output_folder: The folder where the tracking results are stored in CTC format.
Returns:
The tracking result. Each object is colored by its track id. The lineages, which correspond to the cell divisions. Lineages are represented by a list of dicts, with each dict encoding a lineage, where keys correspond to parent track ids. Each key either maps to a list with two child track ids (cell division) or to an empty list (no division).
649def automatic_tracking_implementation( 650 timeseries: np.ndarray, 651 predictor: SamPredictor, 652 segmentor: AMGBase, 653 embedding_path: Optional[Union[str, os.PathLike]] = None, 654 gap_closing: Optional[int] = None, 655 min_time_extent: Optional[int] = None, 656 tile_shape: Optional[Tuple[int, int]] = None, 657 halo: Optional[Tuple[int, int]] = None, 658 verbose: bool = True, 659 return_embeddings: bool = False, 660 batch_size: int = 1, 661 output_folder: Optional[Union[os.PathLike, str]] = None, 662 **kwargs, 663) -> Tuple[np.ndarray, List[Dict]]: 664 """Automatically track objects in a timesries based on per-frame automatic segmentation. 665 666 This function uses Trackastra: https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf 667 for tracking. Please cite it if you use the automated tracking functionality. 668 669 Args: 670 timeseries: The input timeseries of images. 671 predictor: The SAM model. 672 segmentor: The instance segmentation class. 673 embedding_path: The path to save pre-computed embeddings. 674 gap_closing: If given, gaps in the segmentation are closed with a binary closing 675 operation. The value is used to determine the number of iterations for the closing. 676 min_time_extent: Require a minimal extent in time for the tracked objects. 677 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 678 halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. 679 verbose: Verbosity flag. By default, set to 'True'. 680 return_embeddings: Whether to return the precomputed image embeddings. By default, set to 'False'. 681 batch_size: The batch size to compute image embeddings over planes. By default, set to '1'. 682 output_folder: The folder where the tracking results are stored in CTC format. 683 kwargs: Keyword arguments for the 'generate' method of the 'segmentor'. 684 685 Returns: 686 The tracking result. Each object is colored by its track id. 687 The lineages, which correspond to the cell divisions. Lineages are represented by a list of dicts, 688 with each dict encoding a lineage, where keys correspond to parent track ids. 689 Each key either maps to a list with two child track ids (cell division) or to an empty list (no division). 690 """ 691 if Trackastra is None: 692 raise RuntimeError( 693 "Automatic tracking requires trackastra. You can install it via 'pip install trackastra'." 694 ) 695 696 segmentation, image_embeddings = _segment_slices( 697 timeseries, predictor, segmentor, embedding_path, verbose, 698 tile_shape=tile_shape, halo=halo, batch_size=batch_size, 699 **kwargs, 700 ) 701 702 segmentation, lineage = track_across_frames( 703 timeseries=timeseries, 704 segmentation=segmentation, 705 gap_closing=gap_closing, 706 min_time_extent=min_time_extent, 707 verbose=verbose, 708 output_folder=output_folder, 709 ) 710 711 if return_embeddings: 712 return segmentation, lineage, image_embeddings 713 else: 714 return segmentation, lineage
Automatically track objects in a timesries based on per-frame automatic segmentation.
This function uses Trackastra: https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf for tracking. Please cite it if you use the automated tracking functionality.
Arguments:
- timeseries: The input timeseries of images.
- predictor: The SAM model.
- segmentor: The instance segmentation class.
- embedding_path: The path to save pre-computed embeddings.
- gap_closing: If given, gaps in the segmentation are closed with a binary closing operation. The value is used to determine the number of iterations for the closing.
- min_time_extent: Require a minimal extent in time for the tracked objects.
- tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
- halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
- verbose: Verbosity flag. By default, set to 'True'.
- return_embeddings: Whether to return the precomputed image embeddings. By default, set to 'False'.
- batch_size: The batch size to compute image embeddings over planes. By default, set to '1'.
- output_folder: The folder where the tracking results are stored in CTC format.
- kwargs: Keyword arguments for the 'generate' method of the 'segmentor'.
Returns:
The tracking result. Each object is colored by its track id. The lineages, which correspond to the cell divisions. Lineages are represented by a list of dicts, with each dict encoding a lineage, where keys correspond to parent track ids. Each key either maps to a list with two child track ids (cell division) or to an empty list (no division).
717def get_napari_track_data( 718 segmentation: np.ndarray, lineages: List[Dict], n_threads: Optional[int] = None 719) -> Tuple[np.ndarray, Dict[int, List]]: 720 """Derive the inputs for the napari tracking layer from a tracking result. 721 722 Args: 723 segmentation: The segmentation, after relabeling with track ids. 724 lineages: The lineage information. 725 n_threads: Number of threads for extracting the track data from the segmentation. 726 727 Returns: 728 The array with the track data expected by napari. 729 The parent dictionary for napari. 730 """ 731 if n_threads is None: 732 n_threads = mp.cpu_count() 733 734 def compute_props(t): 735 props = regionprops(segmentation[t]) 736 # Create the track data representation for napari, which expects: 737 # track_id, timepoint, y, x 738 track_data = np.array([[prop.label, t] + list(prop.centroid) for prop in props]) 739 return track_data 740 741 with futures.ThreadPoolExecutor(n_threads) as tp: 742 track_data = list(tp.map(compute_props, range(segmentation.shape[0]))) 743 track_data = [data for data in track_data if data.size > 0] 744 track_data = np.concatenate(track_data) 745 746 # The graph representation of napari uses the children as keys and the parents as values, 747 # whereas our representation uses parents as keys and children as values. 748 # Hence, we need to translate the representation. 749 parent_graph = { 750 child: [parent] for lineage in lineages for parent, children in lineage.items() for child in children 751 } 752 753 return track_data, parent_graph
Derive the inputs for the napari tracking layer from a tracking result.
Arguments:
- segmentation: The segmentation, after relabeling with track ids.
- lineages: The lineage information.
- n_threads: Number of threads for extracting the track data from the segmentation.
Returns:
The array with the track data expected by napari. The parent dictionary for napari.