micro_sam.multi_dimensional_segmentation
Multi-dimensional segmentation with segment anything.
1"""Multi-dimensional segmentation with segment anything. 2""" 3 4import os 5from typing import Optional, Union, Tuple 6 7import numpy as np 8 9import nifty 10 11import elf.segmentation as seg_utils 12import elf.tracking.tracking_utils as track_utils 13 14from scipy.ndimage import binary_closing 15from skimage.measure import label, regionprops 16from skimage.segmentation import relabel_sequential 17 18from segment_anything.predictor import SamPredictor 19 20try: 21 from napari.utils import progress as tqdm 22except ImportError: 23 from tqdm import tqdm 24 25from . import util 26from .prompt_based_segmentation import segment_from_mask 27from .instance_segmentation import AMGBase, mask_data_to_segmentation 28 29PROJECTION_MODES = ("box", "mask", "points", "points_and_mask", "single_point") 30 31 32def _validate_projection(projection): 33 use_single_point = False 34 if isinstance(projection, str): 35 if projection == "mask": 36 use_box, use_mask, use_points = True, True, False 37 elif projection == "points": 38 use_box, use_mask, use_points = False, False, True 39 elif projection == "box": 40 use_box, use_mask, use_points = True, False, False 41 elif projection == "points_and_mask": 42 use_box, use_mask, use_points = False, True, True 43 elif projection == "single_point": 44 use_box, use_mask, use_points = False, False, True 45 use_single_point = True 46 else: 47 raise ValueError( 48 "Choose projection method from 'mask' / 'points' / 'box' / 'points_and_mask' / 'single_point'. " 49 f"You have passed the invalid option {projection}." 50 ) 51 elif isinstance(projection, dict): 52 assert len(projection.keys()) == 3, "There should be three parameters assigned for the projection method." 53 use_box, use_mask, use_points = projection["use_box"], projection["use_mask"], projection["use_points"] 54 else: 55 raise ValueError(f"{projection} is not a supported projection method.") 56 return use_box, use_mask, use_points, use_single_point 57 58 59# Advanced stopping criterions. 60# In practice these did not make a big difference, so we do not use this at the moment. 61# We still leave it here for reference. 62def _advanced_stopping_criteria( 63 z, seg_z, seg_prev, z_start, z_increment, segmentation, criterion_choice, score, increment 64): 65 def _compute_mean_iou_for_n_slices(z, increment, seg_z, n_slices): 66 iou_list = [ 67 util.compute_iou(segmentation[z - increment * _slice], seg_z) for _slice in range(1, n_slices+1) 68 ] 69 return np.mean(iou_list) 70 71 if criterion_choice == 1: 72 # 1. current metric: iou of current segmentation and the previous slice 73 iou = util.compute_iou(seg_prev, seg_z) 74 criterion = iou 75 76 elif criterion_choice == 2: 77 # 2. combining SAM iou + iou: curr. slice & first segmented slice + iou: curr. slice vs prev. slice 78 iou = util.compute_iou(seg_prev, seg_z) 79 ff_iou = util.compute_iou(segmentation[z_start], seg_z) 80 criterion = 0.5 * iou + 0.3 * score + 0.2 * ff_iou 81 82 elif criterion_choice == 3: 83 # 3. iou of current segmented slice w.r.t the previous n slices 84 criterion = _compute_mean_iou_for_n_slices(z, increment, seg_z, min(5, abs(z - z_start))) 85 86 return criterion 87 88 89def segment_mask_in_volume( 90 segmentation: np.ndarray, 91 predictor: SamPredictor, 92 image_embeddings: util.ImageEmbeddings, 93 segmented_slices: np.ndarray, 94 stop_lower: bool, 95 stop_upper: bool, 96 iou_threshold: float, 97 projection: Union[str, dict], 98 update_progress: Optional[callable] = None, 99 box_extension: float = 0.0, 100 verbose: bool = False, 101) -> Tuple[np.ndarray, Tuple[int, int]]: 102 """Segment an object mask in in volumetric data. 103 104 Args: 105 segmentation: The initial segmentation for the object. 106 predictor: The segment anything predictor. 107 image_embeddings: The precomputed image embeddings for the volume. 108 segmented_slices: List of slices for which this object has already been segmented. 109 stop_lower: Whether to stop at the lowest segmented slice. 110 stop_upper: Wheter to stop at the topmost segmented slice. 111 iou_threshold: The IOU threshold for continuing segmentation across 3d. 112 projection: The projection method to use. One of 'box', 'mask', 'points', 'points_and_mask' or 'single point'. 113 Pass a dictionary to choose the excact combination of projection modes. 114 update_progress: Callback to update an external progress bar. 115 box_extension: Extension factor for increasing the box size after projection. 116 verbose: Whether to print details about the segmentation steps. 117 118 Returns: 119 Array with the volumetric segmentation. 120 Tuple with the first and last segmented slice. 121 """ 122 use_box, use_mask, use_points, use_single_point = _validate_projection(projection) 123 124 if update_progress is None: 125 def update_progress(*args): 126 pass 127 128 def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None, verbose=False): 129 z = z_start + increment 130 while True: 131 if verbose: 132 print(f"Segment {z_start} to {z_stop}: segmenting slice {z}") 133 seg_prev = segmentation[z - increment] 134 seg_z, score, _ = segment_from_mask( 135 predictor, seg_prev, image_embeddings=image_embeddings, i=z, use_mask=use_mask, 136 use_box=use_box, use_points=use_points, box_extension=box_extension, return_all=True, 137 use_single_point=use_single_point, 138 ) 139 if threshold is not None: 140 iou = util.compute_iou(seg_prev, seg_z) 141 if iou < threshold: 142 if verbose: 143 msg = f"Segmentation stopped at slice {z} due to IOU {iou} < {threshold}." 144 print(msg) 145 break 146 147 segmentation[z] = seg_z 148 z += increment 149 if stopping_criterion(z, z_stop): 150 if verbose: 151 print(f"Segment {z_start} to {z_stop}: stop at slice {z}") 152 break 153 update_progress(1) 154 155 return z - increment 156 157 z0, z1 = int(segmented_slices.min()), int(segmented_slices.max()) 158 159 # segment below the min slice 160 if z0 > 0 and not stop_lower: 161 z_min = segment_range(z0, 0, -1, np.less, iou_threshold, verbose=verbose) 162 else: 163 z_min = z0 164 165 # segment above the max slice 166 if z1 < segmentation.shape[0] - 1 and not stop_upper: 167 z_max = segment_range(z1, segmentation.shape[0] - 1, 1, np.greater, iou_threshold, verbose=verbose) 168 else: 169 z_max = z1 170 171 # segment in between min and max slice 172 if z0 != z1: 173 for z_start, z_stop in zip(segmented_slices[:-1], segmented_slices[1:]): 174 slice_diff = z_stop - z_start 175 z_mid = int((z_start + z_stop) // 2) 176 177 if slice_diff == 1: # the slices are adjacent -> we don't need to do anything 178 pass 179 180 elif z_start == z0 and stop_lower: # the lower slice is stop: we just segment from upper 181 segment_range(z_stop, z_start, -1, np.less_equal, verbose=verbose) 182 183 elif z_stop == z1 and stop_upper: # the upper slice is stop: we just segment from lower 184 segment_range(z_start, z_stop, 1, np.greater_equal, verbose=verbose) 185 186 elif slice_diff == 2: # there is only one slice in between -> use combined mask 187 z = z_start + 1 188 seg_prompt = np.logical_or(segmentation[z_start] == 1, segmentation[z_stop] == 1) 189 segmentation[z] = segment_from_mask( 190 predictor, seg_prompt, image_embeddings=image_embeddings, i=z, 191 use_mask=use_mask, use_box=use_box, use_points=use_points, 192 box_extension=box_extension 193 ) 194 update_progress(1) 195 196 else: # there is a range of more than 2 slices in between -> segment ranges 197 # segment from bottom 198 segment_range( 199 z_start, z_mid, 1, np.greater_equal if slice_diff % 2 == 0 else np.greater, verbose=verbose 200 ) 201 # segment from top 202 segment_range(z_stop, z_mid, -1, np.less_equal, verbose=verbose) 203 # if the difference between start and stop is even, 204 # then we have a slice in the middle that is the same distance from top bottom 205 # in this case the slice is not segmented in the ranges above, and we segment it 206 # using the combined mask from the adjacent top and bottom slice as prompt 207 if slice_diff % 2 == 0: 208 seg_prompt = np.logical_or(segmentation[z_mid - 1] == 1, segmentation[z_mid + 1] == 1) 209 segmentation[z_mid] = segment_from_mask( 210 predictor, seg_prompt, image_embeddings=image_embeddings, i=z_mid, 211 use_mask=use_mask, use_box=use_box, use_points=use_points, 212 box_extension=box_extension 213 ) 214 update_progress(1) 215 216 return segmentation, (z_min, z_max) 217 218 219def _preprocess_closing(slice_segmentation, gap_closing, pbar_update): 220 binarized = slice_segmentation > 0 221 # Use a structuring element that only closes elements in z, to avoid merging objects in-plane. 222 structuring_element = np.zeros((3, 1, 1)) 223 structuring_element[:, 0, 0] = 1 224 closed_segmentation = binary_closing(binarized, iterations=gap_closing, structure=structuring_element) 225 226 new_segmentation = np.zeros_like(slice_segmentation) 227 n_slices = new_segmentation.shape[0] 228 229 def process_slice(z, offset): 230 seg_z = slice_segmentation[z] 231 232 # Closing does not work for the first and last gap slices 233 if z < gap_closing or z >= (n_slices - gap_closing): 234 seg_z, _, _ = relabel_sequential(seg_z, offset=offset) 235 offset = int(seg_z.max()) + 1 236 return seg_z, offset 237 238 # Apply connected components to the closed segmentation. 239 closed_z = label(closed_segmentation[z]) 240 241 # Map objects in the closed and initial segmentation. 242 # We take objects from the closed segmentation unless they 243 # have overlap with more than one object from the initial segmentation. 244 # This indicates wrong merging of closeby objects that we want to prevent. 245 matches = nifty.ground_truth.overlap(closed_z, seg_z) 246 matches = {seg_id: matches.overlapArrays(seg_id, sorted=False)[0] 247 for seg_id in range(1, int(closed_z.max() + 1))} 248 matches = {k: v[v != 0] for k, v in matches.items()} 249 250 ids_initial, ids_closed = [], [] 251 for seg_id, matched in matches.items(): 252 if len(matched) > 1: 253 ids_initial.extend(matched.tolist()) 254 else: 255 ids_closed.append(seg_id) 256 257 seg_new = np.zeros_like(seg_z) 258 closed_mask = np.isin(closed_z, ids_closed) 259 seg_new[closed_mask] = closed_z[closed_mask] 260 261 if ids_initial: 262 initial_mask = np.isin(seg_z, ids_initial) 263 seg_new[initial_mask] = relabel_sequential(seg_z[initial_mask], offset=seg_new.max() + 1)[0] 264 265 seg_new, _, _ = relabel_sequential(seg_new, offset=offset) 266 max_z = seg_new.max() 267 if max_z > 0: 268 offset = int(max_z) + 1 269 270 return seg_new, offset 271 272 # Further optimization: parallelize 273 offset = 1 274 for z in range(n_slices): 275 new_segmentation[z], offset = process_slice(z, offset) 276 pbar_update(1) 277 278 return new_segmentation 279 280 281def merge_instance_segmentation_3d( 282 slice_segmentation: np.ndarray, 283 beta: float = 0.5, 284 with_background: bool = True, 285 gap_closing: Optional[int] = None, 286 min_z_extent: Optional[int] = None, 287 verbose: bool = True, 288 pbar_init: Optional[callable] = None, 289 pbar_update: Optional[callable] = None, 290) -> np.ndarray: 291 """Merge stacked 2d instance segmentations into a consistent 3d segmentation. 292 293 Solves a multicut problem based on the overlap of objects to merge across z. 294 295 Args: 296 slice_segmentation: The stacked segmentation across the slices. 297 We assume that the segmentation is labeled consecutive across z. 298 beta: The bias term for the multicut. Higher values lead to a larger 299 degree of over-segmentation and vice versa. 300 with_background: Whether this is a segmentation problem with background. 301 In that case all edges connecting to the background are set to be repulsive. 302 gap_closing: If given, gaps in the segmentation are closed with a binary closing 303 operation. The value is used to determine the number of iterations for the closing. 304 min_z_extent: Require a minimal extent in z for the segmented objects. 305 This can help to prevent segmentation artifacts. 306 verbose: Verbosity flag. 307 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 308 Can be used together with pbar_update to handle napari progress bar in other thread. 309 To enables using this function within a threadworker. 310 pbar_update: Callback to update an external progress bar. 311 312 Returns: 313 The merged segmentation. 314 """ 315 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 316 317 if gap_closing is not None and gap_closing > 0: 318 pbar_init(slice_segmentation.shape[0] + 1, "Merge segmentation") 319 slice_segmentation = _preprocess_closing(slice_segmentation, gap_closing, pbar_update) 320 else: 321 pbar_init(1, "Merge segmentation") 322 323 # Extract the overlap between slices. 324 edges = track_utils.compute_edges_from_overlap(slice_segmentation, verbose=False) 325 326 uv_ids = np.array([[edge["source"], edge["target"]] for edge in edges]) 327 overlaps = np.array([edge["score"] for edge in edges]) 328 329 n_nodes = int(slice_segmentation.max() + 1) 330 graph = nifty.graph.undirectedGraph(n_nodes) 331 graph.insertEdges(uv_ids) 332 333 costs = seg_utils.multicut.compute_edge_costs(overlaps) 334 # set background weights to be maximally repulsive 335 if with_background: 336 bg_edges = (uv_ids == 0).any(axis=1) 337 costs[bg_edges] = -8.0 338 339 node_labels = seg_utils.multicut.multicut_decomposition(graph, 1.0 - costs, beta=beta) 340 341 segmentation = nifty.tools.take(node_labels, slice_segmentation) 342 343 if min_z_extent is not None and min_z_extent > 0: 344 props = regionprops(segmentation) 345 filter_ids = [] 346 for prop in props: 347 box = prop.bbox 348 z_extent = box[3] - box[0] 349 if z_extent < min_z_extent: 350 filter_ids.append(prop.label) 351 if filter_ids: 352 segmentation[np.isin(segmentation, filter_ids)] = 0 353 354 pbar_update(1) 355 pbar_close() 356 357 return segmentation 358 359 360def automatic_3d_segmentation( 361 volume: np.ndarray, 362 predictor: SamPredictor, 363 segmentor: AMGBase, 364 embedding_path: Optional[Union[str, os.PathLike]] = None, 365 with_background: bool = True, 366 gap_closing: Optional[int] = None, 367 min_z_extent: Optional[int] = None, 368 tile_shape: Optional[Tuple[int, int]] = None, 369 halo: Optional[Tuple[int, int]] = None, 370 verbose: bool = True, 371 **kwargs, 372) -> np.ndarray: 373 """Segment volume in 3d. 374 375 First segments slices individually in 2d and then merges them across 3d 376 based on overlap of objects between slices. 377 378 Args: 379 volume: The input volume. 380 predictor: The SAM model. 381 segmentor: The instance segmentation class. 382 embedding_path: The path to save pre-computed embeddings. 383 with_background: Whether the segmentation has background. 384 gap_closing: If given, gaps in the segmentation are closed with a binary closing 385 operation. The value is used to determine the number of iterations for the closing. 386 min_z_extent: Require a minimal extent in z for the segmented objects. 387 This can help to prevent segmentation artifacts. 388 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 389 halo: Overlap of the tiles for tiled prediction. 390 verbose: Verbosity flag. 391 kwargs: Keyword arguments for the 'generate' method of the 'segmentor'. 392 393 Returns: 394 The segmentation. 395 """ 396 offset = 0 397 segmentation = np.zeros(volume.shape, dtype="uint32") 398 399 min_object_size = kwargs.pop("min_object_size", 0) 400 image_embeddings = util.precompute_image_embeddings( 401 predictor=predictor, 402 input_=volume, 403 save_path=embedding_path, 404 ndim=3, 405 tile_shape=tile_shape, 406 halo=halo, 407 verbose=verbose, 408 ) 409 410 for i in tqdm(range(segmentation.shape[0]), desc="Segment slices", disable=not verbose): 411 segmentor.initialize(volume[i], image_embeddings=image_embeddings, verbose=False, i=i) 412 seg = segmentor.generate(**kwargs) 413 if len(seg) == 0: 414 continue 415 else: 416 seg = mask_data_to_segmentation(seg, with_background=with_background, min_object_size=min_object_size) 417 max_z = seg.max() 418 if max_z == 0: 419 continue 420 seg[seg != 0] += offset 421 offset = max_z + offset 422 segmentation[i] = seg 423 424 segmentation = merge_instance_segmentation_3d( 425 segmentation, 426 beta=0.5, 427 with_background=with_background, 428 gap_closing=gap_closing, 429 min_z_extent=min_z_extent, 430 verbose=verbose, 431 ) 432 433 return segmentation
PROJECTION_MODES =
('box', 'mask', 'points', 'points_and_mask', 'single_point')
def
segment_mask_in_volume( segmentation: numpy.ndarray, predictor: segment_anything.predictor.SamPredictor, image_embeddings: Dict[str, Any], segmented_slices: numpy.ndarray, stop_lower: bool, stop_upper: bool, iou_threshold: float, projection: Union[str, dict], update_progress: Optional[<built-in function callable>] = None, box_extension: float = 0.0, verbose: bool = False) -> Tuple[numpy.ndarray, Tuple[int, int]]:
90def segment_mask_in_volume( 91 segmentation: np.ndarray, 92 predictor: SamPredictor, 93 image_embeddings: util.ImageEmbeddings, 94 segmented_slices: np.ndarray, 95 stop_lower: bool, 96 stop_upper: bool, 97 iou_threshold: float, 98 projection: Union[str, dict], 99 update_progress: Optional[callable] = None, 100 box_extension: float = 0.0, 101 verbose: bool = False, 102) -> Tuple[np.ndarray, Tuple[int, int]]: 103 """Segment an object mask in in volumetric data. 104 105 Args: 106 segmentation: The initial segmentation for the object. 107 predictor: The segment anything predictor. 108 image_embeddings: The precomputed image embeddings for the volume. 109 segmented_slices: List of slices for which this object has already been segmented. 110 stop_lower: Whether to stop at the lowest segmented slice. 111 stop_upper: Wheter to stop at the topmost segmented slice. 112 iou_threshold: The IOU threshold for continuing segmentation across 3d. 113 projection: The projection method to use. One of 'box', 'mask', 'points', 'points_and_mask' or 'single point'. 114 Pass a dictionary to choose the excact combination of projection modes. 115 update_progress: Callback to update an external progress bar. 116 box_extension: Extension factor for increasing the box size after projection. 117 verbose: Whether to print details about the segmentation steps. 118 119 Returns: 120 Array with the volumetric segmentation. 121 Tuple with the first and last segmented slice. 122 """ 123 use_box, use_mask, use_points, use_single_point = _validate_projection(projection) 124 125 if update_progress is None: 126 def update_progress(*args): 127 pass 128 129 def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None, verbose=False): 130 z = z_start + increment 131 while True: 132 if verbose: 133 print(f"Segment {z_start} to {z_stop}: segmenting slice {z}") 134 seg_prev = segmentation[z - increment] 135 seg_z, score, _ = segment_from_mask( 136 predictor, seg_prev, image_embeddings=image_embeddings, i=z, use_mask=use_mask, 137 use_box=use_box, use_points=use_points, box_extension=box_extension, return_all=True, 138 use_single_point=use_single_point, 139 ) 140 if threshold is not None: 141 iou = util.compute_iou(seg_prev, seg_z) 142 if iou < threshold: 143 if verbose: 144 msg = f"Segmentation stopped at slice {z} due to IOU {iou} < {threshold}." 145 print(msg) 146 break 147 148 segmentation[z] = seg_z 149 z += increment 150 if stopping_criterion(z, z_stop): 151 if verbose: 152 print(f"Segment {z_start} to {z_stop}: stop at slice {z}") 153 break 154 update_progress(1) 155 156 return z - increment 157 158 z0, z1 = int(segmented_slices.min()), int(segmented_slices.max()) 159 160 # segment below the min slice 161 if z0 > 0 and not stop_lower: 162 z_min = segment_range(z0, 0, -1, np.less, iou_threshold, verbose=verbose) 163 else: 164 z_min = z0 165 166 # segment above the max slice 167 if z1 < segmentation.shape[0] - 1 and not stop_upper: 168 z_max = segment_range(z1, segmentation.shape[0] - 1, 1, np.greater, iou_threshold, verbose=verbose) 169 else: 170 z_max = z1 171 172 # segment in between min and max slice 173 if z0 != z1: 174 for z_start, z_stop in zip(segmented_slices[:-1], segmented_slices[1:]): 175 slice_diff = z_stop - z_start 176 z_mid = int((z_start + z_stop) // 2) 177 178 if slice_diff == 1: # the slices are adjacent -> we don't need to do anything 179 pass 180 181 elif z_start == z0 and stop_lower: # the lower slice is stop: we just segment from upper 182 segment_range(z_stop, z_start, -1, np.less_equal, verbose=verbose) 183 184 elif z_stop == z1 and stop_upper: # the upper slice is stop: we just segment from lower 185 segment_range(z_start, z_stop, 1, np.greater_equal, verbose=verbose) 186 187 elif slice_diff == 2: # there is only one slice in between -> use combined mask 188 z = z_start + 1 189 seg_prompt = np.logical_or(segmentation[z_start] == 1, segmentation[z_stop] == 1) 190 segmentation[z] = segment_from_mask( 191 predictor, seg_prompt, image_embeddings=image_embeddings, i=z, 192 use_mask=use_mask, use_box=use_box, use_points=use_points, 193 box_extension=box_extension 194 ) 195 update_progress(1) 196 197 else: # there is a range of more than 2 slices in between -> segment ranges 198 # segment from bottom 199 segment_range( 200 z_start, z_mid, 1, np.greater_equal if slice_diff % 2 == 0 else np.greater, verbose=verbose 201 ) 202 # segment from top 203 segment_range(z_stop, z_mid, -1, np.less_equal, verbose=verbose) 204 # if the difference between start and stop is even, 205 # then we have a slice in the middle that is the same distance from top bottom 206 # in this case the slice is not segmented in the ranges above, and we segment it 207 # using the combined mask from the adjacent top and bottom slice as prompt 208 if slice_diff % 2 == 0: 209 seg_prompt = np.logical_or(segmentation[z_mid - 1] == 1, segmentation[z_mid + 1] == 1) 210 segmentation[z_mid] = segment_from_mask( 211 predictor, seg_prompt, image_embeddings=image_embeddings, i=z_mid, 212 use_mask=use_mask, use_box=use_box, use_points=use_points, 213 box_extension=box_extension 214 ) 215 update_progress(1) 216 217 return segmentation, (z_min, z_max)
Segment an object mask in in volumetric data.
Arguments:
- segmentation: The initial segmentation for the object.
- predictor: The segment anything predictor.
- image_embeddings: The precomputed image embeddings for the volume.
- segmented_slices: List of slices for which this object has already been segmented.
- stop_lower: Whether to stop at the lowest segmented slice.
- stop_upper: Wheter to stop at the topmost segmented slice.
- iou_threshold: The IOU threshold for continuing segmentation across 3d.
- projection: The projection method to use. One of 'box', 'mask', 'points', 'points_and_mask' or 'single point'. Pass a dictionary to choose the excact combination of projection modes.
- update_progress: Callback to update an external progress bar.
- box_extension: Extension factor for increasing the box size after projection.
- verbose: Whether to print details about the segmentation steps.
Returns:
Array with the volumetric segmentation. Tuple with the first and last segmented slice.
def
merge_instance_segmentation_3d( slice_segmentation: numpy.ndarray, beta: float = 0.5, with_background: bool = True, gap_closing: Optional[int] = None, min_z_extent: Optional[int] = None, verbose: bool = True, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None) -> numpy.ndarray:
282def merge_instance_segmentation_3d( 283 slice_segmentation: np.ndarray, 284 beta: float = 0.5, 285 with_background: bool = True, 286 gap_closing: Optional[int] = None, 287 min_z_extent: Optional[int] = None, 288 verbose: bool = True, 289 pbar_init: Optional[callable] = None, 290 pbar_update: Optional[callable] = None, 291) -> np.ndarray: 292 """Merge stacked 2d instance segmentations into a consistent 3d segmentation. 293 294 Solves a multicut problem based on the overlap of objects to merge across z. 295 296 Args: 297 slice_segmentation: The stacked segmentation across the slices. 298 We assume that the segmentation is labeled consecutive across z. 299 beta: The bias term for the multicut. Higher values lead to a larger 300 degree of over-segmentation and vice versa. 301 with_background: Whether this is a segmentation problem with background. 302 In that case all edges connecting to the background are set to be repulsive. 303 gap_closing: If given, gaps in the segmentation are closed with a binary closing 304 operation. The value is used to determine the number of iterations for the closing. 305 min_z_extent: Require a minimal extent in z for the segmented objects. 306 This can help to prevent segmentation artifacts. 307 verbose: Verbosity flag. 308 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 309 Can be used together with pbar_update to handle napari progress bar in other thread. 310 To enables using this function within a threadworker. 311 pbar_update: Callback to update an external progress bar. 312 313 Returns: 314 The merged segmentation. 315 """ 316 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 317 318 if gap_closing is not None and gap_closing > 0: 319 pbar_init(slice_segmentation.shape[0] + 1, "Merge segmentation") 320 slice_segmentation = _preprocess_closing(slice_segmentation, gap_closing, pbar_update) 321 else: 322 pbar_init(1, "Merge segmentation") 323 324 # Extract the overlap between slices. 325 edges = track_utils.compute_edges_from_overlap(slice_segmentation, verbose=False) 326 327 uv_ids = np.array([[edge["source"], edge["target"]] for edge in edges]) 328 overlaps = np.array([edge["score"] for edge in edges]) 329 330 n_nodes = int(slice_segmentation.max() + 1) 331 graph = nifty.graph.undirectedGraph(n_nodes) 332 graph.insertEdges(uv_ids) 333 334 costs = seg_utils.multicut.compute_edge_costs(overlaps) 335 # set background weights to be maximally repulsive 336 if with_background: 337 bg_edges = (uv_ids == 0).any(axis=1) 338 costs[bg_edges] = -8.0 339 340 node_labels = seg_utils.multicut.multicut_decomposition(graph, 1.0 - costs, beta=beta) 341 342 segmentation = nifty.tools.take(node_labels, slice_segmentation) 343 344 if min_z_extent is not None and min_z_extent > 0: 345 props = regionprops(segmentation) 346 filter_ids = [] 347 for prop in props: 348 box = prop.bbox 349 z_extent = box[3] - box[0] 350 if z_extent < min_z_extent: 351 filter_ids.append(prop.label) 352 if filter_ids: 353 segmentation[np.isin(segmentation, filter_ids)] = 0 354 355 pbar_update(1) 356 pbar_close() 357 358 return segmentation
Merge stacked 2d instance segmentations into a consistent 3d segmentation.
Solves a multicut problem based on the overlap of objects to merge across z.
Arguments:
- slice_segmentation: The stacked segmentation across the slices. We assume that the segmentation is labeled consecutive across z.
- beta: The bias term for the multicut. Higher values lead to a larger degree of over-segmentation and vice versa.
- with_background: Whether this is a segmentation problem with background. In that case all edges connecting to the background are set to be repulsive.
- gap_closing: If given, gaps in the segmentation are closed with a binary closing operation. The value is used to determine the number of iterations for the closing.
- min_z_extent: Require a minimal extent in z for the segmented objects. This can help to prevent segmentation artifacts.
- verbose: Verbosity flag.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
Returns:
The merged segmentation.
def
automatic_3d_segmentation( volume: numpy.ndarray, predictor: segment_anything.predictor.SamPredictor, segmentor: micro_sam.instance_segmentation.AMGBase, embedding_path: Union[str, os.PathLike, NoneType] = None, with_background: bool = True, gap_closing: Optional[int] = None, min_z_extent: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, **kwargs) -> numpy.ndarray:
361def automatic_3d_segmentation( 362 volume: np.ndarray, 363 predictor: SamPredictor, 364 segmentor: AMGBase, 365 embedding_path: Optional[Union[str, os.PathLike]] = None, 366 with_background: bool = True, 367 gap_closing: Optional[int] = None, 368 min_z_extent: Optional[int] = None, 369 tile_shape: Optional[Tuple[int, int]] = None, 370 halo: Optional[Tuple[int, int]] = None, 371 verbose: bool = True, 372 **kwargs, 373) -> np.ndarray: 374 """Segment volume in 3d. 375 376 First segments slices individually in 2d and then merges them across 3d 377 based on overlap of objects between slices. 378 379 Args: 380 volume: The input volume. 381 predictor: The SAM model. 382 segmentor: The instance segmentation class. 383 embedding_path: The path to save pre-computed embeddings. 384 with_background: Whether the segmentation has background. 385 gap_closing: If given, gaps in the segmentation are closed with a binary closing 386 operation. The value is used to determine the number of iterations for the closing. 387 min_z_extent: Require a minimal extent in z for the segmented objects. 388 This can help to prevent segmentation artifacts. 389 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 390 halo: Overlap of the tiles for tiled prediction. 391 verbose: Verbosity flag. 392 kwargs: Keyword arguments for the 'generate' method of the 'segmentor'. 393 394 Returns: 395 The segmentation. 396 """ 397 offset = 0 398 segmentation = np.zeros(volume.shape, dtype="uint32") 399 400 min_object_size = kwargs.pop("min_object_size", 0) 401 image_embeddings = util.precompute_image_embeddings( 402 predictor=predictor, 403 input_=volume, 404 save_path=embedding_path, 405 ndim=3, 406 tile_shape=tile_shape, 407 halo=halo, 408 verbose=verbose, 409 ) 410 411 for i in tqdm(range(segmentation.shape[0]), desc="Segment slices", disable=not verbose): 412 segmentor.initialize(volume[i], image_embeddings=image_embeddings, verbose=False, i=i) 413 seg = segmentor.generate(**kwargs) 414 if len(seg) == 0: 415 continue 416 else: 417 seg = mask_data_to_segmentation(seg, with_background=with_background, min_object_size=min_object_size) 418 max_z = seg.max() 419 if max_z == 0: 420 continue 421 seg[seg != 0] += offset 422 offset = max_z + offset 423 segmentation[i] = seg 424 425 segmentation = merge_instance_segmentation_3d( 426 segmentation, 427 beta=0.5, 428 with_background=with_background, 429 gap_closing=gap_closing, 430 min_z_extent=min_z_extent, 431 verbose=verbose, 432 ) 433 434 return segmentation
Segment volume in 3d.
First segments slices individually in 2d and then merges them across 3d based on overlap of objects between slices.
Arguments:
- volume: The input volume.
- predictor: The SAM model.
- segmentor: The instance segmentation class.
- embedding_path: The path to save pre-computed embeddings.
- with_background: Whether the segmentation has background.
- gap_closing: If given, gaps in the segmentation are closed with a binary closing operation. The value is used to determine the number of iterations for the closing.
- min_z_extent: Require a minimal extent in z for the segmented objects. This can help to prevent segmentation artifacts.
- tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
- halo: Overlap of the tiles for tiled prediction.
- verbose: Verbosity flag.
- kwargs: Keyword arguments for the 'generate' method of the 'segmentor'.
Returns:
The segmentation.