micro_sam.training.util
1import os 2from math import ceil, floor 3from typing import Dict, List, Optional, Union 4 5import numpy as np 6 7import torch 8 9from segment_anything.utils.transforms import ResizeLongestSide 10 11from ..prompt_generators import PointAndBoxPromptGenerator 12from ..util import ( 13 get_centers_and_bounding_boxes, get_sam_model, get_device, 14 segmentation_to_one_hot, _DEFAULT_MODEL, 15) 16from .. import models as custom_models 17from .trainable_sam import TrainableSAM 18 19from torch_em.transform.label import PerObjectDistanceTransform 20from torch_em.transform.raw import normalize_percentile, normalize 21 22 23def identity(x): 24 """Identity transformation. 25 26 This is a helper function to skip data normalization when finetuning SAM. 27 Data normalization is performed within the model and should thus be skipped as 28 a preprocessing step in training. 29 """ 30 return x 31 32 33def require_8bit(x): 34 """Transformation to require 8bit input data range (0-255). 35 """ 36 if x.max() < 1: 37 x = x * 255 38 return x 39 40 41def get_trainable_sam_model( 42 model_type: str = _DEFAULT_MODEL, 43 device: Optional[Union[str, torch.device]] = None, 44 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 45 freeze: Optional[List[str]] = None, 46 return_state: bool = False, 47 peft_kwargs: Optional[Dict] = None, 48 flexible_load_checkpoint: bool = False, 49 **model_kwargs 50) -> TrainableSAM: 51 """Get the trainable sam model. 52 53 Args: 54 model_type: The segment anything model that should be finetuned. 55 The weights of this model will be used for initialization, unless a 56 custom weight file is passed via `checkpoint_path`. 57 device: The device to use for training. 58 checkpoint_path: Path to a custom checkpoint from which to load the model weights. 59 freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder 60 By default nothing is frozen and the full model is updated. 61 return_state: Whether to return the full checkpoint state. 62 peft_kwargs: Keyword arguments for the PEFT wrapper class. 63 flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. 64 model_kwargs: Additional keyword arguments for the `util.get_sam_model`. 65 66 Returns: 67 The trainable segment anything model. 68 """ 69 # set the device here so that the correct one is passed to TrainableSAM below 70 device = get_device(device) 71 _, sam, state = get_sam_model( 72 model_type=model_type, 73 device=device, 74 checkpoint_path=checkpoint_path, 75 return_sam=True, 76 return_state=True, 77 flexible_load_checkpoint=flexible_load_checkpoint, 78 **model_kwargs 79 ) 80 81 # NOTE: This is done exclusive to "get_sam_model" here to use PEFT's layer-specific initialization on top. 82 # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. 83 # Overwrites the SAM model by freezing the backbone and allow PEFT methods. 84 if peft_kwargs and isinstance(peft_kwargs, dict): 85 if model_type[:5] == "vit_t": 86 raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.") 87 88 sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam 89 90 # freeze components of the model if freeze was passed 91 # ideally we would want to add components in such a way that: 92 # - we would be able to freeze the choice of encoder/decoder blocks, yet be able to add components to the network 93 # (for e.g. encoder blocks to "image_encoder") 94 if freeze is not None: 95 for name, param in sam.named_parameters(): 96 if not isinstance(freeze, list): 97 # we "freeze" only for one specific component when passed a "particular" part 98 freeze = [freeze] 99 100 # we would want to "freeze" all the components in the model if passed a list of parts 101 for l_item in freeze: 102 # in case PEFT is switched on, we cannot freeze the image encoder 103 if (peft_kwargs and peft_kwargs.get('rank') is not None) and (l_item == "image_encoder"): 104 raise ValueError("You cannot use PEFT & freeze the image encoder at the same time.") 105 106 if name.startswith(f"{l_item}"): 107 param.requires_grad = False 108 109 # convert to trainable sam 110 trainable_sam = TrainableSAM(sam) 111 112 if return_state: 113 return trainable_sam, state 114 return trainable_sam 115 116 117class ConvertToSamInputs: 118 """Convert outputs of data loader to the expected batched inputs of the SegmentAnything model. 119 120 Args: 121 transform: The transformation to resize the prompts. Should be the same transform used in the 122 model to resize the inputs. If `None` the prompts will not be resized. 123 dilation_strength: The dilation factor. 124 It determines a "safety" border from which prompts are not sampled to avoid ambiguous prompts 125 due to imprecise groundtruth masks. 126 box_distortion_factor: Factor for distorting the box annotations derived from the groundtruth masks. 127 """ 128 def __init__( 129 self, 130 transform: Optional[ResizeLongestSide], 131 dilation_strength: int = 10, 132 box_distortion_factor: Optional[float] = None, 133 ) -> None: 134 self.dilation_strength = dilation_strength 135 self.transform = identity if transform is None else transform 136 self.box_distortion_factor = box_distortion_factor 137 138 def _distort_boxes(self, bbox_coordinates, shape): 139 distorted_boxes = [] 140 for bbox in bbox_coordinates: 141 # The bounding box is parametrized by y0, x0, y1, x1. 142 y0, x0, y1, x1 = bbox 143 ly, lx = y1 - y0, x1 - x0 144 y0 = int(round(max(0, y0 - np.random.uniform(0, self.box_distortion_factor) * ly))) 145 y1 = int(round(min(shape[0], y1 + np.random.uniform(0, self.box_distortion_factor) * ly))) 146 x0 = int(round(max(0, x0 - np.random.uniform(0, self.box_distortion_factor) * lx))) 147 x1 = int(round(min(shape[1], x1 + np.random.uniform(0, self.box_distortion_factor) * lx))) 148 distorted_boxes.append([y0, x0, y1, x1]) 149 return distorted_boxes 150 151 def _get_prompt_lists(self, gt, n_samples, prompt_generator): 152 """Returns a list of "expected" prompts subjected to the random input attributes for prompting.""" 153 154 _, bbox_coordinates = get_centers_and_bounding_boxes(gt, mode="p") 155 156 # get the segment ids 157 cell_ids = np.unique(gt)[1:] 158 if n_samples is None: # n-samples is set to None, so we use all ids 159 sampled_cell_ids = cell_ids 160 161 else: # n-samples is set, so we subsample the cell ids 162 sampled_cell_ids = np.random.choice(cell_ids, size=min(n_samples, len(cell_ids)), replace=False) 163 sampled_cell_ids = np.sort(sampled_cell_ids) 164 165 # only keep the bounding boxes for sampled cell ids 166 bbox_coordinates = [bbox_coordinates[sampled_id] for sampled_id in sampled_cell_ids] 167 if self.box_distortion_factor is not None: 168 bbox_coordinates = self._distort_boxes(bbox_coordinates, shape=gt.shape[-2:]) 169 170 # convert the gt to the one-hot-encoded masks for the sampled cell ids 171 object_masks = segmentation_to_one_hot(gt, None if n_samples is None else sampled_cell_ids) 172 173 # derive and return the prompts 174 point_prompts, point_label_prompts, box_prompts, _ = prompt_generator(object_masks, bbox_coordinates) 175 return box_prompts, point_prompts, point_label_prompts, sampled_cell_ids 176 177 def __call__(self, x, y, n_pos, n_neg, get_boxes=False, n_samples=None): 178 """Convert the outputs of dataloader and prompt settings to the batch format expected by SAM. 179 """ 180 # condition to see if we get point prompts, then we (ofc) use point-prompting 181 # else we don't use point prompting 182 if n_pos == 0 and n_neg == 0: 183 get_points = False 184 else: 185 get_points = True 186 187 # keeping the solution open by checking for deterministic/dynamic choice of point prompts 188 prompt_generator = PointAndBoxPromptGenerator(n_positive_points=n_pos, 189 n_negative_points=n_neg, 190 dilation_strength=self.dilation_strength, 191 get_box_prompts=get_boxes, 192 get_point_prompts=get_points) 193 194 batched_inputs = [] 195 batched_sampled_cell_ids_list = [] 196 197 for image, gt in zip(x, y): 198 gt = gt.squeeze().numpy().astype(np.int64) 199 box_prompts, point_prompts, point_label_prompts, sampled_cell_ids = self._get_prompt_lists( 200 gt, n_samples, prompt_generator, 201 ) 202 203 # check to be sure about the expected size of the no. of elements in different settings 204 if get_boxes: 205 assert len(sampled_cell_ids) == len(box_prompts), f"{len(sampled_cell_ids)}, {len(box_prompts)}" 206 207 if get_points: 208 assert len(sampled_cell_ids) == len(point_prompts) == len(point_label_prompts), \ 209 f"{len(sampled_cell_ids)}, {len(point_prompts)}, {len(point_label_prompts)}" 210 211 batched_sampled_cell_ids_list.append(sampled_cell_ids) 212 213 batched_input = {"image": image, "original_size": image.shape[1:]} 214 if get_boxes: 215 batched_input["boxes"] = self.transform.apply_boxes_torch( 216 box_prompts, original_size=gt.shape[-2:] 217 ) if self.transform is not None else box_prompts 218 if get_points: 219 batched_input["point_coords"] = self.transform.apply_coords_torch( 220 point_prompts, original_size=gt.shape[-2:] 221 ) if self.transform is not None else point_prompts 222 batched_input["point_labels"] = point_label_prompts 223 224 batched_inputs.append(batched_input) 225 226 return batched_inputs, batched_sampled_cell_ids_list 227 228 229class ConvertToSemanticSamInputs: 230 """Convert outputs of data loader to the expected batched inputs of the SegmentAnything model 231 for semantic segmentation. 232 """ 233 def __call__(self, x, y): 234 """Convert the outputs of dataloader to the batched format of inputs expected by SAM. 235 """ 236 batched_inputs = [] 237 for image, gt in zip(x, y): 238 batched_input = {"image": image, "original_size": image.shape[-2:]} 239 batched_inputs.append(batched_input) 240 241 return batched_inputs 242 243 244# 245# Raw and Label Transformations for the Generalist and Specialist finetuning 246# 247 248 249def normalize_to_8bit(raw): 250 raw = normalize(raw) * 255 251 return raw 252 253 254class ResizeRawTrafo: 255 def __init__(self, desired_shape, do_rescaling=False, padding="constant"): 256 self.desired_shape = desired_shape 257 self.padding = padding 258 self.do_rescaling = do_rescaling 259 260 def __call__(self, raw): 261 if self.do_rescaling: 262 raw = normalize_percentile(raw, axis=(1, 2)) 263 raw = np.mean(raw, axis=0) 264 raw = normalize(raw) 265 raw = raw * 255 266 267 tmp_ddim = (self.desired_shape[0] - raw.shape[0], self.desired_shape[1] - raw.shape[1]) 268 ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2) 269 raw = np.pad( 270 raw, 271 pad_width=((ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1]))), 272 mode=self.padding 273 ) 274 assert raw.shape == self.desired_shape 275 return raw 276 277 278class ResizeLabelTrafo: 279 def __init__(self, desired_shape, padding="constant", min_size=0): 280 self.desired_shape = desired_shape 281 self.padding = padding 282 self.min_size = min_size 283 284 def __call__(self, labels): 285 distance_trafo = PerObjectDistanceTransform( 286 distances=True, boundary_distances=True, directed_distances=False, 287 foreground=True, instances=True, min_size=self.min_size 288 ) 289 labels = distance_trafo(labels) 290 291 # choosing H and W from labels (4, H, W), from above dist trafo outputs 292 tmp_ddim = (self.desired_shape[0] - labels.shape[1], self.desired_shape[0] - labels.shape[2]) 293 ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2) 294 labels = np.pad( 295 labels, 296 pad_width=((0, 0), (ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1]))), 297 mode=self.padding 298 ) 299 assert labels.shape[1:] == self.desired_shape, labels.shape 300 return labels
def
identity(x):
24def identity(x): 25 """Identity transformation. 26 27 This is a helper function to skip data normalization when finetuning SAM. 28 Data normalization is performed within the model and should thus be skipped as 29 a preprocessing step in training. 30 """ 31 return x
Identity transformation.
This is a helper function to skip data normalization when finetuning SAM. Data normalization is performed within the model and should thus be skipped as a preprocessing step in training.
def
require_8bit(x):
34def require_8bit(x): 35 """Transformation to require 8bit input data range (0-255). 36 """ 37 if x.max() < 1: 38 x = x * 255 39 return x
Transformation to require 8bit input data range (0-255).
def
get_trainable_sam_model( model_type: str = 'vit_l', device: Union[str, torch.device, NoneType] = None, checkpoint_path: Union[str, os.PathLike, NoneType] = None, freeze: Optional[List[str]] = None, return_state: bool = False, peft_kwargs: Optional[Dict] = None, flexible_load_checkpoint: bool = False, **model_kwargs) -> micro_sam.training.trainable_sam.TrainableSAM:
42def get_trainable_sam_model( 43 model_type: str = _DEFAULT_MODEL, 44 device: Optional[Union[str, torch.device]] = None, 45 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 46 freeze: Optional[List[str]] = None, 47 return_state: bool = False, 48 peft_kwargs: Optional[Dict] = None, 49 flexible_load_checkpoint: bool = False, 50 **model_kwargs 51) -> TrainableSAM: 52 """Get the trainable sam model. 53 54 Args: 55 model_type: The segment anything model that should be finetuned. 56 The weights of this model will be used for initialization, unless a 57 custom weight file is passed via `checkpoint_path`. 58 device: The device to use for training. 59 checkpoint_path: Path to a custom checkpoint from which to load the model weights. 60 freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder 61 By default nothing is frozen and the full model is updated. 62 return_state: Whether to return the full checkpoint state. 63 peft_kwargs: Keyword arguments for the PEFT wrapper class. 64 flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. 65 model_kwargs: Additional keyword arguments for the `util.get_sam_model`. 66 67 Returns: 68 The trainable segment anything model. 69 """ 70 # set the device here so that the correct one is passed to TrainableSAM below 71 device = get_device(device) 72 _, sam, state = get_sam_model( 73 model_type=model_type, 74 device=device, 75 checkpoint_path=checkpoint_path, 76 return_sam=True, 77 return_state=True, 78 flexible_load_checkpoint=flexible_load_checkpoint, 79 **model_kwargs 80 ) 81 82 # NOTE: This is done exclusive to "get_sam_model" here to use PEFT's layer-specific initialization on top. 83 # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. 84 # Overwrites the SAM model by freezing the backbone and allow PEFT methods. 85 if peft_kwargs and isinstance(peft_kwargs, dict): 86 if model_type[:5] == "vit_t": 87 raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.") 88 89 sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam 90 91 # freeze components of the model if freeze was passed 92 # ideally we would want to add components in such a way that: 93 # - we would be able to freeze the choice of encoder/decoder blocks, yet be able to add components to the network 94 # (for e.g. encoder blocks to "image_encoder") 95 if freeze is not None: 96 for name, param in sam.named_parameters(): 97 if not isinstance(freeze, list): 98 # we "freeze" only for one specific component when passed a "particular" part 99 freeze = [freeze] 100 101 # we would want to "freeze" all the components in the model if passed a list of parts 102 for l_item in freeze: 103 # in case PEFT is switched on, we cannot freeze the image encoder 104 if (peft_kwargs and peft_kwargs.get('rank') is not None) and (l_item == "image_encoder"): 105 raise ValueError("You cannot use PEFT & freeze the image encoder at the same time.") 106 107 if name.startswith(f"{l_item}"): 108 param.requires_grad = False 109 110 # convert to trainable sam 111 trainable_sam = TrainableSAM(sam) 112 113 if return_state: 114 return trainable_sam, state 115 return trainable_sam
Get the trainable sam model.
Arguments:
- model_type: The segment anything model that should be finetuned.
The weights of this model will be used for initialization, unless a
custom weight file is passed via
checkpoint_path
. - device: The device to use for training.
- checkpoint_path: Path to a custom checkpoint from which to load the model weights.
- freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder By default nothing is frozen and the full model is updated.
- return_state: Whether to return the full checkpoint state.
- peft_kwargs: Keyword arguments for the PEFT wrapper class.
- flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.
- model_kwargs: Additional keyword arguments for the
util.get_sam_model
.
Returns:
The trainable segment anything model.
class
ConvertToSamInputs:
118class ConvertToSamInputs: 119 """Convert outputs of data loader to the expected batched inputs of the SegmentAnything model. 120 121 Args: 122 transform: The transformation to resize the prompts. Should be the same transform used in the 123 model to resize the inputs. If `None` the prompts will not be resized. 124 dilation_strength: The dilation factor. 125 It determines a "safety" border from which prompts are not sampled to avoid ambiguous prompts 126 due to imprecise groundtruth masks. 127 box_distortion_factor: Factor for distorting the box annotations derived from the groundtruth masks. 128 """ 129 def __init__( 130 self, 131 transform: Optional[ResizeLongestSide], 132 dilation_strength: int = 10, 133 box_distortion_factor: Optional[float] = None, 134 ) -> None: 135 self.dilation_strength = dilation_strength 136 self.transform = identity if transform is None else transform 137 self.box_distortion_factor = box_distortion_factor 138 139 def _distort_boxes(self, bbox_coordinates, shape): 140 distorted_boxes = [] 141 for bbox in bbox_coordinates: 142 # The bounding box is parametrized by y0, x0, y1, x1. 143 y0, x0, y1, x1 = bbox 144 ly, lx = y1 - y0, x1 - x0 145 y0 = int(round(max(0, y0 - np.random.uniform(0, self.box_distortion_factor) * ly))) 146 y1 = int(round(min(shape[0], y1 + np.random.uniform(0, self.box_distortion_factor) * ly))) 147 x0 = int(round(max(0, x0 - np.random.uniform(0, self.box_distortion_factor) * lx))) 148 x1 = int(round(min(shape[1], x1 + np.random.uniform(0, self.box_distortion_factor) * lx))) 149 distorted_boxes.append([y0, x0, y1, x1]) 150 return distorted_boxes 151 152 def _get_prompt_lists(self, gt, n_samples, prompt_generator): 153 """Returns a list of "expected" prompts subjected to the random input attributes for prompting.""" 154 155 _, bbox_coordinates = get_centers_and_bounding_boxes(gt, mode="p") 156 157 # get the segment ids 158 cell_ids = np.unique(gt)[1:] 159 if n_samples is None: # n-samples is set to None, so we use all ids 160 sampled_cell_ids = cell_ids 161 162 else: # n-samples is set, so we subsample the cell ids 163 sampled_cell_ids = np.random.choice(cell_ids, size=min(n_samples, len(cell_ids)), replace=False) 164 sampled_cell_ids = np.sort(sampled_cell_ids) 165 166 # only keep the bounding boxes for sampled cell ids 167 bbox_coordinates = [bbox_coordinates[sampled_id] for sampled_id in sampled_cell_ids] 168 if self.box_distortion_factor is not None: 169 bbox_coordinates = self._distort_boxes(bbox_coordinates, shape=gt.shape[-2:]) 170 171 # convert the gt to the one-hot-encoded masks for the sampled cell ids 172 object_masks = segmentation_to_one_hot(gt, None if n_samples is None else sampled_cell_ids) 173 174 # derive and return the prompts 175 point_prompts, point_label_prompts, box_prompts, _ = prompt_generator(object_masks, bbox_coordinates) 176 return box_prompts, point_prompts, point_label_prompts, sampled_cell_ids 177 178 def __call__(self, x, y, n_pos, n_neg, get_boxes=False, n_samples=None): 179 """Convert the outputs of dataloader and prompt settings to the batch format expected by SAM. 180 """ 181 # condition to see if we get point prompts, then we (ofc) use point-prompting 182 # else we don't use point prompting 183 if n_pos == 0 and n_neg == 0: 184 get_points = False 185 else: 186 get_points = True 187 188 # keeping the solution open by checking for deterministic/dynamic choice of point prompts 189 prompt_generator = PointAndBoxPromptGenerator(n_positive_points=n_pos, 190 n_negative_points=n_neg, 191 dilation_strength=self.dilation_strength, 192 get_box_prompts=get_boxes, 193 get_point_prompts=get_points) 194 195 batched_inputs = [] 196 batched_sampled_cell_ids_list = [] 197 198 for image, gt in zip(x, y): 199 gt = gt.squeeze().numpy().astype(np.int64) 200 box_prompts, point_prompts, point_label_prompts, sampled_cell_ids = self._get_prompt_lists( 201 gt, n_samples, prompt_generator, 202 ) 203 204 # check to be sure about the expected size of the no. of elements in different settings 205 if get_boxes: 206 assert len(sampled_cell_ids) == len(box_prompts), f"{len(sampled_cell_ids)}, {len(box_prompts)}" 207 208 if get_points: 209 assert len(sampled_cell_ids) == len(point_prompts) == len(point_label_prompts), \ 210 f"{len(sampled_cell_ids)}, {len(point_prompts)}, {len(point_label_prompts)}" 211 212 batched_sampled_cell_ids_list.append(sampled_cell_ids) 213 214 batched_input = {"image": image, "original_size": image.shape[1:]} 215 if get_boxes: 216 batched_input["boxes"] = self.transform.apply_boxes_torch( 217 box_prompts, original_size=gt.shape[-2:] 218 ) if self.transform is not None else box_prompts 219 if get_points: 220 batched_input["point_coords"] = self.transform.apply_coords_torch( 221 point_prompts, original_size=gt.shape[-2:] 222 ) if self.transform is not None else point_prompts 223 batched_input["point_labels"] = point_label_prompts 224 225 batched_inputs.append(batched_input) 226 227 return batched_inputs, batched_sampled_cell_ids_list
Convert outputs of data loader to the expected batched inputs of the SegmentAnything model.
Arguments:
- transform: The transformation to resize the prompts. Should be the same transform used in the
model to resize the inputs. If
None
the prompts will not be resized. - dilation_strength: The dilation factor. It determines a "safety" border from which prompts are not sampled to avoid ambiguous prompts due to imprecise groundtruth masks.
- box_distortion_factor: Factor for distorting the box annotations derived from the groundtruth masks.
ConvertToSamInputs( transform: Optional[segment_anything.utils.transforms.ResizeLongestSide], dilation_strength: int = 10, box_distortion_factor: Optional[float] = None)
129 def __init__( 130 self, 131 transform: Optional[ResizeLongestSide], 132 dilation_strength: int = 10, 133 box_distortion_factor: Optional[float] = None, 134 ) -> None: 135 self.dilation_strength = dilation_strength 136 self.transform = identity if transform is None else transform 137 self.box_distortion_factor = box_distortion_factor
class
ConvertToSemanticSamInputs:
230class ConvertToSemanticSamInputs: 231 """Convert outputs of data loader to the expected batched inputs of the SegmentAnything model 232 for semantic segmentation. 233 """ 234 def __call__(self, x, y): 235 """Convert the outputs of dataloader to the batched format of inputs expected by SAM. 236 """ 237 batched_inputs = [] 238 for image, gt in zip(x, y): 239 batched_input = {"image": image, "original_size": image.shape[-2:]} 240 batched_inputs.append(batched_input) 241 242 return batched_inputs
Convert outputs of data loader to the expected batched inputs of the SegmentAnything model for semantic segmentation.
def
normalize_to_8bit(raw):
class
ResizeRawTrafo:
255class ResizeRawTrafo: 256 def __init__(self, desired_shape, do_rescaling=False, padding="constant"): 257 self.desired_shape = desired_shape 258 self.padding = padding 259 self.do_rescaling = do_rescaling 260 261 def __call__(self, raw): 262 if self.do_rescaling: 263 raw = normalize_percentile(raw, axis=(1, 2)) 264 raw = np.mean(raw, axis=0) 265 raw = normalize(raw) 266 raw = raw * 255 267 268 tmp_ddim = (self.desired_shape[0] - raw.shape[0], self.desired_shape[1] - raw.shape[1]) 269 ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2) 270 raw = np.pad( 271 raw, 272 pad_width=((ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1]))), 273 mode=self.padding 274 ) 275 assert raw.shape == self.desired_shape 276 return raw
class
ResizeLabelTrafo:
279class ResizeLabelTrafo: 280 def __init__(self, desired_shape, padding="constant", min_size=0): 281 self.desired_shape = desired_shape 282 self.padding = padding 283 self.min_size = min_size 284 285 def __call__(self, labels): 286 distance_trafo = PerObjectDistanceTransform( 287 distances=True, boundary_distances=True, directed_distances=False, 288 foreground=True, instances=True, min_size=self.min_size 289 ) 290 labels = distance_trafo(labels) 291 292 # choosing H and W from labels (4, H, W), from above dist trafo outputs 293 tmp_ddim = (self.desired_shape[0] - labels.shape[1], self.desired_shape[0] - labels.shape[2]) 294 ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2) 295 labels = np.pad( 296 labels, 297 pad_width=((0, 0), (ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1]))), 298 mode=self.padding 299 ) 300 assert labels.shape[1:] == self.desired_shape, labels.shape 301 return labels