synapse_net.training.transform
1import numpy as np 2from torch_em.transform.label import labels_to_binary 3from scipy.ndimage import distance_transform_edt 4 5 6class AZDistanceLabelTransform: 7 def __init__(self, max_distance: float = 50.0): 8 self.max_distance = max_distance 9 10 def __call__(self, input_): 11 binary_target = labels_to_binary(input_).astype("float32") 12 if binary_target.sum() == 0: 13 distances = np.ones_like(binary_target, dtype="float32") 14 else: 15 distances = distance_transform_edt(binary_target == 0) 16 distances = np.clip(distances, 0.0, self.max_distance) 17 distances /= self.max_distance 18 return np.stack([binary_target, distances])
class
AZDistanceLabelTransform:
7class AZDistanceLabelTransform: 8 def __init__(self, max_distance: float = 50.0): 9 self.max_distance = max_distance 10 11 def __call__(self, input_): 12 binary_target = labels_to_binary(input_).astype("float32") 13 if binary_target.sum() == 0: 14 distances = np.ones_like(binary_target, dtype="float32") 15 else: 16 distances = distance_transform_edt(binary_target == 0) 17 distances = np.clip(distances, 0.0, self.max_distance) 18 distances /= self.max_distance 19 return np.stack([binary_target, distances])