micro_sam.multi_dimensional_segmentation
Multi-dimensional segmentation with segment anything.
1"""Multi-dimensional segmentation with segment anything. 2""" 3 4import os 5from typing import Optional, Union, Tuple 6 7import numpy as np 8 9import nifty 10 11import elf.segmentation as seg_utils 12import elf.tracking.tracking_utils as track_utils 13 14from scipy.ndimage import binary_closing 15from skimage.measure import label, regionprops 16from skimage.segmentation import relabel_sequential 17 18from segment_anything.predictor import SamPredictor 19 20try: 21 from napari.utils import progress as tqdm 22except ImportError: 23 from tqdm import tqdm 24 25from . import util 26from .prompt_based_segmentation import segment_from_mask 27from .instance_segmentation import AMGBase, mask_data_to_segmentation 28 29PROJECTION_MODES = ("box", "mask", "points", "points_and_mask", "single_point") 30 31 32def _validate_projection(projection): 33 use_single_point = False 34 if isinstance(projection, str): 35 if projection == "mask": 36 use_box, use_mask, use_points = True, True, False 37 elif projection == "points": 38 use_box, use_mask, use_points = False, False, True 39 elif projection == "box": 40 use_box, use_mask, use_points = True, False, False 41 elif projection == "points_and_mask": 42 use_box, use_mask, use_points = False, True, True 43 elif projection == "single_point": 44 use_box, use_mask, use_points = False, False, True 45 use_single_point = True 46 else: 47 raise ValueError( 48 "Choose projection method from 'mask' / 'points' / 'box' / 'points_and_mask' / 'single_point'. " 49 f"You have passed the invalid option {projection}." 50 ) 51 elif isinstance(projection, dict): 52 assert len(projection.keys()) == 3, "There should be three parameters assigned for the projection method." 53 use_box, use_mask, use_points = projection["use_box"], projection["use_mask"], projection["use_points"] 54 else: 55 raise ValueError(f"{projection} is not a supported projection method.") 56 return use_box, use_mask, use_points, use_single_point 57 58 59# Advanced stopping criterions. 60# In practice these did not make a big difference, so we do not use this at the moment. 61# We still leave it here for reference. 62def _advanced_stopping_criteria( 63 z, seg_z, seg_prev, z_start, z_increment, segmentation, criterion_choice, score, increment 64): 65 def _compute_mean_iou_for_n_slices(z, increment, seg_z, n_slices): 66 iou_list = [ 67 util.compute_iou(segmentation[z - increment * _slice], seg_z) for _slice in range(1, n_slices+1) 68 ] 69 return np.mean(iou_list) 70 71 if criterion_choice == 1: 72 # 1. current metric: iou of current segmentation and the previous slice 73 iou = util.compute_iou(seg_prev, seg_z) 74 criterion = iou 75 76 elif criterion_choice == 2: 77 # 2. combining SAM iou + iou: curr. slice & first segmented slice + iou: curr. slice vs prev. slice 78 iou = util.compute_iou(seg_prev, seg_z) 79 ff_iou = util.compute_iou(segmentation[z_start], seg_z) 80 criterion = 0.5 * iou + 0.3 * score + 0.2 * ff_iou 81 82 elif criterion_choice == 3: 83 # 3. iou of current segmented slice w.r.t the previous n slices 84 criterion = _compute_mean_iou_for_n_slices(z, increment, seg_z, min(5, abs(z - z_start))) 85 86 return criterion 87 88 89def segment_mask_in_volume( 90 segmentation: np.ndarray, 91 predictor: SamPredictor, 92 image_embeddings: util.ImageEmbeddings, 93 segmented_slices: np.ndarray, 94 stop_lower: bool, 95 stop_upper: bool, 96 iou_threshold: float, 97 projection: Union[str, dict], 98 update_progress: Optional[callable] = None, 99 box_extension: float = 0.0, 100 verbose: bool = False, 101) -> Tuple[np.ndarray, Tuple[int, int]]: 102 """Segment an object mask in in volumetric data. 103 104 Args: 105 segmentation: The initial segmentation for the object. 106 predictor: The segment anything predictor. 107 image_embeddings: The precomputed image embeddings for the volume. 108 segmented_slices: List of slices for which this object has already been segmented. 109 stop_lower: Whether to stop at the lowest segmented slice. 110 stop_upper: Wheter to stop at the topmost segmented slice. 111 iou_threshold: The IOU threshold for continuing segmentation across 3d. 112 projection: The projection method to use. One of 'box', 'mask', 'points', 'points_and_mask' or 'single point'. 113 Pass a dictionary to choose the excact combination of projection modes. 114 update_progress: Callback to update an external progress bar. 115 box_extension: Extension factor for increasing the box size after projection. 116 verbose: Whether to print details about the segmentation steps. 117 118 Returns: 119 Array with the volumetric segmentation. 120 Tuple with the first and last segmented slice. 121 """ 122 use_box, use_mask, use_points, use_single_point = _validate_projection(projection) 123 124 if update_progress is None: 125 def update_progress(*args): 126 pass 127 128 def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None, verbose=False): 129 z = z_start + increment 130 while True: 131 if verbose: 132 print(f"Segment {z_start} to {z_stop}: segmenting slice {z}") 133 seg_prev = segmentation[z - increment] 134 seg_z, score, _ = segment_from_mask( 135 predictor, seg_prev, image_embeddings=image_embeddings, i=z, use_mask=use_mask, 136 use_box=use_box, use_points=use_points, box_extension=box_extension, return_all=True, 137 use_single_point=use_single_point, 138 ) 139 if threshold is not None: 140 iou = util.compute_iou(seg_prev, seg_z) 141 if iou < threshold: 142 msg = f"Segmentation stopped at slice {z} due to IOU {iou} < {threshold}." 143 print(msg) 144 break 145 146 segmentation[z] = seg_z 147 z += increment 148 if stopping_criterion(z, z_stop): 149 if verbose: 150 print(f"Segment {z_start} to {z_stop}: stop at slice {z}") 151 break 152 update_progress(1) 153 154 return z - increment 155 156 z0, z1 = int(segmented_slices.min()), int(segmented_slices.max()) 157 158 # segment below the min slice 159 if z0 > 0 and not stop_lower: 160 z_min = segment_range(z0, 0, -1, np.less, iou_threshold, verbose=verbose) 161 else: 162 z_min = z0 163 164 # segment above the max slice 165 if z1 < segmentation.shape[0] - 1 and not stop_upper: 166 z_max = segment_range(z1, segmentation.shape[0] - 1, 1, np.greater, iou_threshold, verbose=verbose) 167 else: 168 z_max = z1 169 170 # segment in between min and max slice 171 if z0 != z1: 172 for z_start, z_stop in zip(segmented_slices[:-1], segmented_slices[1:]): 173 slice_diff = z_stop - z_start 174 z_mid = int((z_start + z_stop) // 2) 175 176 if slice_diff == 1: # the slices are adjacent -> we don't need to do anything 177 pass 178 179 elif z_start == z0 and stop_lower: # the lower slice is stop: we just segment from upper 180 segment_range(z_stop, z_start, -1, np.less_equal, verbose=verbose) 181 182 elif z_stop == z1 and stop_upper: # the upper slice is stop: we just segment from lower 183 segment_range(z_start, z_stop, 1, np.greater_equal, verbose=verbose) 184 185 elif slice_diff == 2: # there is only one slice in between -> use combined mask 186 z = z_start + 1 187 seg_prompt = np.logical_or(segmentation[z_start] == 1, segmentation[z_stop] == 1) 188 segmentation[z] = segment_from_mask( 189 predictor, seg_prompt, image_embeddings=image_embeddings, i=z, 190 use_mask=use_mask, use_box=use_box, use_points=use_points, 191 box_extension=box_extension 192 ) 193 update_progress(1) 194 195 else: # there is a range of more than 2 slices in between -> segment ranges 196 # segment from bottom 197 segment_range( 198 z_start, z_mid, 1, np.greater_equal if slice_diff % 2 == 0 else np.greater, verbose=verbose 199 ) 200 # segment from top 201 segment_range(z_stop, z_mid, -1, np.less_equal, verbose=verbose) 202 # if the difference between start and stop is even, 203 # then we have a slice in the middle that is the same distance from top bottom 204 # in this case the slice is not segmented in the ranges above, and we segment it 205 # using the combined mask from the adjacent top and bottom slice as prompt 206 if slice_diff % 2 == 0: 207 seg_prompt = np.logical_or(segmentation[z_mid - 1] == 1, segmentation[z_mid + 1] == 1) 208 segmentation[z_mid] = segment_from_mask( 209 predictor, seg_prompt, image_embeddings=image_embeddings, i=z_mid, 210 use_mask=use_mask, use_box=use_box, use_points=use_points, 211 box_extension=box_extension 212 ) 213 update_progress(1) 214 215 return segmentation, (z_min, z_max) 216 217 218def _preprocess_closing(slice_segmentation, gap_closing, pbar_update): 219 binarized = slice_segmentation > 0 220 # Use a structuring element that only closes elements in z, to avoid merging objects in-plane. 221 structuring_element = np.zeros((3, 1, 1)) 222 structuring_element[:, 0, 0] = 1 223 closed_segmentation = binary_closing(binarized, iterations=gap_closing, structure=structuring_element) 224 225 new_segmentation = np.zeros_like(slice_segmentation) 226 n_slices = new_segmentation.shape[0] 227 228 def process_slice(z, offset): 229 seg_z = slice_segmentation[z] 230 231 # Closing does not work for the first and last gap slices 232 if z < gap_closing or z >= (n_slices - gap_closing): 233 seg_z, _, _ = relabel_sequential(seg_z, offset=offset) 234 offset = int(seg_z.max()) + 1 235 return seg_z, offset 236 237 # Apply connected components to the closed segmentation. 238 closed_z = label(closed_segmentation[z]) 239 240 # Map objects in the closed and initial segmentation. 241 # We take objects from the closed segmentation unless they 242 # have overlap with more than one object from the initial segmentation. 243 # This indicates wrong merging of closeby objects that we want to prevent. 244 matches = nifty.ground_truth.overlap(closed_z, seg_z) 245 matches = {seg_id: matches.overlapArrays(seg_id, sorted=False)[0] 246 for seg_id in range(1, int(closed_z.max() + 1))} 247 matches = {k: v[v != 0] for k, v in matches.items()} 248 249 ids_initial, ids_closed = [], [] 250 for seg_id, matched in matches.items(): 251 if len(matched) > 1: 252 ids_initial.extend(matched.tolist()) 253 else: 254 ids_closed.append(seg_id) 255 256 seg_new = np.zeros_like(seg_z) 257 closed_mask = np.isin(closed_z, ids_closed) 258 seg_new[closed_mask] = closed_z[closed_mask] 259 260 if ids_initial: 261 initial_mask = np.isin(seg_z, ids_initial) 262 seg_new[initial_mask] = relabel_sequential(seg_z[initial_mask], offset=seg_new.max() + 1)[0] 263 264 seg_new, _, _ = relabel_sequential(seg_new, offset=offset) 265 max_z = seg_new.max() 266 if max_z > 0: 267 offset = int(max_z) + 1 268 269 return seg_new, offset 270 271 # Further optimization: parallelize 272 offset = 1 273 for z in range(n_slices): 274 new_segmentation[z], offset = process_slice(z, offset) 275 pbar_update(1) 276 277 return new_segmentation 278 279 280def merge_instance_segmentation_3d( 281 slice_segmentation: np.ndarray, 282 beta: float = 0.5, 283 with_background: bool = True, 284 gap_closing: Optional[int] = None, 285 min_z_extent: Optional[int] = None, 286 verbose: bool = True, 287 pbar_init: Optional[callable] = None, 288 pbar_update: Optional[callable] = None, 289) -> np.ndarray: 290 """Merge stacked 2d instance segmentations into a consistent 3d segmentation. 291 292 Solves a multicut problem based on the overlap of objects to merge across z. 293 294 Args: 295 slice_segmentation: The stacked segmentation across the slices. 296 We assume that the segmentation is labeled consecutive across z. 297 beta: The bias term for the multicut. Higher values lead to a larger 298 degree of over-segmentation and vice versa. 299 with_background: Whether this is a segmentation problem with background. 300 In that case all edges connecting to the background are set to be repulsive. 301 gap_closing: If given, gaps in the segmentation are closed with a binary closing 302 operation. The value is used to determine the number of iterations for the closing. 303 min_z_extent: Require a minimal extent in z for the segmented objects. 304 This can help to prevent segmentation artifacts. 305 verbose: Verbosity flag. 306 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 307 Can be used together with pbar_update to handle napari progress bar in other thread. 308 To enables using this function within a threadworker. 309 pbar_update: Callback to update an external progress bar. 310 311 Returns: 312 The merged segmentation. 313 """ 314 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 315 316 if gap_closing is not None and gap_closing > 0: 317 pbar_init(slice_segmentation.shape[0] + 1, "Merge segmentation") 318 slice_segmentation = _preprocess_closing(slice_segmentation, gap_closing, pbar_update) 319 else: 320 pbar_init(1, "Merge segmentation") 321 322 # Extract the overlap between slices. 323 edges = track_utils.compute_edges_from_overlap(slice_segmentation, verbose=False) 324 325 uv_ids = np.array([[edge["source"], edge["target"]] for edge in edges]) 326 overlaps = np.array([edge["score"] for edge in edges]) 327 328 n_nodes = int(slice_segmentation.max() + 1) 329 graph = nifty.graph.undirectedGraph(n_nodes) 330 graph.insertEdges(uv_ids) 331 332 costs = seg_utils.multicut.compute_edge_costs(overlaps) 333 # set background weights to be maximally repulsive 334 if with_background: 335 bg_edges = (uv_ids == 0).any(axis=1) 336 costs[bg_edges] = -8.0 337 338 node_labels = seg_utils.multicut.multicut_decomposition(graph, 1.0 - costs, beta=beta) 339 340 segmentation = nifty.tools.take(node_labels, slice_segmentation) 341 342 if min_z_extent is not None and min_z_extent > 0: 343 props = regionprops(segmentation) 344 filter_ids = [] 345 for prop in props: 346 box = prop.bbox 347 z_extent = box[3] - box[0] 348 if z_extent < min_z_extent: 349 filter_ids.append(prop.label) 350 if filter_ids: 351 segmentation[np.isin(segmentation, filter_ids)] = 0 352 353 pbar_update(1) 354 pbar_close() 355 356 return segmentation 357 358 359def automatic_3d_segmentation( 360 volume: np.ndarray, 361 predictor: SamPredictor, 362 segmentor: AMGBase, 363 embedding_path: Optional[Union[str, os.PathLike]] = None, 364 with_background: bool = True, 365 gap_closing: Optional[int] = None, 366 min_z_extent: Optional[int] = None, 367 tile_shape: Optional[Tuple[int, int]] = None, 368 halo: Optional[Tuple[int, int]] = None, 369 verbose: bool = True, 370 **kwargs, 371) -> np.ndarray: 372 """Segment volume in 3d. 373 374 First segments slices individually in 2d and then merges them across 3d 375 based on overlap of objects between slices. 376 377 Args: 378 volume: The input volume. 379 predictor: The SAM model. 380 segmentor: The instance segmentation class. 381 embedding_path: The path to save pre-computed embeddings. 382 with_background: Whether the segmentation has background. 383 gap_closing: If given, gaps in the segmentation are closed with a binary closing 384 operation. The value is used to determine the number of iterations for the closing. 385 min_z_extent: Require a minimal extent in z for the segmented objects. 386 This can help to prevent segmentation artifacts. 387 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 388 halo: Overlap of the tiles for tiled prediction. 389 verbose: Verbosity flag. 390 kwargs: Keyword arguments for the 'generate' method of the 'segmentor'. 391 392 Returns: 393 The segmentation. 394 """ 395 offset = 0 396 segmentation = np.zeros(volume.shape, dtype="uint32") 397 398 min_object_size = kwargs.pop("min_object_size", 0) 399 image_embeddings = util.precompute_image_embeddings( 400 predictor=predictor, 401 input_=volume, 402 save_path=embedding_path, 403 ndim=3, 404 tile_shape=tile_shape, 405 halo=halo, 406 verbose=verbose, 407 ) 408 409 for i in tqdm(range(segmentation.shape[0]), desc="Segment slices", disable=not verbose): 410 segmentor.initialize(volume[i], image_embeddings=image_embeddings, verbose=False, i=i) 411 seg = segmentor.generate(**kwargs) 412 if len(seg) == 0: 413 continue 414 else: 415 seg = mask_data_to_segmentation(seg, with_background=with_background, min_object_size=min_object_size) 416 max_z = seg.max() 417 if max_z == 0: 418 continue 419 seg[seg != 0] += offset 420 offset = max_z + offset 421 segmentation[i] = seg 422 423 segmentation = merge_instance_segmentation_3d( 424 segmentation, 425 beta=0.5, 426 with_background=with_background, 427 gap_closing=gap_closing, 428 min_z_extent=min_z_extent, 429 verbose=verbose, 430 ) 431 432 return segmentation
PROJECTION_MODES =
('box', 'mask', 'points', 'points_and_mask', 'single_point')
def
segment_mask_in_volume( segmentation: numpy.ndarray, predictor: segment_anything.predictor.SamPredictor, image_embeddings: Dict[str, Any], segmented_slices: numpy.ndarray, stop_lower: bool, stop_upper: bool, iou_threshold: float, projection: Union[str, dict], update_progress: Optional[<built-in function callable>] = None, box_extension: float = 0.0, verbose: bool = False) -> Tuple[numpy.ndarray, Tuple[int, int]]:
90def segment_mask_in_volume( 91 segmentation: np.ndarray, 92 predictor: SamPredictor, 93 image_embeddings: util.ImageEmbeddings, 94 segmented_slices: np.ndarray, 95 stop_lower: bool, 96 stop_upper: bool, 97 iou_threshold: float, 98 projection: Union[str, dict], 99 update_progress: Optional[callable] = None, 100 box_extension: float = 0.0, 101 verbose: bool = False, 102) -> Tuple[np.ndarray, Tuple[int, int]]: 103 """Segment an object mask in in volumetric data. 104 105 Args: 106 segmentation: The initial segmentation for the object. 107 predictor: The segment anything predictor. 108 image_embeddings: The precomputed image embeddings for the volume. 109 segmented_slices: List of slices for which this object has already been segmented. 110 stop_lower: Whether to stop at the lowest segmented slice. 111 stop_upper: Wheter to stop at the topmost segmented slice. 112 iou_threshold: The IOU threshold for continuing segmentation across 3d. 113 projection: The projection method to use. One of 'box', 'mask', 'points', 'points_and_mask' or 'single point'. 114 Pass a dictionary to choose the excact combination of projection modes. 115 update_progress: Callback to update an external progress bar. 116 box_extension: Extension factor for increasing the box size after projection. 117 verbose: Whether to print details about the segmentation steps. 118 119 Returns: 120 Array with the volumetric segmentation. 121 Tuple with the first and last segmented slice. 122 """ 123 use_box, use_mask, use_points, use_single_point = _validate_projection(projection) 124 125 if update_progress is None: 126 def update_progress(*args): 127 pass 128 129 def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None, verbose=False): 130 z = z_start + increment 131 while True: 132 if verbose: 133 print(f"Segment {z_start} to {z_stop}: segmenting slice {z}") 134 seg_prev = segmentation[z - increment] 135 seg_z, score, _ = segment_from_mask( 136 predictor, seg_prev, image_embeddings=image_embeddings, i=z, use_mask=use_mask, 137 use_box=use_box, use_points=use_points, box_extension=box_extension, return_all=True, 138 use_single_point=use_single_point, 139 ) 140 if threshold is not None: 141 iou = util.compute_iou(seg_prev, seg_z) 142 if iou < threshold: 143 msg = f"Segmentation stopped at slice {z} due to IOU {iou} < {threshold}." 144 print(msg) 145 break 146 147 segmentation[z] = seg_z 148 z += increment 149 if stopping_criterion(z, z_stop): 150 if verbose: 151 print(f"Segment {z_start} to {z_stop}: stop at slice {z}") 152 break 153 update_progress(1) 154 155 return z - increment 156 157 z0, z1 = int(segmented_slices.min()), int(segmented_slices.max()) 158 159 # segment below the min slice 160 if z0 > 0 and not stop_lower: 161 z_min = segment_range(z0, 0, -1, np.less, iou_threshold, verbose=verbose) 162 else: 163 z_min = z0 164 165 # segment above the max slice 166 if z1 < segmentation.shape[0] - 1 and not stop_upper: 167 z_max = segment_range(z1, segmentation.shape[0] - 1, 1, np.greater, iou_threshold, verbose=verbose) 168 else: 169 z_max = z1 170 171 # segment in between min and max slice 172 if z0 != z1: 173 for z_start, z_stop in zip(segmented_slices[:-1], segmented_slices[1:]): 174 slice_diff = z_stop - z_start 175 z_mid = int((z_start + z_stop) // 2) 176 177 if slice_diff == 1: # the slices are adjacent -> we don't need to do anything 178 pass 179 180 elif z_start == z0 and stop_lower: # the lower slice is stop: we just segment from upper 181 segment_range(z_stop, z_start, -1, np.less_equal, verbose=verbose) 182 183 elif z_stop == z1 and stop_upper: # the upper slice is stop: we just segment from lower 184 segment_range(z_start, z_stop, 1, np.greater_equal, verbose=verbose) 185 186 elif slice_diff == 2: # there is only one slice in between -> use combined mask 187 z = z_start + 1 188 seg_prompt = np.logical_or(segmentation[z_start] == 1, segmentation[z_stop] == 1) 189 segmentation[z] = segment_from_mask( 190 predictor, seg_prompt, image_embeddings=image_embeddings, i=z, 191 use_mask=use_mask, use_box=use_box, use_points=use_points, 192 box_extension=box_extension 193 ) 194 update_progress(1) 195 196 else: # there is a range of more than 2 slices in between -> segment ranges 197 # segment from bottom 198 segment_range( 199 z_start, z_mid, 1, np.greater_equal if slice_diff % 2 == 0 else np.greater, verbose=verbose 200 ) 201 # segment from top 202 segment_range(z_stop, z_mid, -1, np.less_equal, verbose=verbose) 203 # if the difference between start and stop is even, 204 # then we have a slice in the middle that is the same distance from top bottom 205 # in this case the slice is not segmented in the ranges above, and we segment it 206 # using the combined mask from the adjacent top and bottom slice as prompt 207 if slice_diff % 2 == 0: 208 seg_prompt = np.logical_or(segmentation[z_mid - 1] == 1, segmentation[z_mid + 1] == 1) 209 segmentation[z_mid] = segment_from_mask( 210 predictor, seg_prompt, image_embeddings=image_embeddings, i=z_mid, 211 use_mask=use_mask, use_box=use_box, use_points=use_points, 212 box_extension=box_extension 213 ) 214 update_progress(1) 215 216 return segmentation, (z_min, z_max)
Segment an object mask in in volumetric data.
Arguments:
- segmentation: The initial segmentation for the object.
- predictor: The segment anything predictor.
- image_embeddings: The precomputed image embeddings for the volume.
- segmented_slices: List of slices for which this object has already been segmented.
- stop_lower: Whether to stop at the lowest segmented slice.
- stop_upper: Wheter to stop at the topmost segmented slice.
- iou_threshold: The IOU threshold for continuing segmentation across 3d.
- projection: The projection method to use. One of 'box', 'mask', 'points', 'points_and_mask' or 'single point'. Pass a dictionary to choose the excact combination of projection modes.
- update_progress: Callback to update an external progress bar.
- box_extension: Extension factor for increasing the box size after projection.
- verbose: Whether to print details about the segmentation steps.
Returns:
Array with the volumetric segmentation. Tuple with the first and last segmented slice.
def
merge_instance_segmentation_3d( slice_segmentation: numpy.ndarray, beta: float = 0.5, with_background: bool = True, gap_closing: Optional[int] = None, min_z_extent: Optional[int] = None, verbose: bool = True, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None) -> numpy.ndarray:
281def merge_instance_segmentation_3d( 282 slice_segmentation: np.ndarray, 283 beta: float = 0.5, 284 with_background: bool = True, 285 gap_closing: Optional[int] = None, 286 min_z_extent: Optional[int] = None, 287 verbose: bool = True, 288 pbar_init: Optional[callable] = None, 289 pbar_update: Optional[callable] = None, 290) -> np.ndarray: 291 """Merge stacked 2d instance segmentations into a consistent 3d segmentation. 292 293 Solves a multicut problem based on the overlap of objects to merge across z. 294 295 Args: 296 slice_segmentation: The stacked segmentation across the slices. 297 We assume that the segmentation is labeled consecutive across z. 298 beta: The bias term for the multicut. Higher values lead to a larger 299 degree of over-segmentation and vice versa. 300 with_background: Whether this is a segmentation problem with background. 301 In that case all edges connecting to the background are set to be repulsive. 302 gap_closing: If given, gaps in the segmentation are closed with a binary closing 303 operation. The value is used to determine the number of iterations for the closing. 304 min_z_extent: Require a minimal extent in z for the segmented objects. 305 This can help to prevent segmentation artifacts. 306 verbose: Verbosity flag. 307 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 308 Can be used together with pbar_update to handle napari progress bar in other thread. 309 To enables using this function within a threadworker. 310 pbar_update: Callback to update an external progress bar. 311 312 Returns: 313 The merged segmentation. 314 """ 315 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 316 317 if gap_closing is not None and gap_closing > 0: 318 pbar_init(slice_segmentation.shape[0] + 1, "Merge segmentation") 319 slice_segmentation = _preprocess_closing(slice_segmentation, gap_closing, pbar_update) 320 else: 321 pbar_init(1, "Merge segmentation") 322 323 # Extract the overlap between slices. 324 edges = track_utils.compute_edges_from_overlap(slice_segmentation, verbose=False) 325 326 uv_ids = np.array([[edge["source"], edge["target"]] for edge in edges]) 327 overlaps = np.array([edge["score"] for edge in edges]) 328 329 n_nodes = int(slice_segmentation.max() + 1) 330 graph = nifty.graph.undirectedGraph(n_nodes) 331 graph.insertEdges(uv_ids) 332 333 costs = seg_utils.multicut.compute_edge_costs(overlaps) 334 # set background weights to be maximally repulsive 335 if with_background: 336 bg_edges = (uv_ids == 0).any(axis=1) 337 costs[bg_edges] = -8.0 338 339 node_labels = seg_utils.multicut.multicut_decomposition(graph, 1.0 - costs, beta=beta) 340 341 segmentation = nifty.tools.take(node_labels, slice_segmentation) 342 343 if min_z_extent is not None and min_z_extent > 0: 344 props = regionprops(segmentation) 345 filter_ids = [] 346 for prop in props: 347 box = prop.bbox 348 z_extent = box[3] - box[0] 349 if z_extent < min_z_extent: 350 filter_ids.append(prop.label) 351 if filter_ids: 352 segmentation[np.isin(segmentation, filter_ids)] = 0 353 354 pbar_update(1) 355 pbar_close() 356 357 return segmentation
Merge stacked 2d instance segmentations into a consistent 3d segmentation.
Solves a multicut problem based on the overlap of objects to merge across z.
Arguments:
- slice_segmentation: The stacked segmentation across the slices. We assume that the segmentation is labeled consecutive across z.
- beta: The bias term for the multicut. Higher values lead to a larger degree of over-segmentation and vice versa.
- with_background: Whether this is a segmentation problem with background. In that case all edges connecting to the background are set to be repulsive.
- gap_closing: If given, gaps in the segmentation are closed with a binary closing operation. The value is used to determine the number of iterations for the closing.
- min_z_extent: Require a minimal extent in z for the segmented objects. This can help to prevent segmentation artifacts.
- verbose: Verbosity flag.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
Returns:
The merged segmentation.
def
automatic_3d_segmentation( volume: numpy.ndarray, predictor: segment_anything.predictor.SamPredictor, segmentor: micro_sam.instance_segmentation.AMGBase, embedding_path: Union[str, os.PathLike, NoneType] = None, with_background: bool = True, gap_closing: Optional[int] = None, min_z_extent: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, **kwargs) -> numpy.ndarray:
360def automatic_3d_segmentation( 361 volume: np.ndarray, 362 predictor: SamPredictor, 363 segmentor: AMGBase, 364 embedding_path: Optional[Union[str, os.PathLike]] = None, 365 with_background: bool = True, 366 gap_closing: Optional[int] = None, 367 min_z_extent: Optional[int] = None, 368 tile_shape: Optional[Tuple[int, int]] = None, 369 halo: Optional[Tuple[int, int]] = None, 370 verbose: bool = True, 371 **kwargs, 372) -> np.ndarray: 373 """Segment volume in 3d. 374 375 First segments slices individually in 2d and then merges them across 3d 376 based on overlap of objects between slices. 377 378 Args: 379 volume: The input volume. 380 predictor: The SAM model. 381 segmentor: The instance segmentation class. 382 embedding_path: The path to save pre-computed embeddings. 383 with_background: Whether the segmentation has background. 384 gap_closing: If given, gaps in the segmentation are closed with a binary closing 385 operation. The value is used to determine the number of iterations for the closing. 386 min_z_extent: Require a minimal extent in z for the segmented objects. 387 This can help to prevent segmentation artifacts. 388 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 389 halo: Overlap of the tiles for tiled prediction. 390 verbose: Verbosity flag. 391 kwargs: Keyword arguments for the 'generate' method of the 'segmentor'. 392 393 Returns: 394 The segmentation. 395 """ 396 offset = 0 397 segmentation = np.zeros(volume.shape, dtype="uint32") 398 399 min_object_size = kwargs.pop("min_object_size", 0) 400 image_embeddings = util.precompute_image_embeddings( 401 predictor=predictor, 402 input_=volume, 403 save_path=embedding_path, 404 ndim=3, 405 tile_shape=tile_shape, 406 halo=halo, 407 verbose=verbose, 408 ) 409 410 for i in tqdm(range(segmentation.shape[0]), desc="Segment slices", disable=not verbose): 411 segmentor.initialize(volume[i], image_embeddings=image_embeddings, verbose=False, i=i) 412 seg = segmentor.generate(**kwargs) 413 if len(seg) == 0: 414 continue 415 else: 416 seg = mask_data_to_segmentation(seg, with_background=with_background, min_object_size=min_object_size) 417 max_z = seg.max() 418 if max_z == 0: 419 continue 420 seg[seg != 0] += offset 421 offset = max_z + offset 422 segmentation[i] = seg 423 424 segmentation = merge_instance_segmentation_3d( 425 segmentation, 426 beta=0.5, 427 with_background=with_background, 428 gap_closing=gap_closing, 429 min_z_extent=min_z_extent, 430 verbose=verbose, 431 ) 432 433 return segmentation
Segment volume in 3d.
First segments slices individually in 2d and then merges them across 3d based on overlap of objects between slices.
Arguments:
- volume: The input volume.
- predictor: The SAM model.
- segmentor: The instance segmentation class.
- embedding_path: The path to save pre-computed embeddings.
- with_background: Whether the segmentation has background.
- gap_closing: If given, gaps in the segmentation are closed with a binary closing operation. The value is used to determine the number of iterations for the closing.
- min_z_extent: Require a minimal extent in z for the segmented objects. This can help to prevent segmentation artifacts.
- tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
- halo: Overlap of the tiles for tiled prediction.
- verbose: Verbosity flag.
- kwargs: Keyword arguments for the 'generate' method of the 'segmentor'.
Returns:
The segmentation.