synapse_net.training.transform

 1import numpy as np
 2from torch_em.transform.label import labels_to_binary
 3from scipy.ndimage import distance_transform_edt
 4
 5
 6class AZDistanceLabelTransform:
 7    def __init__(self, max_distance: float = 50.0):
 8        self.max_distance = max_distance
 9
10    def __call__(self, input_):
11        binary_target = labels_to_binary(input_).astype("float32")
12        if binary_target.sum() == 0:
13            distances = np.ones_like(binary_target, dtype="float32")
14        else:
15            distances = distance_transform_edt(binary_target == 0)
16            distances = np.clip(distances, 0.0, self.max_distance)
17            distances /= self.max_distance
18        return np.stack([binary_target, distances])
class AZDistanceLabelTransform:
 7class AZDistanceLabelTransform:
 8    def __init__(self, max_distance: float = 50.0):
 9        self.max_distance = max_distance
10
11    def __call__(self, input_):
12        binary_target = labels_to_binary(input_).astype("float32")
13        if binary_target.sum() == 0:
14            distances = np.ones_like(binary_target, dtype="float32")
15        else:
16            distances = distance_transform_edt(binary_target == 0)
17            distances = np.clip(distances, 0.0, self.max_distance)
18            distances /= self.max_distance
19        return np.stack([binary_target, distances])
AZDistanceLabelTransform(max_distance: float = 50.0)
8    def __init__(self, max_distance: float = 50.0):
9        self.max_distance = max_distance
max_distance