micro_sam.multi_dimensional_segmentation
Multi-dimensional segmentation with segment anything.
1"""Multi-dimensional segmentation with segment anything. 2""" 3 4import os 5from typing import Optional, Union, Tuple 6 7import numpy as np 8from scipy.ndimage import binary_closing 9from skimage.measure import label, regionprops 10from skimage.segmentation import relabel_sequential 11 12import nifty 13 14import elf.segmentation as seg_utils 15import elf.tracking.tracking_utils as track_utils 16 17from segment_anything.predictor import SamPredictor 18 19try: 20 from napari.utils import progress as tqdm 21except ImportError: 22 from tqdm import tqdm 23 24from . import util 25from .prompt_based_segmentation import segment_from_mask 26from .instance_segmentation import AMGBase, mask_data_to_segmentation 27 28 29PROJECTION_MODES = ("box", "mask", "points", "points_and_mask", "single_point") 30 31 32def _validate_projection(projection): 33 use_single_point = False 34 if isinstance(projection, str): 35 if projection == "mask": 36 use_box, use_mask, use_points = True, True, False 37 elif projection == "points": 38 use_box, use_mask, use_points = False, False, True 39 elif projection == "box": 40 use_box, use_mask, use_points = True, False, False 41 elif projection == "points_and_mask": 42 use_box, use_mask, use_points = False, True, True 43 elif projection == "single_point": 44 use_box, use_mask, use_points = False, False, True 45 use_single_point = True 46 else: 47 raise ValueError( 48 "Choose projection method from 'mask' / 'points' / 'box' / 'points_and_mask' / 'single_point'. " 49 f"You have passed the invalid option {projection}." 50 ) 51 elif isinstance(projection, dict): 52 assert len(projection.keys()) == 3, "There should be three parameters assigned for the projection method." 53 use_box, use_mask, use_points = projection["use_box"], projection["use_mask"], projection["use_points"] 54 else: 55 raise ValueError(f"{projection} is not a supported projection method.") 56 return use_box, use_mask, use_points, use_single_point 57 58 59# Advanced stopping criterions. 60# In practice these did not make a big difference, so we do not use this at the moment. 61# We still leave it here for reference. 62def _advanced_stopping_criteria( 63 z, seg_z, seg_prev, z_start, z_increment, segmentation, criterion_choice, score, increment 64): 65 def _compute_mean_iou_for_n_slices(z, increment, seg_z, n_slices): 66 iou_list = [ 67 util.compute_iou(segmentation[z - increment * _slice], seg_z) for _slice in range(1, n_slices+1) 68 ] 69 return np.mean(iou_list) 70 71 if criterion_choice == 1: 72 # 1. current metric: iou of current segmentation and the previous slice 73 iou = util.compute_iou(seg_prev, seg_z) 74 criterion = iou 75 76 elif criterion_choice == 2: 77 # 2. combining SAM iou + iou: curr. slice & first segmented slice + iou: curr. slice vs prev. slice 78 iou = util.compute_iou(seg_prev, seg_z) 79 ff_iou = util.compute_iou(segmentation[z_start], seg_z) 80 criterion = 0.5 * iou + 0.3 * score + 0.2 * ff_iou 81 82 elif criterion_choice == 3: 83 # 3. iou of current segmented slice w.r.t the previous n slices 84 criterion = _compute_mean_iou_for_n_slices(z, increment, seg_z, min(5, abs(z - z_start))) 85 86 return criterion 87 88 89def segment_mask_in_volume( 90 segmentation: np.ndarray, 91 predictor: SamPredictor, 92 image_embeddings: util.ImageEmbeddings, 93 segmented_slices: np.ndarray, 94 stop_lower: bool, 95 stop_upper: bool, 96 iou_threshold: float, 97 projection: Union[str, dict], 98 update_progress: Optional[callable] = None, 99 box_extension: float = 0.0, 100 verbose: bool = False, 101) -> Tuple[np.ndarray, Tuple[int, int]]: 102 """Segment an object mask in in volumetric data. 103 104 Args: 105 segmentation: The initial segmentation for the object. 106 predictor: The segment anything predictor. 107 image_embeddings: The precomputed image embeddings for the volume. 108 segmented_slices: List of slices for which this object has already been segmented. 109 stop_lower: Whether to stop at the lowest segmented slice. 110 stop_upper: Wheter to stop at the topmost segmented slice. 111 iou_threshold: The IOU threshold for continuing segmentation across 3d. 112 projection: The projection method to use. One of 'box', 'mask', 'points', 'points_and_mask' or 'single point'. 113 Pass a dictionary to choose the excact combination of projection modes. 114 update_progress: Callback to update an external progress bar. 115 box_extension: Extension factor for increasing the box size after projection. 116 verbose: Whether to print details about the segmentation steps. 117 118 Returns: 119 Array with the volumetric segmentation. 120 Tuple with the first and last segmented slice. 121 """ 122 use_box, use_mask, use_points, use_single_point = _validate_projection(projection) 123 124 if update_progress is None: 125 def update_progress(*args): 126 pass 127 128 def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None, verbose=False): 129 z = z_start + increment 130 while True: 131 if verbose: 132 print(f"Segment {z_start} to {z_stop}: segmenting slice {z}") 133 seg_prev = segmentation[z - increment] 134 seg_z, score, _ = segment_from_mask( 135 predictor, seg_prev, image_embeddings=image_embeddings, i=z, use_mask=use_mask, 136 use_box=use_box, use_points=use_points, box_extension=box_extension, return_all=True, 137 use_single_point=use_single_point, 138 ) 139 if threshold is not None: 140 iou = util.compute_iou(seg_prev, seg_z) 141 if iou < threshold: 142 if verbose: 143 msg = f"Segmentation stopped at slice {z} due to IOU {iou} < {threshold}." 144 print(msg) 145 break 146 147 segmentation[z] = seg_z 148 z += increment 149 if stopping_criterion(z, z_stop): 150 if verbose: 151 print(f"Segment {z_start} to {z_stop}: stop at slice {z}") 152 break 153 update_progress(1) 154 155 return z - increment 156 157 z0, z1 = int(segmented_slices.min()), int(segmented_slices.max()) 158 159 # segment below the min slice 160 if z0 > 0 and not stop_lower: 161 z_min = segment_range(z0, 0, -1, np.less, iou_threshold, verbose=verbose) 162 else: 163 z_min = z0 164 165 # segment above the max slice 166 if z1 < segmentation.shape[0] - 1 and not stop_upper: 167 z_max = segment_range(z1, segmentation.shape[0] - 1, 1, np.greater, iou_threshold, verbose=verbose) 168 else: 169 z_max = z1 170 171 # segment in between min and max slice 172 if z0 != z1: 173 for z_start, z_stop in zip(segmented_slices[:-1], segmented_slices[1:]): 174 slice_diff = z_stop - z_start 175 z_mid = int((z_start + z_stop) // 2) 176 177 if slice_diff == 1: # the slices are adjacent -> we don't need to do anything 178 pass 179 180 elif z_start == z0 and stop_lower: # the lower slice is stop: we just segment from upper 181 segment_range(z_stop, z_start, -1, np.less_equal, verbose=verbose) 182 183 elif z_stop == z1 and stop_upper: # the upper slice is stop: we just segment from lower 184 segment_range(z_start, z_stop, 1, np.greater_equal, verbose=verbose) 185 186 elif slice_diff == 2: # there is only one slice in between -> use combined mask 187 z = z_start + 1 188 seg_prompt = np.logical_or(segmentation[z_start] == 1, segmentation[z_stop] == 1) 189 segmentation[z] = segment_from_mask( 190 predictor, seg_prompt, image_embeddings=image_embeddings, i=z, 191 use_mask=use_mask, use_box=use_box, use_points=use_points, 192 box_extension=box_extension 193 ) 194 update_progress(1) 195 196 else: # there is a range of more than 2 slices in between -> segment ranges 197 # segment from bottom 198 segment_range( 199 z_start, z_mid, 1, np.greater_equal if slice_diff % 2 == 0 else np.greater, verbose=verbose 200 ) 201 # segment from top 202 segment_range(z_stop, z_mid, -1, np.less_equal, verbose=verbose) 203 # if the difference between start and stop is even, 204 # then we have a slice in the middle that is the same distance from top bottom 205 # in this case the slice is not segmented in the ranges above, and we segment it 206 # using the combined mask from the adjacent top and bottom slice as prompt 207 if slice_diff % 2 == 0: 208 seg_prompt = np.logical_or(segmentation[z_mid - 1] == 1, segmentation[z_mid + 1] == 1) 209 segmentation[z_mid] = segment_from_mask( 210 predictor, seg_prompt, image_embeddings=image_embeddings, i=z_mid, 211 use_mask=use_mask, use_box=use_box, use_points=use_points, 212 box_extension=box_extension 213 ) 214 update_progress(1) 215 216 return segmentation, (z_min, z_max) 217 218 219def _preprocess_closing(slice_segmentation, gap_closing, pbar_update): 220 binarized = slice_segmentation > 0 221 # Use a structuring element that only closes elements in z, to avoid merging objects in-plane. 222 structuring_element = np.zeros((3, 1, 1)) 223 structuring_element[:, 0, 0] = 1 224 closed_segmentation = binary_closing(binarized, iterations=gap_closing, structure=structuring_element) 225 226 new_segmentation = np.zeros_like(slice_segmentation) 227 n_slices = new_segmentation.shape[0] 228 229 def process_slice(z, offset): 230 seg_z = slice_segmentation[z] 231 232 # Closing does not work for the first and last gap slices 233 if z < gap_closing or z >= (n_slices - gap_closing): 234 seg_z, _, _ = relabel_sequential(seg_z, offset=offset) 235 offset = int(seg_z.max()) + 1 236 return seg_z, offset 237 238 # Apply connected components to the closed segmentation. 239 closed_z = label(closed_segmentation[z]) 240 241 # Map objects in the closed and initial segmentation. 242 # We take objects from the closed segmentation unless they 243 # have overlap with more than one object from the initial segmentation. 244 # This indicates wrong merging of closeby objects that we want to prevent. 245 matches = nifty.ground_truth.overlap(closed_z, seg_z) 246 matches = { 247 seg_id: matches.overlapArrays(seg_id, sorted=False)[0] for seg_id in range(1, int(closed_z.max() + 1)) 248 } 249 matches = {k: v[v != 0] for k, v in matches.items()} 250 251 ids_initial, ids_closed = [], [] 252 for seg_id, matched in matches.items(): 253 if len(matched) > 1: 254 ids_initial.extend(matched.tolist()) 255 else: 256 ids_closed.append(seg_id) 257 258 seg_new = np.zeros_like(seg_z) 259 closed_mask = np.isin(closed_z, ids_closed) 260 seg_new[closed_mask] = closed_z[closed_mask] 261 262 if ids_initial: 263 initial_mask = np.isin(seg_z, ids_initial) 264 seg_new[initial_mask] = relabel_sequential(seg_z[initial_mask], offset=seg_new.max() + 1)[0] 265 266 seg_new, _, _ = relabel_sequential(seg_new, offset=offset) 267 max_z = seg_new.max() 268 if max_z > 0: 269 offset = int(max_z) + 1 270 271 return seg_new, offset 272 273 # Further optimization: parallelize 274 offset = 1 275 for z in range(n_slices): 276 new_segmentation[z], offset = process_slice(z, offset) 277 pbar_update(1) 278 279 return new_segmentation 280 281 282def merge_instance_segmentation_3d( 283 slice_segmentation: np.ndarray, 284 beta: float = 0.5, 285 with_background: bool = True, 286 gap_closing: Optional[int] = None, 287 min_z_extent: Optional[int] = None, 288 verbose: bool = True, 289 pbar_init: Optional[callable] = None, 290 pbar_update: Optional[callable] = None, 291) -> np.ndarray: 292 """Merge stacked 2d instance segmentations into a consistent 3d segmentation. 293 294 Solves a multicut problem based on the overlap of objects to merge across z. 295 296 Args: 297 slice_segmentation: The stacked segmentation across the slices. 298 We assume that the segmentation is labeled consecutive across z. 299 beta: The bias term for the multicut. Higher values lead to a larger 300 degree of over-segmentation and vice versa. 301 with_background: Whether this is a segmentation problem with background. 302 In that case all edges connecting to the background are set to be repulsive. 303 gap_closing: If given, gaps in the segmentation are closed with a binary closing 304 operation. The value is used to determine the number of iterations for the closing. 305 min_z_extent: Require a minimal extent in z for the segmented objects. 306 This can help to prevent segmentation artifacts. 307 verbose: Verbosity flag. 308 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 309 Can be used together with pbar_update to handle napari progress bar in other thread. 310 To enables using this function within a threadworker. 311 pbar_update: Callback to update an external progress bar. 312 313 Returns: 314 The merged segmentation. 315 """ 316 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 317 318 if gap_closing is not None and gap_closing > 0: 319 pbar_init(slice_segmentation.shape[0] + 1, "Merge segmentation") 320 slice_segmentation = _preprocess_closing(slice_segmentation, gap_closing, pbar_update) 321 else: 322 pbar_init(1, "Merge segmentation") 323 324 # Extract the overlap between slices. 325 edges = track_utils.compute_edges_from_overlap(slice_segmentation, verbose=False) 326 327 uv_ids = np.array([[edge["source"], edge["target"]] for edge in edges]) 328 overlaps = np.array([edge["score"] for edge in edges]) 329 330 n_nodes = int(slice_segmentation.max() + 1) 331 graph = nifty.graph.undirectedGraph(n_nodes) 332 graph.insertEdges(uv_ids) 333 334 costs = seg_utils.multicut.compute_edge_costs(overlaps) 335 # set background weights to be maximally repulsive 336 if with_background: 337 bg_edges = (uv_ids == 0).any(axis=1) 338 costs[bg_edges] = -8.0 339 340 node_labels = seg_utils.multicut.multicut_decomposition(graph, 1.0 - costs, beta=beta) 341 342 segmentation = nifty.tools.take(node_labels, slice_segmentation) 343 344 if min_z_extent is not None and min_z_extent > 0: 345 props = regionprops(segmentation) 346 filter_ids = [] 347 for prop in props: 348 box = prop.bbox 349 z_extent = box[3] - box[0] 350 if z_extent < min_z_extent: 351 filter_ids.append(prop.label) 352 if filter_ids: 353 segmentation[np.isin(segmentation, filter_ids)] = 0 354 355 pbar_update(1) 356 pbar_close() 357 358 return segmentation 359 360 361def automatic_3d_segmentation( 362 volume: np.ndarray, 363 predictor: SamPredictor, 364 segmentor: AMGBase, 365 embedding_path: Optional[Union[str, os.PathLike]] = None, 366 with_background: bool = True, 367 gap_closing: Optional[int] = None, 368 min_z_extent: Optional[int] = None, 369 tile_shape: Optional[Tuple[int, int]] = None, 370 halo: Optional[Tuple[int, int]] = None, 371 verbose: bool = True, 372 return_embeddings: bool = False, 373 **kwargs, 374) -> np.ndarray: 375 """Segment volume in 3d. 376 377 First segments slices individually in 2d and then merges them across 3d 378 based on overlap of objects between slices. 379 380 Args: 381 volume: The input volume. 382 predictor: The SAM model. 383 segmentor: The instance segmentation class. 384 embedding_path: The path to save pre-computed embeddings. 385 with_background: Whether the segmentation has background. 386 gap_closing: If given, gaps in the segmentation are closed with a binary closing 387 operation. The value is used to determine the number of iterations for the closing. 388 min_z_extent: Require a minimal extent in z for the segmented objects. 389 This can help to prevent segmentation artifacts. 390 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 391 halo: Overlap of the tiles for tiled prediction. 392 verbose: Verbosity flag. 393 return_embeddings: Whether to return the precomputed image embeddings. 394 kwargs: Keyword arguments for the 'generate' method of the 'segmentor'. 395 396 Returns: 397 The segmentation. 398 """ 399 offset = 0 400 segmentation = np.zeros(volume.shape[:3], dtype="uint32") 401 402 min_object_size = kwargs.pop("min_object_size", 0) 403 image_embeddings = util.precompute_image_embeddings( 404 predictor=predictor, 405 input_=volume, 406 save_path=embedding_path, 407 ndim=3, 408 tile_shape=tile_shape, 409 halo=halo, 410 verbose=verbose, 411 ) 412 413 for i in tqdm(range(segmentation.shape[0]), desc="Segment slices", disable=not verbose): 414 segmentor.initialize(volume[i], image_embeddings=image_embeddings, verbose=False, i=i) 415 seg = segmentor.generate(**kwargs) 416 417 if isinstance(seg, list) and len(seg) == 0: 418 continue 419 else: 420 if isinstance(seg, list): 421 seg = mask_data_to_segmentation( 422 seg, with_background=with_background, min_object_size=min_object_size 423 ) 424 425 # Set offset for instance per slice. 426 max_z = seg.max() 427 if max_z == 0: 428 continue 429 seg[seg != 0] += offset 430 offset = max_z + offset 431 432 segmentation[i] = seg 433 434 segmentation = merge_instance_segmentation_3d( 435 segmentation, 436 beta=0.5, 437 with_background=with_background, 438 gap_closing=gap_closing, 439 min_z_extent=min_z_extent, 440 verbose=verbose, 441 ) 442 443 if return_embeddings: 444 return segmentation, image_embeddings 445 else: 446 return segmentation
PROJECTION_MODES =
('box', 'mask', 'points', 'points_and_mask', 'single_point')
def
segment_mask_in_volume( segmentation: numpy.ndarray, predictor: segment_anything.predictor.SamPredictor, image_embeddings: Dict[str, Any], segmented_slices: numpy.ndarray, stop_lower: bool, stop_upper: bool, iou_threshold: float, projection: Union[str, dict], update_progress: Optional[<built-in function callable>] = None, box_extension: float = 0.0, verbose: bool = False) -> Tuple[numpy.ndarray, Tuple[int, int]]:
90def segment_mask_in_volume( 91 segmentation: np.ndarray, 92 predictor: SamPredictor, 93 image_embeddings: util.ImageEmbeddings, 94 segmented_slices: np.ndarray, 95 stop_lower: bool, 96 stop_upper: bool, 97 iou_threshold: float, 98 projection: Union[str, dict], 99 update_progress: Optional[callable] = None, 100 box_extension: float = 0.0, 101 verbose: bool = False, 102) -> Tuple[np.ndarray, Tuple[int, int]]: 103 """Segment an object mask in in volumetric data. 104 105 Args: 106 segmentation: The initial segmentation for the object. 107 predictor: The segment anything predictor. 108 image_embeddings: The precomputed image embeddings for the volume. 109 segmented_slices: List of slices for which this object has already been segmented. 110 stop_lower: Whether to stop at the lowest segmented slice. 111 stop_upper: Wheter to stop at the topmost segmented slice. 112 iou_threshold: The IOU threshold for continuing segmentation across 3d. 113 projection: The projection method to use. One of 'box', 'mask', 'points', 'points_and_mask' or 'single point'. 114 Pass a dictionary to choose the excact combination of projection modes. 115 update_progress: Callback to update an external progress bar. 116 box_extension: Extension factor for increasing the box size after projection. 117 verbose: Whether to print details about the segmentation steps. 118 119 Returns: 120 Array with the volumetric segmentation. 121 Tuple with the first and last segmented slice. 122 """ 123 use_box, use_mask, use_points, use_single_point = _validate_projection(projection) 124 125 if update_progress is None: 126 def update_progress(*args): 127 pass 128 129 def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None, verbose=False): 130 z = z_start + increment 131 while True: 132 if verbose: 133 print(f"Segment {z_start} to {z_stop}: segmenting slice {z}") 134 seg_prev = segmentation[z - increment] 135 seg_z, score, _ = segment_from_mask( 136 predictor, seg_prev, image_embeddings=image_embeddings, i=z, use_mask=use_mask, 137 use_box=use_box, use_points=use_points, box_extension=box_extension, return_all=True, 138 use_single_point=use_single_point, 139 ) 140 if threshold is not None: 141 iou = util.compute_iou(seg_prev, seg_z) 142 if iou < threshold: 143 if verbose: 144 msg = f"Segmentation stopped at slice {z} due to IOU {iou} < {threshold}." 145 print(msg) 146 break 147 148 segmentation[z] = seg_z 149 z += increment 150 if stopping_criterion(z, z_stop): 151 if verbose: 152 print(f"Segment {z_start} to {z_stop}: stop at slice {z}") 153 break 154 update_progress(1) 155 156 return z - increment 157 158 z0, z1 = int(segmented_slices.min()), int(segmented_slices.max()) 159 160 # segment below the min slice 161 if z0 > 0 and not stop_lower: 162 z_min = segment_range(z0, 0, -1, np.less, iou_threshold, verbose=verbose) 163 else: 164 z_min = z0 165 166 # segment above the max slice 167 if z1 < segmentation.shape[0] - 1 and not stop_upper: 168 z_max = segment_range(z1, segmentation.shape[0] - 1, 1, np.greater, iou_threshold, verbose=verbose) 169 else: 170 z_max = z1 171 172 # segment in between min and max slice 173 if z0 != z1: 174 for z_start, z_stop in zip(segmented_slices[:-1], segmented_slices[1:]): 175 slice_diff = z_stop - z_start 176 z_mid = int((z_start + z_stop) // 2) 177 178 if slice_diff == 1: # the slices are adjacent -> we don't need to do anything 179 pass 180 181 elif z_start == z0 and stop_lower: # the lower slice is stop: we just segment from upper 182 segment_range(z_stop, z_start, -1, np.less_equal, verbose=verbose) 183 184 elif z_stop == z1 and stop_upper: # the upper slice is stop: we just segment from lower 185 segment_range(z_start, z_stop, 1, np.greater_equal, verbose=verbose) 186 187 elif slice_diff == 2: # there is only one slice in between -> use combined mask 188 z = z_start + 1 189 seg_prompt = np.logical_or(segmentation[z_start] == 1, segmentation[z_stop] == 1) 190 segmentation[z] = segment_from_mask( 191 predictor, seg_prompt, image_embeddings=image_embeddings, i=z, 192 use_mask=use_mask, use_box=use_box, use_points=use_points, 193 box_extension=box_extension 194 ) 195 update_progress(1) 196 197 else: # there is a range of more than 2 slices in between -> segment ranges 198 # segment from bottom 199 segment_range( 200 z_start, z_mid, 1, np.greater_equal if slice_diff % 2 == 0 else np.greater, verbose=verbose 201 ) 202 # segment from top 203 segment_range(z_stop, z_mid, -1, np.less_equal, verbose=verbose) 204 # if the difference between start and stop is even, 205 # then we have a slice in the middle that is the same distance from top bottom 206 # in this case the slice is not segmented in the ranges above, and we segment it 207 # using the combined mask from the adjacent top and bottom slice as prompt 208 if slice_diff % 2 == 0: 209 seg_prompt = np.logical_or(segmentation[z_mid - 1] == 1, segmentation[z_mid + 1] == 1) 210 segmentation[z_mid] = segment_from_mask( 211 predictor, seg_prompt, image_embeddings=image_embeddings, i=z_mid, 212 use_mask=use_mask, use_box=use_box, use_points=use_points, 213 box_extension=box_extension 214 ) 215 update_progress(1) 216 217 return segmentation, (z_min, z_max)
Segment an object mask in in volumetric data.
Arguments:
- segmentation: The initial segmentation for the object.
- predictor: The segment anything predictor.
- image_embeddings: The precomputed image embeddings for the volume.
- segmented_slices: List of slices for which this object has already been segmented.
- stop_lower: Whether to stop at the lowest segmented slice.
- stop_upper: Wheter to stop at the topmost segmented slice.
- iou_threshold: The IOU threshold for continuing segmentation across 3d.
- projection: The projection method to use. One of 'box', 'mask', 'points', 'points_and_mask' or 'single point'. Pass a dictionary to choose the excact combination of projection modes.
- update_progress: Callback to update an external progress bar.
- box_extension: Extension factor for increasing the box size after projection.
- verbose: Whether to print details about the segmentation steps.
Returns:
Array with the volumetric segmentation. Tuple with the first and last segmented slice.
def
merge_instance_segmentation_3d( slice_segmentation: numpy.ndarray, beta: float = 0.5, with_background: bool = True, gap_closing: Optional[int] = None, min_z_extent: Optional[int] = None, verbose: bool = True, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None) -> numpy.ndarray:
283def merge_instance_segmentation_3d( 284 slice_segmentation: np.ndarray, 285 beta: float = 0.5, 286 with_background: bool = True, 287 gap_closing: Optional[int] = None, 288 min_z_extent: Optional[int] = None, 289 verbose: bool = True, 290 pbar_init: Optional[callable] = None, 291 pbar_update: Optional[callable] = None, 292) -> np.ndarray: 293 """Merge stacked 2d instance segmentations into a consistent 3d segmentation. 294 295 Solves a multicut problem based on the overlap of objects to merge across z. 296 297 Args: 298 slice_segmentation: The stacked segmentation across the slices. 299 We assume that the segmentation is labeled consecutive across z. 300 beta: The bias term for the multicut. Higher values lead to a larger 301 degree of over-segmentation and vice versa. 302 with_background: Whether this is a segmentation problem with background. 303 In that case all edges connecting to the background are set to be repulsive. 304 gap_closing: If given, gaps in the segmentation are closed with a binary closing 305 operation. The value is used to determine the number of iterations for the closing. 306 min_z_extent: Require a minimal extent in z for the segmented objects. 307 This can help to prevent segmentation artifacts. 308 verbose: Verbosity flag. 309 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 310 Can be used together with pbar_update to handle napari progress bar in other thread. 311 To enables using this function within a threadworker. 312 pbar_update: Callback to update an external progress bar. 313 314 Returns: 315 The merged segmentation. 316 """ 317 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 318 319 if gap_closing is not None and gap_closing > 0: 320 pbar_init(slice_segmentation.shape[0] + 1, "Merge segmentation") 321 slice_segmentation = _preprocess_closing(slice_segmentation, gap_closing, pbar_update) 322 else: 323 pbar_init(1, "Merge segmentation") 324 325 # Extract the overlap between slices. 326 edges = track_utils.compute_edges_from_overlap(slice_segmentation, verbose=False) 327 328 uv_ids = np.array([[edge["source"], edge["target"]] for edge in edges]) 329 overlaps = np.array([edge["score"] for edge in edges]) 330 331 n_nodes = int(slice_segmentation.max() + 1) 332 graph = nifty.graph.undirectedGraph(n_nodes) 333 graph.insertEdges(uv_ids) 334 335 costs = seg_utils.multicut.compute_edge_costs(overlaps) 336 # set background weights to be maximally repulsive 337 if with_background: 338 bg_edges = (uv_ids == 0).any(axis=1) 339 costs[bg_edges] = -8.0 340 341 node_labels = seg_utils.multicut.multicut_decomposition(graph, 1.0 - costs, beta=beta) 342 343 segmentation = nifty.tools.take(node_labels, slice_segmentation) 344 345 if min_z_extent is not None and min_z_extent > 0: 346 props = regionprops(segmentation) 347 filter_ids = [] 348 for prop in props: 349 box = prop.bbox 350 z_extent = box[3] - box[0] 351 if z_extent < min_z_extent: 352 filter_ids.append(prop.label) 353 if filter_ids: 354 segmentation[np.isin(segmentation, filter_ids)] = 0 355 356 pbar_update(1) 357 pbar_close() 358 359 return segmentation
Merge stacked 2d instance segmentations into a consistent 3d segmentation.
Solves a multicut problem based on the overlap of objects to merge across z.
Arguments:
- slice_segmentation: The stacked segmentation across the slices. We assume that the segmentation is labeled consecutive across z.
- beta: The bias term for the multicut. Higher values lead to a larger degree of over-segmentation and vice versa.
- with_background: Whether this is a segmentation problem with background. In that case all edges connecting to the background are set to be repulsive.
- gap_closing: If given, gaps in the segmentation are closed with a binary closing operation. The value is used to determine the number of iterations for the closing.
- min_z_extent: Require a minimal extent in z for the segmented objects. This can help to prevent segmentation artifacts.
- verbose: Verbosity flag.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
Returns:
The merged segmentation.
def
automatic_3d_segmentation( volume: numpy.ndarray, predictor: segment_anything.predictor.SamPredictor, segmentor: micro_sam.instance_segmentation.AMGBase, embedding_path: Union[str, os.PathLike, NoneType] = None, with_background: bool = True, gap_closing: Optional[int] = None, min_z_extent: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, return_embeddings: bool = False, **kwargs) -> numpy.ndarray:
362def automatic_3d_segmentation( 363 volume: np.ndarray, 364 predictor: SamPredictor, 365 segmentor: AMGBase, 366 embedding_path: Optional[Union[str, os.PathLike]] = None, 367 with_background: bool = True, 368 gap_closing: Optional[int] = None, 369 min_z_extent: Optional[int] = None, 370 tile_shape: Optional[Tuple[int, int]] = None, 371 halo: Optional[Tuple[int, int]] = None, 372 verbose: bool = True, 373 return_embeddings: bool = False, 374 **kwargs, 375) -> np.ndarray: 376 """Segment volume in 3d. 377 378 First segments slices individually in 2d and then merges them across 3d 379 based on overlap of objects between slices. 380 381 Args: 382 volume: The input volume. 383 predictor: The SAM model. 384 segmentor: The instance segmentation class. 385 embedding_path: The path to save pre-computed embeddings. 386 with_background: Whether the segmentation has background. 387 gap_closing: If given, gaps in the segmentation are closed with a binary closing 388 operation. The value is used to determine the number of iterations for the closing. 389 min_z_extent: Require a minimal extent in z for the segmented objects. 390 This can help to prevent segmentation artifacts. 391 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 392 halo: Overlap of the tiles for tiled prediction. 393 verbose: Verbosity flag. 394 return_embeddings: Whether to return the precomputed image embeddings. 395 kwargs: Keyword arguments for the 'generate' method of the 'segmentor'. 396 397 Returns: 398 The segmentation. 399 """ 400 offset = 0 401 segmentation = np.zeros(volume.shape[:3], dtype="uint32") 402 403 min_object_size = kwargs.pop("min_object_size", 0) 404 image_embeddings = util.precompute_image_embeddings( 405 predictor=predictor, 406 input_=volume, 407 save_path=embedding_path, 408 ndim=3, 409 tile_shape=tile_shape, 410 halo=halo, 411 verbose=verbose, 412 ) 413 414 for i in tqdm(range(segmentation.shape[0]), desc="Segment slices", disable=not verbose): 415 segmentor.initialize(volume[i], image_embeddings=image_embeddings, verbose=False, i=i) 416 seg = segmentor.generate(**kwargs) 417 418 if isinstance(seg, list) and len(seg) == 0: 419 continue 420 else: 421 if isinstance(seg, list): 422 seg = mask_data_to_segmentation( 423 seg, with_background=with_background, min_object_size=min_object_size 424 ) 425 426 # Set offset for instance per slice. 427 max_z = seg.max() 428 if max_z == 0: 429 continue 430 seg[seg != 0] += offset 431 offset = max_z + offset 432 433 segmentation[i] = seg 434 435 segmentation = merge_instance_segmentation_3d( 436 segmentation, 437 beta=0.5, 438 with_background=with_background, 439 gap_closing=gap_closing, 440 min_z_extent=min_z_extent, 441 verbose=verbose, 442 ) 443 444 if return_embeddings: 445 return segmentation, image_embeddings 446 else: 447 return segmentation
Segment volume in 3d.
First segments slices individually in 2d and then merges them across 3d based on overlap of objects between slices.
Arguments:
- volume: The input volume.
- predictor: The SAM model.
- segmentor: The instance segmentation class.
- embedding_path: The path to save pre-computed embeddings.
- with_background: Whether the segmentation has background.
- gap_closing: If given, gaps in the segmentation are closed with a binary closing operation. The value is used to determine the number of iterations for the closing.
- min_z_extent: Require a minimal extent in z for the segmented objects. This can help to prevent segmentation artifacts.
- tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
- halo: Overlap of the tiles for tiled prediction.
- verbose: Verbosity flag.
- return_embeddings: Whether to return the precomputed image embeddings.
- kwargs: Keyword arguments for the 'generate' method of the 'segmentor'.
Returns:
The segmentation.