micro_sam.multi_dimensional_segmentation

Multi-dimensional segmentation with segment anything.

  1"""Multi-dimensional segmentation with segment anything.
  2"""
  3
  4import os
  5from typing import Optional, Union, Tuple
  6
  7import numpy as np
  8from scipy.ndimage import binary_closing
  9from skimage.measure import label, regionprops
 10from skimage.segmentation import relabel_sequential
 11
 12import nifty
 13
 14import elf.segmentation as seg_utils
 15import elf.tracking.tracking_utils as track_utils
 16
 17from segment_anything.predictor import SamPredictor
 18
 19try:
 20    from napari.utils import progress as tqdm
 21except ImportError:
 22    from tqdm import tqdm
 23
 24from . import util
 25from .prompt_based_segmentation import segment_from_mask
 26from .instance_segmentation import AMGBase, mask_data_to_segmentation
 27
 28
 29PROJECTION_MODES = ("box", "mask", "points", "points_and_mask", "single_point")
 30
 31
 32def _validate_projection(projection):
 33    use_single_point = False
 34    if isinstance(projection, str):
 35        if projection == "mask":
 36            use_box, use_mask, use_points = True, True, False
 37        elif projection == "points":
 38            use_box, use_mask, use_points = False, False, True
 39        elif projection == "box":
 40            use_box, use_mask, use_points = True, False, False
 41        elif projection == "points_and_mask":
 42            use_box, use_mask, use_points = False, True, True
 43        elif projection == "single_point":
 44            use_box, use_mask, use_points = False, False, True
 45            use_single_point = True
 46        else:
 47            raise ValueError(
 48                "Choose projection method from 'mask' / 'points' / 'box' / 'points_and_mask' / 'single_point'. "
 49                f"You have passed the invalid option {projection}."
 50            )
 51    elif isinstance(projection, dict):
 52        assert len(projection.keys()) == 3, "There should be three parameters assigned for the projection method."
 53        use_box, use_mask, use_points = projection["use_box"], projection["use_mask"], projection["use_points"]
 54    else:
 55        raise ValueError(f"{projection} is not a supported projection method.")
 56    return use_box, use_mask, use_points, use_single_point
 57
 58
 59# Advanced stopping criterions.
 60# In practice these did not make a big difference, so we do not use this at the moment.
 61# We still leave it here for reference.
 62def _advanced_stopping_criteria(
 63    z, seg_z, seg_prev, z_start, z_increment, segmentation, criterion_choice, score, increment
 64):
 65    def _compute_mean_iou_for_n_slices(z, increment, seg_z, n_slices):
 66        iou_list = [
 67            util.compute_iou(segmentation[z - increment * _slice], seg_z) for _slice in range(1, n_slices+1)
 68        ]
 69        return np.mean(iou_list)
 70
 71    if criterion_choice == 1:
 72        # 1. current metric: iou of current segmentation and the previous slice
 73        iou = util.compute_iou(seg_prev, seg_z)
 74        criterion = iou
 75
 76    elif criterion_choice == 2:
 77        # 2. combining SAM iou + iou: curr. slice & first segmented slice + iou: curr. slice vs prev. slice
 78        iou = util.compute_iou(seg_prev, seg_z)
 79        ff_iou = util.compute_iou(segmentation[z_start], seg_z)
 80        criterion = 0.5 * iou + 0.3 * score + 0.2 * ff_iou
 81
 82    elif criterion_choice == 3:
 83        # 3. iou of current segmented slice w.r.t the previous n slices
 84        criterion = _compute_mean_iou_for_n_slices(z, increment, seg_z, min(5, abs(z - z_start)))
 85
 86    return criterion
 87
 88
 89def segment_mask_in_volume(
 90    segmentation: np.ndarray,
 91    predictor: SamPredictor,
 92    image_embeddings: util.ImageEmbeddings,
 93    segmented_slices: np.ndarray,
 94    stop_lower: bool,
 95    stop_upper: bool,
 96    iou_threshold: float,
 97    projection: Union[str, dict],
 98    update_progress: Optional[callable] = None,
 99    box_extension: float = 0.0,
100    verbose: bool = False,
101) -> Tuple[np.ndarray, Tuple[int, int]]:
102    """Segment an object mask in in volumetric data.
103
104    Args:
105        segmentation: The initial segmentation for the object.
106        predictor: The segment anything predictor.
107        image_embeddings: The precomputed image embeddings for the volume.
108        segmented_slices: List of slices for which this object has already been segmented.
109        stop_lower: Whether to stop at the lowest segmented slice.
110        stop_upper: Wheter to stop at the topmost segmented slice.
111        iou_threshold: The IOU threshold for continuing segmentation across 3d.
112        projection: The projection method to use. One of 'box', 'mask', 'points', 'points_and_mask' or 'single point'.
113            Pass a dictionary to choose the excact combination of projection modes.
114        update_progress: Callback to update an external progress bar.
115        box_extension: Extension factor for increasing the box size after projection.
116        verbose: Whether to print details about the segmentation steps.
117
118    Returns:
119        Array with the volumetric segmentation.
120        Tuple with the first and last segmented slice.
121    """
122    use_box, use_mask, use_points, use_single_point = _validate_projection(projection)
123
124    if update_progress is None:
125        def update_progress(*args):
126            pass
127
128    def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None, verbose=False):
129        z = z_start + increment
130        while True:
131            if verbose:
132                print(f"Segment {z_start} to {z_stop}: segmenting slice {z}")
133            seg_prev = segmentation[z - increment]
134            seg_z, score, _ = segment_from_mask(
135                predictor, seg_prev, image_embeddings=image_embeddings, i=z, use_mask=use_mask,
136                use_box=use_box, use_points=use_points, box_extension=box_extension, return_all=True,
137                use_single_point=use_single_point,
138            )
139            if threshold is not None:
140                iou = util.compute_iou(seg_prev, seg_z)
141                if iou < threshold:
142                    if verbose:
143                        msg = f"Segmentation stopped at slice {z} due to IOU {iou} < {threshold}."
144                        print(msg)
145                    break
146
147            segmentation[z] = seg_z
148            z += increment
149            if stopping_criterion(z, z_stop):
150                if verbose:
151                    print(f"Segment {z_start} to {z_stop}: stop at slice {z}")
152                break
153            update_progress(1)
154
155        return z - increment
156
157    z0, z1 = int(segmented_slices.min()), int(segmented_slices.max())
158
159    # segment below the min slice
160    if z0 > 0 and not stop_lower:
161        z_min = segment_range(z0, 0, -1, np.less, iou_threshold, verbose=verbose)
162    else:
163        z_min = z0
164
165    # segment above the max slice
166    if z1 < segmentation.shape[0] - 1 and not stop_upper:
167        z_max = segment_range(z1, segmentation.shape[0] - 1, 1, np.greater, iou_threshold, verbose=verbose)
168    else:
169        z_max = z1
170
171    # segment in between min and max slice
172    if z0 != z1:
173        for z_start, z_stop in zip(segmented_slices[:-1], segmented_slices[1:]):
174            slice_diff = z_stop - z_start
175            z_mid = int((z_start + z_stop) // 2)
176
177            if slice_diff == 1:  # the slices are adjacent -> we don't need to do anything
178                pass
179
180            elif z_start == z0 and stop_lower:  # the lower slice is stop: we just segment from upper
181                segment_range(z_stop, z_start, -1, np.less_equal, verbose=verbose)
182
183            elif z_stop == z1 and stop_upper:  # the upper slice is stop: we just segment from lower
184                segment_range(z_start, z_stop, 1, np.greater_equal, verbose=verbose)
185
186            elif slice_diff == 2:  # there is only one slice in between -> use combined mask
187                z = z_start + 1
188                seg_prompt = np.logical_or(segmentation[z_start] == 1, segmentation[z_stop] == 1)
189                segmentation[z] = segment_from_mask(
190                    predictor, seg_prompt, image_embeddings=image_embeddings, i=z,
191                    use_mask=use_mask, use_box=use_box, use_points=use_points,
192                    box_extension=box_extension
193                )
194                update_progress(1)
195
196            else:  # there is a range of more than 2 slices in between -> segment ranges
197                # segment from bottom
198                segment_range(
199                    z_start, z_mid, 1, np.greater_equal if slice_diff % 2 == 0 else np.greater, verbose=verbose
200                )
201                # segment from top
202                segment_range(z_stop, z_mid, -1, np.less_equal, verbose=verbose)
203                # if the difference between start and stop is even,
204                # then we have a slice in the middle that is the same distance from top bottom
205                # in this case the slice is not segmented in the ranges above, and we segment it
206                # using the combined mask from the adjacent top and bottom slice as prompt
207                if slice_diff % 2 == 0:
208                    seg_prompt = np.logical_or(segmentation[z_mid - 1] == 1, segmentation[z_mid + 1] == 1)
209                    segmentation[z_mid] = segment_from_mask(
210                        predictor, seg_prompt, image_embeddings=image_embeddings, i=z_mid,
211                        use_mask=use_mask, use_box=use_box, use_points=use_points,
212                        box_extension=box_extension
213                    )
214                    update_progress(1)
215
216    return segmentation, (z_min, z_max)
217
218
219def _preprocess_closing(slice_segmentation, gap_closing, pbar_update):
220    binarized = slice_segmentation > 0
221    # Use a structuring element that only closes elements in z, to avoid merging objects in-plane.
222    structuring_element = np.zeros((3, 1, 1))
223    structuring_element[:, 0, 0] = 1
224    closed_segmentation = binary_closing(binarized, iterations=gap_closing, structure=structuring_element)
225
226    new_segmentation = np.zeros_like(slice_segmentation)
227    n_slices = new_segmentation.shape[0]
228
229    def process_slice(z, offset):
230        seg_z = slice_segmentation[z]
231
232        # Closing does not work for the first and last gap slices
233        if z < gap_closing or z >= (n_slices - gap_closing):
234            seg_z, _, _ = relabel_sequential(seg_z, offset=offset)
235            offset = int(seg_z.max()) + 1
236            return seg_z, offset
237
238        # Apply connected components to the closed segmentation.
239        closed_z = label(closed_segmentation[z])
240
241        # Map objects in the closed and initial segmentation.
242        # We take objects from the closed segmentation unless they
243        # have overlap with more than one object from the initial segmentation.
244        # This indicates wrong merging of closeby objects that we want to prevent.
245        matches = nifty.ground_truth.overlap(closed_z, seg_z)
246        matches = {
247            seg_id: matches.overlapArrays(seg_id, sorted=False)[0] for seg_id in range(1, int(closed_z.max() + 1))
248        }
249        matches = {k: v[v != 0] for k, v in matches.items()}
250
251        ids_initial, ids_closed = [], []
252        for seg_id, matched in matches.items():
253            if len(matched) > 1:
254                ids_initial.extend(matched.tolist())
255            else:
256                ids_closed.append(seg_id)
257
258        seg_new = np.zeros_like(seg_z)
259        closed_mask = np.isin(closed_z, ids_closed)
260        seg_new[closed_mask] = closed_z[closed_mask]
261
262        if ids_initial:
263            initial_mask = np.isin(seg_z, ids_initial)
264            seg_new[initial_mask] = relabel_sequential(seg_z[initial_mask], offset=seg_new.max() + 1)[0]
265
266        seg_new, _, _ = relabel_sequential(seg_new, offset=offset)
267        max_z = seg_new.max()
268        if max_z > 0:
269            offset = int(max_z) + 1
270
271        return seg_new, offset
272
273    # Further optimization: parallelize
274    offset = 1
275    for z in range(n_slices):
276        new_segmentation[z], offset = process_slice(z, offset)
277        pbar_update(1)
278
279    return new_segmentation
280
281
282def merge_instance_segmentation_3d(
283    slice_segmentation: np.ndarray,
284    beta: float = 0.5,
285    with_background: bool = True,
286    gap_closing: Optional[int] = None,
287    min_z_extent: Optional[int] = None,
288    verbose: bool = True,
289    pbar_init: Optional[callable] = None,
290    pbar_update: Optional[callable] = None,
291) -> np.ndarray:
292    """Merge stacked 2d instance segmentations into a consistent 3d segmentation.
293
294    Solves a multicut problem based on the overlap of objects to merge across z.
295
296    Args:
297        slice_segmentation: The stacked segmentation across the slices.
298            We assume that the segmentation is labeled consecutive across z.
299        beta: The bias term for the multicut. Higher values lead to a larger
300            degree of over-segmentation and vice versa.
301        with_background: Whether this is a segmentation problem with background.
302            In that case all edges connecting to the background are set to be repulsive.
303        gap_closing: If given, gaps in the segmentation are closed with a binary closing
304            operation. The value is used to determine the number of iterations for the closing.
305        min_z_extent: Require a minimal extent in z for the segmented objects.
306            This can help to prevent segmentation artifacts.
307        verbose: Verbosity flag.
308        pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
309            Can be used together with pbar_update to handle napari progress bar in other thread.
310            To enables using this function within a threadworker.
311        pbar_update: Callback to update an external progress bar.
312
313    Returns:
314        The merged segmentation.
315    """
316    _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
317
318    if gap_closing is not None and gap_closing > 0:
319        pbar_init(slice_segmentation.shape[0] + 1, "Merge segmentation")
320        slice_segmentation = _preprocess_closing(slice_segmentation, gap_closing, pbar_update)
321    else:
322        pbar_init(1, "Merge segmentation")
323
324    # Extract the overlap between slices.
325    edges = track_utils.compute_edges_from_overlap(slice_segmentation, verbose=False)
326
327    uv_ids = np.array([[edge["source"], edge["target"]] for edge in edges])
328    overlaps = np.array([edge["score"] for edge in edges])
329
330    n_nodes = int(slice_segmentation.max() + 1)
331    graph = nifty.graph.undirectedGraph(n_nodes)
332    graph.insertEdges(uv_ids)
333
334    costs = seg_utils.multicut.compute_edge_costs(overlaps)
335    # set background weights to be maximally repulsive
336    if with_background:
337        bg_edges = (uv_ids == 0).any(axis=1)
338        costs[bg_edges] = -8.0
339
340    node_labels = seg_utils.multicut.multicut_decomposition(graph, 1.0 - costs, beta=beta)
341
342    segmentation = nifty.tools.take(node_labels, slice_segmentation)
343
344    if min_z_extent is not None and min_z_extent > 0:
345        props = regionprops(segmentation)
346        filter_ids = []
347        for prop in props:
348            box = prop.bbox
349            z_extent = box[3] - box[0]
350            if z_extent < min_z_extent:
351                filter_ids.append(prop.label)
352        if filter_ids:
353            segmentation[np.isin(segmentation, filter_ids)] = 0
354
355    pbar_update(1)
356    pbar_close()
357
358    return segmentation
359
360
361def automatic_3d_segmentation(
362    volume: np.ndarray,
363    predictor: SamPredictor,
364    segmentor: AMGBase,
365    embedding_path: Optional[Union[str, os.PathLike]] = None,
366    with_background: bool = True,
367    gap_closing: Optional[int] = None,
368    min_z_extent: Optional[int] = None,
369    tile_shape: Optional[Tuple[int, int]] = None,
370    halo: Optional[Tuple[int, int]] = None,
371    verbose: bool = True,
372    return_embeddings: bool = False,
373    **kwargs,
374) -> np.ndarray:
375    """Segment volume in 3d.
376
377    First segments slices individually in 2d and then merges them across 3d
378    based on overlap of objects between slices.
379
380    Args:
381        volume: The input volume.
382        predictor: The SAM model.
383        segmentor: The instance segmentation class.
384        embedding_path: The path to save pre-computed embeddings.
385        with_background: Whether the segmentation has background.
386        gap_closing: If given, gaps in the segmentation are closed with a binary closing
387            operation. The value is used to determine the number of iterations for the closing.
388        min_z_extent: Require a minimal extent in z for the segmented objects.
389            This can help to prevent segmentation artifacts.
390        tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
391        halo: Overlap of the tiles for tiled prediction.
392        verbose: Verbosity flag.
393        return_embeddings: Whether to return the precomputed image embeddings.
394        kwargs: Keyword arguments for the 'generate' method of the 'segmentor'.
395
396    Returns:
397        The segmentation.
398    """
399    offset = 0
400    segmentation = np.zeros(volume.shape[:3], dtype="uint32")
401
402    min_object_size = kwargs.pop("min_object_size", 0)
403    image_embeddings = util.precompute_image_embeddings(
404        predictor=predictor,
405        input_=volume,
406        save_path=embedding_path,
407        ndim=3,
408        tile_shape=tile_shape,
409        halo=halo,
410        verbose=verbose,
411    )
412
413    for i in tqdm(range(segmentation.shape[0]), desc="Segment slices", disable=not verbose):
414        segmentor.initialize(volume[i], image_embeddings=image_embeddings, verbose=False, i=i)
415        seg = segmentor.generate(**kwargs)
416
417        if isinstance(seg, list) and len(seg) == 0:
418            continue
419        else:
420            if isinstance(seg, list):
421                seg = mask_data_to_segmentation(
422                    seg, with_background=with_background, min_object_size=min_object_size
423                )
424
425            # Set offset for instance per slice.
426            max_z = seg.max()
427            if max_z == 0:
428                continue
429            seg[seg != 0] += offset
430            offset = max_z + offset
431
432        segmentation[i] = seg
433
434    segmentation = merge_instance_segmentation_3d(
435        segmentation,
436        beta=0.5,
437        with_background=with_background,
438        gap_closing=gap_closing,
439        min_z_extent=min_z_extent,
440        verbose=verbose,
441    )
442
443    if return_embeddings:
444        return segmentation, image_embeddings
445    else:
446        return segmentation
PROJECTION_MODES = ('box', 'mask', 'points', 'points_and_mask', 'single_point')
def segment_mask_in_volume( segmentation: numpy.ndarray, predictor: segment_anything.predictor.SamPredictor, image_embeddings: Dict[str, Any], segmented_slices: numpy.ndarray, stop_lower: bool, stop_upper: bool, iou_threshold: float, projection: Union[str, dict], update_progress: Optional[<built-in function callable>] = None, box_extension: float = 0.0, verbose: bool = False) -> Tuple[numpy.ndarray, Tuple[int, int]]:
 90def segment_mask_in_volume(
 91    segmentation: np.ndarray,
 92    predictor: SamPredictor,
 93    image_embeddings: util.ImageEmbeddings,
 94    segmented_slices: np.ndarray,
 95    stop_lower: bool,
 96    stop_upper: bool,
 97    iou_threshold: float,
 98    projection: Union[str, dict],
 99    update_progress: Optional[callable] = None,
100    box_extension: float = 0.0,
101    verbose: bool = False,
102) -> Tuple[np.ndarray, Tuple[int, int]]:
103    """Segment an object mask in in volumetric data.
104
105    Args:
106        segmentation: The initial segmentation for the object.
107        predictor: The segment anything predictor.
108        image_embeddings: The precomputed image embeddings for the volume.
109        segmented_slices: List of slices for which this object has already been segmented.
110        stop_lower: Whether to stop at the lowest segmented slice.
111        stop_upper: Wheter to stop at the topmost segmented slice.
112        iou_threshold: The IOU threshold for continuing segmentation across 3d.
113        projection: The projection method to use. One of 'box', 'mask', 'points', 'points_and_mask' or 'single point'.
114            Pass a dictionary to choose the excact combination of projection modes.
115        update_progress: Callback to update an external progress bar.
116        box_extension: Extension factor for increasing the box size after projection.
117        verbose: Whether to print details about the segmentation steps.
118
119    Returns:
120        Array with the volumetric segmentation.
121        Tuple with the first and last segmented slice.
122    """
123    use_box, use_mask, use_points, use_single_point = _validate_projection(projection)
124
125    if update_progress is None:
126        def update_progress(*args):
127            pass
128
129    def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None, verbose=False):
130        z = z_start + increment
131        while True:
132            if verbose:
133                print(f"Segment {z_start} to {z_stop}: segmenting slice {z}")
134            seg_prev = segmentation[z - increment]
135            seg_z, score, _ = segment_from_mask(
136                predictor, seg_prev, image_embeddings=image_embeddings, i=z, use_mask=use_mask,
137                use_box=use_box, use_points=use_points, box_extension=box_extension, return_all=True,
138                use_single_point=use_single_point,
139            )
140            if threshold is not None:
141                iou = util.compute_iou(seg_prev, seg_z)
142                if iou < threshold:
143                    if verbose:
144                        msg = f"Segmentation stopped at slice {z} due to IOU {iou} < {threshold}."
145                        print(msg)
146                    break
147
148            segmentation[z] = seg_z
149            z += increment
150            if stopping_criterion(z, z_stop):
151                if verbose:
152                    print(f"Segment {z_start} to {z_stop}: stop at slice {z}")
153                break
154            update_progress(1)
155
156        return z - increment
157
158    z0, z1 = int(segmented_slices.min()), int(segmented_slices.max())
159
160    # segment below the min slice
161    if z0 > 0 and not stop_lower:
162        z_min = segment_range(z0, 0, -1, np.less, iou_threshold, verbose=verbose)
163    else:
164        z_min = z0
165
166    # segment above the max slice
167    if z1 < segmentation.shape[0] - 1 and not stop_upper:
168        z_max = segment_range(z1, segmentation.shape[0] - 1, 1, np.greater, iou_threshold, verbose=verbose)
169    else:
170        z_max = z1
171
172    # segment in between min and max slice
173    if z0 != z1:
174        for z_start, z_stop in zip(segmented_slices[:-1], segmented_slices[1:]):
175            slice_diff = z_stop - z_start
176            z_mid = int((z_start + z_stop) // 2)
177
178            if slice_diff == 1:  # the slices are adjacent -> we don't need to do anything
179                pass
180
181            elif z_start == z0 and stop_lower:  # the lower slice is stop: we just segment from upper
182                segment_range(z_stop, z_start, -1, np.less_equal, verbose=verbose)
183
184            elif z_stop == z1 and stop_upper:  # the upper slice is stop: we just segment from lower
185                segment_range(z_start, z_stop, 1, np.greater_equal, verbose=verbose)
186
187            elif slice_diff == 2:  # there is only one slice in between -> use combined mask
188                z = z_start + 1
189                seg_prompt = np.logical_or(segmentation[z_start] == 1, segmentation[z_stop] == 1)
190                segmentation[z] = segment_from_mask(
191                    predictor, seg_prompt, image_embeddings=image_embeddings, i=z,
192                    use_mask=use_mask, use_box=use_box, use_points=use_points,
193                    box_extension=box_extension
194                )
195                update_progress(1)
196
197            else:  # there is a range of more than 2 slices in between -> segment ranges
198                # segment from bottom
199                segment_range(
200                    z_start, z_mid, 1, np.greater_equal if slice_diff % 2 == 0 else np.greater, verbose=verbose
201                )
202                # segment from top
203                segment_range(z_stop, z_mid, -1, np.less_equal, verbose=verbose)
204                # if the difference between start and stop is even,
205                # then we have a slice in the middle that is the same distance from top bottom
206                # in this case the slice is not segmented in the ranges above, and we segment it
207                # using the combined mask from the adjacent top and bottom slice as prompt
208                if slice_diff % 2 == 0:
209                    seg_prompt = np.logical_or(segmentation[z_mid - 1] == 1, segmentation[z_mid + 1] == 1)
210                    segmentation[z_mid] = segment_from_mask(
211                        predictor, seg_prompt, image_embeddings=image_embeddings, i=z_mid,
212                        use_mask=use_mask, use_box=use_box, use_points=use_points,
213                        box_extension=box_extension
214                    )
215                    update_progress(1)
216
217    return segmentation, (z_min, z_max)

Segment an object mask in in volumetric data.

Arguments:
  • segmentation: The initial segmentation for the object.
  • predictor: The segment anything predictor.
  • image_embeddings: The precomputed image embeddings for the volume.
  • segmented_slices: List of slices for which this object has already been segmented.
  • stop_lower: Whether to stop at the lowest segmented slice.
  • stop_upper: Wheter to stop at the topmost segmented slice.
  • iou_threshold: The IOU threshold for continuing segmentation across 3d.
  • projection: The projection method to use. One of 'box', 'mask', 'points', 'points_and_mask' or 'single point'. Pass a dictionary to choose the excact combination of projection modes.
  • update_progress: Callback to update an external progress bar.
  • box_extension: Extension factor for increasing the box size after projection.
  • verbose: Whether to print details about the segmentation steps.
Returns:

Array with the volumetric segmentation. Tuple with the first and last segmented slice.

def merge_instance_segmentation_3d( slice_segmentation: numpy.ndarray, beta: float = 0.5, with_background: bool = True, gap_closing: Optional[int] = None, min_z_extent: Optional[int] = None, verbose: bool = True, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None) -> numpy.ndarray:
283def merge_instance_segmentation_3d(
284    slice_segmentation: np.ndarray,
285    beta: float = 0.5,
286    with_background: bool = True,
287    gap_closing: Optional[int] = None,
288    min_z_extent: Optional[int] = None,
289    verbose: bool = True,
290    pbar_init: Optional[callable] = None,
291    pbar_update: Optional[callable] = None,
292) -> np.ndarray:
293    """Merge stacked 2d instance segmentations into a consistent 3d segmentation.
294
295    Solves a multicut problem based on the overlap of objects to merge across z.
296
297    Args:
298        slice_segmentation: The stacked segmentation across the slices.
299            We assume that the segmentation is labeled consecutive across z.
300        beta: The bias term for the multicut. Higher values lead to a larger
301            degree of over-segmentation and vice versa.
302        with_background: Whether this is a segmentation problem with background.
303            In that case all edges connecting to the background are set to be repulsive.
304        gap_closing: If given, gaps in the segmentation are closed with a binary closing
305            operation. The value is used to determine the number of iterations for the closing.
306        min_z_extent: Require a minimal extent in z for the segmented objects.
307            This can help to prevent segmentation artifacts.
308        verbose: Verbosity flag.
309        pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
310            Can be used together with pbar_update to handle napari progress bar in other thread.
311            To enables using this function within a threadworker.
312        pbar_update: Callback to update an external progress bar.
313
314    Returns:
315        The merged segmentation.
316    """
317    _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
318
319    if gap_closing is not None and gap_closing > 0:
320        pbar_init(slice_segmentation.shape[0] + 1, "Merge segmentation")
321        slice_segmentation = _preprocess_closing(slice_segmentation, gap_closing, pbar_update)
322    else:
323        pbar_init(1, "Merge segmentation")
324
325    # Extract the overlap between slices.
326    edges = track_utils.compute_edges_from_overlap(slice_segmentation, verbose=False)
327
328    uv_ids = np.array([[edge["source"], edge["target"]] for edge in edges])
329    overlaps = np.array([edge["score"] for edge in edges])
330
331    n_nodes = int(slice_segmentation.max() + 1)
332    graph = nifty.graph.undirectedGraph(n_nodes)
333    graph.insertEdges(uv_ids)
334
335    costs = seg_utils.multicut.compute_edge_costs(overlaps)
336    # set background weights to be maximally repulsive
337    if with_background:
338        bg_edges = (uv_ids == 0).any(axis=1)
339        costs[bg_edges] = -8.0
340
341    node_labels = seg_utils.multicut.multicut_decomposition(graph, 1.0 - costs, beta=beta)
342
343    segmentation = nifty.tools.take(node_labels, slice_segmentation)
344
345    if min_z_extent is not None and min_z_extent > 0:
346        props = regionprops(segmentation)
347        filter_ids = []
348        for prop in props:
349            box = prop.bbox
350            z_extent = box[3] - box[0]
351            if z_extent < min_z_extent:
352                filter_ids.append(prop.label)
353        if filter_ids:
354            segmentation[np.isin(segmentation, filter_ids)] = 0
355
356    pbar_update(1)
357    pbar_close()
358
359    return segmentation

Merge stacked 2d instance segmentations into a consistent 3d segmentation.

Solves a multicut problem based on the overlap of objects to merge across z.

Arguments:
  • slice_segmentation: The stacked segmentation across the slices. We assume that the segmentation is labeled consecutive across z.
  • beta: The bias term for the multicut. Higher values lead to a larger degree of over-segmentation and vice versa.
  • with_background: Whether this is a segmentation problem with background. In that case all edges connecting to the background are set to be repulsive.
  • gap_closing: If given, gaps in the segmentation are closed with a binary closing operation. The value is used to determine the number of iterations for the closing.
  • min_z_extent: Require a minimal extent in z for the segmented objects. This can help to prevent segmentation artifacts.
  • verbose: Verbosity flag.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
Returns:

The merged segmentation.

def automatic_3d_segmentation( volume: numpy.ndarray, predictor: segment_anything.predictor.SamPredictor, segmentor: micro_sam.instance_segmentation.AMGBase, embedding_path: Union[str, os.PathLike, NoneType] = None, with_background: bool = True, gap_closing: Optional[int] = None, min_z_extent: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, return_embeddings: bool = False, **kwargs) -> numpy.ndarray:
362def automatic_3d_segmentation(
363    volume: np.ndarray,
364    predictor: SamPredictor,
365    segmentor: AMGBase,
366    embedding_path: Optional[Union[str, os.PathLike]] = None,
367    with_background: bool = True,
368    gap_closing: Optional[int] = None,
369    min_z_extent: Optional[int] = None,
370    tile_shape: Optional[Tuple[int, int]] = None,
371    halo: Optional[Tuple[int, int]] = None,
372    verbose: bool = True,
373    return_embeddings: bool = False,
374    **kwargs,
375) -> np.ndarray:
376    """Segment volume in 3d.
377
378    First segments slices individually in 2d and then merges them across 3d
379    based on overlap of objects between slices.
380
381    Args:
382        volume: The input volume.
383        predictor: The SAM model.
384        segmentor: The instance segmentation class.
385        embedding_path: The path to save pre-computed embeddings.
386        with_background: Whether the segmentation has background.
387        gap_closing: If given, gaps in the segmentation are closed with a binary closing
388            operation. The value is used to determine the number of iterations for the closing.
389        min_z_extent: Require a minimal extent in z for the segmented objects.
390            This can help to prevent segmentation artifacts.
391        tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
392        halo: Overlap of the tiles for tiled prediction.
393        verbose: Verbosity flag.
394        return_embeddings: Whether to return the precomputed image embeddings.
395        kwargs: Keyword arguments for the 'generate' method of the 'segmentor'.
396
397    Returns:
398        The segmentation.
399    """
400    offset = 0
401    segmentation = np.zeros(volume.shape[:3], dtype="uint32")
402
403    min_object_size = kwargs.pop("min_object_size", 0)
404    image_embeddings = util.precompute_image_embeddings(
405        predictor=predictor,
406        input_=volume,
407        save_path=embedding_path,
408        ndim=3,
409        tile_shape=tile_shape,
410        halo=halo,
411        verbose=verbose,
412    )
413
414    for i in tqdm(range(segmentation.shape[0]), desc="Segment slices", disable=not verbose):
415        segmentor.initialize(volume[i], image_embeddings=image_embeddings, verbose=False, i=i)
416        seg = segmentor.generate(**kwargs)
417
418        if isinstance(seg, list) and len(seg) == 0:
419            continue
420        else:
421            if isinstance(seg, list):
422                seg = mask_data_to_segmentation(
423                    seg, with_background=with_background, min_object_size=min_object_size
424                )
425
426            # Set offset for instance per slice.
427            max_z = seg.max()
428            if max_z == 0:
429                continue
430            seg[seg != 0] += offset
431            offset = max_z + offset
432
433        segmentation[i] = seg
434
435    segmentation = merge_instance_segmentation_3d(
436        segmentation,
437        beta=0.5,
438        with_background=with_background,
439        gap_closing=gap_closing,
440        min_z_extent=min_z_extent,
441        verbose=verbose,
442    )
443
444    if return_embeddings:
445        return segmentation, image_embeddings
446    else:
447        return segmentation

Segment volume in 3d.

First segments slices individually in 2d and then merges them across 3d based on overlap of objects between slices.

Arguments:
  • volume: The input volume.
  • predictor: The SAM model.
  • segmentor: The instance segmentation class.
  • embedding_path: The path to save pre-computed embeddings.
  • with_background: Whether the segmentation has background.
  • gap_closing: If given, gaps in the segmentation are closed with a binary closing operation. The value is used to determine the number of iterations for the closing.
  • min_z_extent: Require a minimal extent in z for the segmented objects. This can help to prevent segmentation artifacts.
  • tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
  • halo: Overlap of the tiles for tiled prediction.
  • verbose: Verbosity flag.
  • return_embeddings: Whether to return the precomputed image embeddings.
  • kwargs: Keyword arguments for the 'generate' method of the 'segmentor'.
Returns:

The segmentation.