
  1import os
  2import multiprocessing as mp
  3from typing import Dict, List, Optional, Tuple, Union
  5import numpy as np
  7from concurrent import futures
  8from scipy.ndimage import distance_transform_edt, binary_dilation
  9from sklearn.metrics import pairwise_distances
 11from skimage.measure import regionprops
 12from skimage.draw import line_nd
 13from tqdm import tqdm
 16    import skfmm
 17except ImportError:
 18    skfmm = None
 21def compute_geodesic_distances(
 22    segmentation: np.ndarray,
 23    distance_to: np.ndarray,
 24    resolution: Optional[Union[int, float, Tuple[int, int, int]]] = None,
 25    unsigned: bool = True,
 26) -> np.ndarray:
 27    """Compute the geodesic distances between a segmentation and a distance target.
 29    This function require scikit-fmm to be installed.
 31    Args:
 32        segmentation: The binary segmentation.
 33        distance_to: The binary distance target.
 34        resolution: The voxel size of the data, used to scale the distances.
 35        unsigned: Whether to return the unsigned or signed distances.
 37    Returns:
 38        Array with the geodesic distance values.
 39    """
 40    assert skfmm is not None, "Please install scikit-fmm to use compute_geodesic_distance."
 42    invalid = segmentation == 0
 43    input_ =, mask=invalid)
 44    input_[distance_to] = 0
 46    if resolution is None:
 47        dx = 1.0
 48    elif isinstance(resolution, (int, float)):
 49        dx = float(resolution)
 50    else:
 51        assert len(resolution) == segmentation.ndim
 52        dx = resolution
 54    distances = skfmm.distance(input_, dx=dx).data
 55    distances[distances == 0] = np.inf
 56    distances[distance_to] = 0
 58    if unsigned:
 59        distances = np.abs(distances)
 61    return distances
 64def _compute_centroid_distances(segmentation, resolution, n_neighbors):
 65    props = regionprops(segmentation)
 66    centroids = np.array([prop.centroid for prop in props])
 67    if resolution is not None:
 68        scale_factor = np.array(resolution)[:, None]
 69        centroids *= scale_factor
 70    pair_distances = pairwise_distances(centroids)
 71    return pair_distances
 74def _compute_boundary_distances(segmentation, resolution, n_threads):
 76    seg_ids = np.unique(segmentation)[1:]
 77    n = len(seg_ids)
 79    pairwise_distances = np.zeros((n, n))
 80    ndim = segmentation.ndim
 81    end_points1 = np.zeros((n, n, ndim), dtype="int")
 82    end_points2 = np.zeros((n, n, ndim), dtype="int")
 84    properties = regionprops(segmentation)
 85    properties = {prop.label: prop for prop in properties}
 87    def compute_distances_for_object(i):
 89        seg_id = seg_ids[i]
 90        distances, indices = distance_transform_edt(segmentation != seg_id, return_indices=True, sampling=resolution)
 92        for j in range(len(seg_ids)):
 93            if i >= j:
 94                continue
 96            ngb_id = seg_ids[j]
 97            prop = properties[ngb_id]
 99            bb = prop.bbox
100            offset = np.array(bb[:ndim])
101            if ndim == 2:
102                bb = np.s_[bb[0]:bb[2], bb[1]:bb[3]]
103            else:
104                bb = np.s_[bb[0]:bb[3], bb[1]:bb[4], bb[2]:bb[5]]
106            mask = segmentation[bb] == ngb_id
107            ngb_dist, ngb_index = distances[bb].copy(), indices[(slice(None),) + bb]
108            ngb_dist[~mask] = np.inf
109            min_point_ngb = np.unravel_index(np.argmin(ngb_dist), shape=mask.shape)
111            min_dist = ngb_dist[min_point_ngb]
113            min_point = tuple(ind[min_point_ngb] for ind in ngb_index)
114            pairwise_distances[i, j] = min_dist
116            end_points1[i, j] = min_point
117            min_point_ngb = [off + minp for off, minp in zip(offset, min_point_ngb)]
118            end_points2[i, j] = min_point_ngb
120    n_threads = mp.cpu_count() if n_threads is None else n_threads
121    with futures.ThreadPoolExecutor(n_threads) as tp:
122        list(tqdm(
123  , range(n)), total=n, desc="Compute boundary distances"
124        ))
126    return pairwise_distances, end_points1, end_points2, seg_ids
129def measure_pairwise_object_distances(
130    segmentation: np.ndarray,
131    distance_type: str = "boundary",
132    resolution: Optional[Tuple[int, int, int]] = None,
133    n_threads: Optional[int] = None,
134    save_path: Optional[os.PathLike] = None,
135) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
136    """Compute the pairwise distances between all objects within a segmentation.
138    Args:
139        segmentation: The input segmentation.
140        distance_type: The type of distance to compute, can either be 'boundary' to
141            compute the distance between the boundary / surface of the objects or 'centroid'
142            to compute the distance between centroids.
143        resolution: The resolution / pixel size of the data.
144        n_threads: The number of threads for parallelizing the distance computation.
145        save_path: Path for saving the measurement results in numpy zipped format.
147    Returns:
148        The pairwise object distances.
149        The 'left' endpoint coordinates of the distances.
150        The 'right' endpoint coordinates of the distances.
151        The segmentation id pairs of the distances.
152    """
153    supported_distances = ("boundary", "centroid")
154    assert distance_type in supported_distances
155    if distance_type == "boundary":
156        distances, endpoints1, endpoints2, seg_ids = _compute_boundary_distances(segmentation, resolution, n_threads)
157    elif distance_type == "centroid":
158        raise NotImplementedError
159        # TODO has to be adapted
160        # distances, neighbors = _compute_centroid_distances(segmentation, resolution)
162    if save_path is not None:
163        np.savez(
164            save_path,
165            distances=distances,
166            endpoints1=endpoints1,
167            endpoints2=endpoints2,
168            seg_ids=seg_ids,
169        )
171    return distances, endpoints1, endpoints2, seg_ids
174def _compute_seg_object_distances(segmentation, segmented_object, resolution, verbose):
175    distance_map, indices = distance_transform_edt(segmented_object == 0, return_indices=True, sampling=resolution)
177    seg_ids = np.unique(segmentation)[1:].tolist()
178    n = len(seg_ids)
180    distances = np.zeros(n)
181    ndim = segmentation.ndim
182    endpoints1 = np.zeros((n, ndim), dtype="int")
183    endpoints2 = np.zeros((n, ndim), dtype="int")
185    object_ids = []
186    # We use this so often, it should be refactored.
187    props = regionprops(segmentation)
188    for prop in tqdm(props, disable=not verbose):
189        bb = prop.bbox
190        offset = np.array(bb[:ndim])
191        if ndim == 2:
192            bb = np.s_[bb[0]:bb[2], bb[1]:bb[3]]
193        else:
194            bb = np.s_[bb[0]:bb[3], bb[1]:bb[4], bb[2]:bb[5]]
196        label = prop.label
197        mask = segmentation[bb] == label
199        dist, idx = distance_map[bb].copy(), indices[(slice(None),) + bb]
200        dist[~mask] = np.inf
202        min_dist_coord = np.argmin(dist)
203        min_dist_coord = np.unravel_index(min_dist_coord, mask.shape)
204        distance = dist[min_dist_coord]
206        object_coord = tuple(idx_[min_dist_coord] for idx_ in idx)
207        object_id = segmented_object[object_coord]
208        assert object_id != 0
210        seg_idx = seg_ids.index(label)
211        distances[seg_idx] = distance
212        endpoints1[seg_idx] = object_coord
214        min_dist_coord = [off + minc for off, minc in zip(offset, min_dist_coord)]
215        endpoints2[seg_idx] = min_dist_coord
217        object_ids.append(object_id)
219    return distances, endpoints1, endpoints2, np.array(seg_ids), np.array(object_ids)
222def measure_segmentation_to_object_distances(
223    segmentation: np.ndarray,
224    segmented_object: np.ndarray,
225    distance_type: str = "boundary",
226    resolution: Optional[Tuple[int, int, int]] = None,
227    save_path: Optional[os.PathLike] = None,
228    verbose: bool = False,
229) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
230    """Compute the distance betwen all objects in a segmentation and another object.
232    Args:
233        segmentation: The input segmentation.
234        segmented_object: The segmented object.
235        distance_type: The type of distance to compute, can either be 'boundary' to
236            compute the distance between the boundary / surface of the objects or 'centroid'
237            to compute the distance between centroids.
238        resolution: The resolution / pixel size of the data.
239        save_path: Path for saving the measurement results in numpy zipped format.
240        verbose: Whether to print the progress of the distance computation.
242    Returns:
243        The segmentation to object distances.
244        The 'left' endpoint coordinates of the distances.
245        The 'right' endpoint coordinates of the distances.
246        The segmentation ids corresponding to the distances.
247    """
248    if distance_type == "boundary":
249        distances, endpoints1, endpoints2, seg_ids, object_ids = _compute_seg_object_distances(
250            segmentation, segmented_object, resolution, verbose
251        )
252        assert len(distances) == len(endpoints1) == len(endpoints2) == len(seg_ids) == len(object_ids)
253    else:
254        raise NotImplementedError
256    if save_path is not None:
257        np.savez(
258            save_path,
259            distances=distances,
260            endpoints1=endpoints1,
261            endpoints2=endpoints2,
262            seg_ids=seg_ids,
263            object_ids=object_ids,
264        )
265    return distances, endpoints1, endpoints2, seg_ids
268def _extract_nearest_neighbors(pairwise_distances, seg_ids, n_neighbors, remove_duplicates=True):
269    distance_matrix = pairwise_distances.copy()
271    # Set the diagonal (distance to self) to infinity.
272    distance_matrix[np.diag_indices(len(distance_matrix))] = np.inf
273    # Mirror the distances.
274    # (We only compute upper triangle, but need to take all distances into account here)
275    tril_indices = np.tril_indices_from(distance_matrix)
276    distance_matrix[tril_indices] = distance_matrix.T[tril_indices]
278    neighbor_distances = np.sort(distance_matrix, axis=1)[:, :n_neighbors]
279    neighbor_indices = np.argsort(distance_matrix, axis=1)[:, :n_neighbors]
281    pairs = []
282    for i, (dists, inds) in enumerate(zip(neighbor_distances, neighbor_indices)):
283        seg_id = seg_ids[i]
284        ngb_ids = [seg_ids[j] for j, dist in zip(inds, dists) if np.isfinite(dist)]
285        pairs.extend([[min(seg_id, ngb_id), max(seg_id, ngb_id)] for ngb_id in ngb_ids])
287    pairs = np.array(pairs)
288    pairs = np.sort(pairs, axis=1)
289    if remove_duplicates:
290        pairs = np.unique(pairs, axis=0)
291    return pairs
294def load_distances(
295    measurement_path: os.PathLike
296) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
297    """Load the saved distacnes from a zipped numpy file.
299    Args:
300        measurement_path: The path where the distances where saved.
302    Returns:
303        The segmentation to object distances.
304        The 'left' endpoint coordinates of the distances.
305        The 'right' endpoint coordinates of the distances.
306        The segmentation ids corresponding to the distances.
307    """
308    auto_dists = np.load(measurement_path)
309    distances, seg_ids = auto_dists["distances"], list(auto_dists["seg_ids"])
310    endpoints1, endpoints2 = auto_dists["endpoints1"], auto_dists["endpoints2"]
311    return distances, endpoints1, endpoints2, seg_ids
314def create_pairwise_distance_lines(
315    distances: np.ndarray,
316    endpoints1: np.ndarray,
317    endpoints2: np.ndarray,
318    seg_ids: List[List[int]],
319    n_neighbors: Optional[int] = None,
320    pairs: Optional[np.ndarray] = None,
321    bb: Optional[Tuple[slice]] = None,
322    scale: Optional[float] = None,
323    remove_duplicates: bool = True
324) -> Tuple[np.ndarray, Dict]:
325    """Create a line representation of pair-wise object distances for display in napari.
327    Args:
328        distances: The pairwise distances.
329        endpoints1: One set of distance end points.
330        endpoints2: The other set of distance end points.
331        seg_ids: The segmentation pair corresponding to each distance.
332        n_neighbors: The number of nearest neighbors to take into consideration
333            for creating the distance lines.
334        pairs: Optional list of ids to use for creating the distance lines.
335        bb: Bounding box for restricing the distance line creation.
336        scale: Scale factor for resizing the distance lines.
337            Use this if the corresponding segmentations were downscaled for visualization.
338        remove_duplicates: Remove duplicate id pairs from the distance lines.
340    Returns:
341        The lines for plotting in napari.
342        Additional attributes for the line layer in napari.
343    """
344    if pairs is None and n_neighbors is not None:
345        pairs = _extract_nearest_neighbors(distances, seg_ids, n_neighbors, remove_duplicates=remove_duplicates)
346    elif pairs is None:
347        pairs = [[id1, id2] for id1 in seg_ids for id2 in seg_ids if id1 < id2]
349    assert pairs is not None
350    pair_indices = (
351        np.array([seg_ids.index(pair[0]) for pair in pairs]),
352        np.array([seg_ids.index(pair[1]) for pair in pairs])
353    )
355    pairs = np.array(pairs)
356    distances = distances[pair_indices]
357    endpoints1 = endpoints1[pair_indices]
358    endpoints2 = endpoints2[pair_indices]
360    if bb is not None:
361        in_bb = np.where(
362            (endpoints1[:, 0] > bb[0].start) & (endpoints1[:, 0] < bb[0].stop) &
363            (endpoints1[:, 1] > bb[1].start) & (endpoints1[:, 1] < bb[1].stop) &
364            (endpoints1[:, 2] > bb[2].start) & (endpoints1[:, 2] < bb[2].stop) &
365            (endpoints2[:, 0] > bb[0].start) & (endpoints2[:, 0] < bb[0].stop) &
366            (endpoints2[:, 1] > bb[1].start) & (endpoints2[:, 1] < bb[1].stop) &
367            (endpoints2[:, 2] > bb[2].start) & (endpoints2[:, 2] < bb[2].stop)
368        )
370        pairs = pairs[in_bb]
371        distances, endpoints1, endpoints2 = distances[in_bb], endpoints1[in_bb], endpoints2[in_bb]
373        offset = np.array([b.start for b in bb])[None]
374        endpoints1 -= offset
375        endpoints2 -= offset
377    lines = np.array([[start, end] for start, end in zip(endpoints1, endpoints2)])
379    if scale is not None:
380        scale_factor = np.array(3 * [scale])[None, None]
381        lines //= scale_factor
383    properties = {
384        "id_a": pairs[:, 0],
385        "id_b": pairs[:, 1],
386        "distance": np.round(distances, 2),
387    }
388    return lines, properties
391def create_object_distance_lines(
392    distances: np.ndarray,
393    endpoints1: np.ndarray,
394    endpoints2: np.ndarray,
395    seg_ids: np.ndarray,
396    max_distance: Optional[float] = None,
397    filter_seg_ids: Optional[np.ndarray] = None,
398    scale: Optional[float] = None,
399) -> Tuple[np.ndarray, Dict]:
400    """Create a line representation of object distances for display in napari.
402    Args:
403        distances: The measurd distances.
404        endpoints1: One set of distance end points.
405        endpoints2: The other set of distance end points.
406        seg_ids: The segmentation ids corresponding to each distance.
407        max_distance: Maximal distance for drawing the distance line.
408        filter_seg_ids: Segmentation ids to restrict the distance lines.
409        scale: Scale factor for resizing the distance lines.
410            Use this if the corresponding segmentations were downscaled for visualization.
412    Returns:
413        The lines for plotting in napari.
414        Additional attributes for the line layer in napari.
415    """
417    if filter_seg_ids is not None:
418        id_mask = np.isin(seg_ids, filter_seg_ids)
419        distances = distances[id_mask]
420        endpoints1, endpoints2 = endpoints1[id_mask], endpoints2[id_mask]
421        seg_ids = filter_seg_ids
423    if max_distance is not None:
424        distance_mask = distances <= max_distance
425        distances, seg_ids = distances[distance_mask], seg_ids[distance_mask]
426        endpoints1, endpoints2 = endpoints1[distance_mask], endpoints2[distance_mask]
428    assert len(distances) == len(seg_ids) == len(endpoints1) == len(endpoints2)
429    lines = np.array([[start, end] for start, end in zip(endpoints1, endpoints2)])
431    if scale is not None and len(lines > 0):
432        scale_factor = np.array(3 * [scale])[None, None]
433        lines //= scale_factor
435    properties = {"id": seg_ids, "distance": np.round(distances, 2)}
436    return lines, properties
439def keep_direct_distances(
440    segmentation: np.ndarray,
441    distances: np.ndarray,
442    endpoints1: np.ndarray,
443    endpoints2: np.ndarray,
444    seg_ids: np.ndarray,
445    line_dilation: int = 0,
446    scale: Optional[Tuple[int, int, int]] = None,
447) -> List[List[int]]:
448    """Filter out all distances that are not direct; distances that are occluded by another segmented object.
450    Args:
451        segmentation: The segmentation from which the distances are derived.
452        distances: The measurd distances.
453        endpoints1: One set of distance end points.
454        endpoints2: The other set of distance end points.
455        seg_ids: The segmentation ids corresponding to each distance.
456        line_dilation: Dilation factor of the distance lines for determining occlusions.
457        scale: Scaling factor of the segmentation compared to the distance measurements.
459    Returns:
460        The list of id pairs that are kept.
461    """
462    distance_lines, properties = create_object_distance_lines(
463        distances, endpoints1, endpoints2, seg_ids, scale=scale
464    )
466    ids_a, ids_b = properties["id_a"], properties["id_b"]
467    filtered_ids_a, filtered_ids_b = [], []
469    for i, line in tqdm(enumerate(distance_lines), total=len(distance_lines)):
470        id_a, id_b = ids_a[i], ids_b[i]
472        start, stop = line
473        line = line_nd(start, stop, endpoint=True)
475        if line_dilation > 0:
476            # TODO make this more efficient, ideally by dilating the mask coordinates
477            # instead of dilating the actual mask.
478            # We turn the line into a binary mask and dilate it to have some tolerance.
479            line_vol = np.zeros_like(segmentation)
480            line_vol[line] = 1
481            line_vol = binary_dilation(line_vol, iterations=line_dilation)
482        else:
483            line_vol = line
485        # Check if we cross any other segments:
486        # Extract the unique ids in the segmentation that overlap with the segmentation.
487        # We count this as a direct distance if no other object overlaps with the line.
488        line_seg_ids = np.unique(segmentation[line_vol])
489        line_seg_ids = np.setdiff1d(line_seg_ids, [0, id_a, id_b])
491        if len(line_seg_ids) == 0:  # No other objet is overlapping, we keep the line.
492            filtered_ids_a.append(id_a)
493            filtered_ids_b.append(id_b)
495    print("Keeping", len(filtered_ids_a), "/", len(ids_a), "distance pairs")
496    filtered_pairs = [[ida, idb] for ida, idb in zip(filtered_ids_a, filtered_ids_b)]
497    return filtered_pairs
500def filter_blocked_segmentation_to_object_distances(
501    segmentation: np.ndarray,
502    distances: np.ndarray,
503    endpoints1: np.ndarray,
504    endpoints2: np.ndarray,
505    seg_ids: np.ndarray,
506    line_dilation: int = 0,
507    scale: Optional[Tuple[int, int, int]] = None,
508    filter_seg_ids: Optional[List[int]] = None,
509    verbose: bool = False,
510) -> List[int]:
511    """Filter out all distances that are not direct; distances that are occluded by another segmented object.
513    Args:
514        segmentation: The segmentation from which the distances are derived.
515        distances: The measurd distances.
516        endpoints1: One set of distance end points.
517        endpoints2: The other set of distance end points.
518        seg_ids: The segmentation ids corresponding to each distance.
519        line_dilation: Dilation factor of the distance lines for determining occlusions.
520        scale: Scaling factor of the segmentation compared to the distance measurements.
521        filter_seg_ids: Segmentation ids to restrict the distance lines.
522        verbose: Whether to print progressbar.
524    Returns:
525        The list of id pairs that are kept.
526    """
527    distance_lines, properties = create_object_distance_lines(
528         distances, endpoints1, endpoints2, seg_ids, scale=scale
529    )
530    all_seg_ids = properties["id"]
532    filtered_ids = []
533    for seg_id, line in tqdm(zip(all_seg_ids, distance_lines), total=len(distance_lines), disable=not verbose):
534        if (seg_ids is not None) and (seg_id not in seg_ids):
535            continue
537        start, stop = line
538        line = line_nd(start, stop, endpoint=True)
540        if line_dilation > 0:
541            # TODO make this more efficient, ideally by dilating the mask coordinates
542            # instead of dilating the actual mask.
543            # We turn the line into a binary mask and dilate it to have some tolerance.
544            line_vol = np.zeros_like(segmentation)
545            line_vol[line] = 1
546            line_vol = binary_dilation(line_vol, iterations=line_dilation)
547        else:
548            line_vol = line
550        # Check if we cross any other segments:
551        # Extract the unique ids in the segmentation that overlap with the segmentation.
552        # We count this as a direct distance if no other object overlaps with the line.
553        line_seg_ids = np.unique(segmentation[line_vol])
554        line_seg_ids = np.setdiff1d(line_seg_ids, [0, seg_id])
556        if len(line_seg_ids) == 0:  # No other objet is overlapping, we keep the line.
557            filtered_ids.append(seg_id)
559    return filtered_ids
