micro_sam.training.util

  1import os
  2from math import ceil, floor
  3from typing import Dict, List, Optional, Union
  4
  5import numpy as np
  6
  7import torch
  8
  9from segment_anything.utils.transforms import ResizeLongestSide
 10
 11from ..prompt_generators import PointAndBoxPromptGenerator
 12from ..util import (
 13    get_centers_and_bounding_boxes, get_sam_model, get_device,
 14    segmentation_to_one_hot, _DEFAULT_MODEL,
 15)
 16from .. import models as custom_models
 17from .trainable_sam import TrainableSAM
 18
 19from torch_em.transform.label import PerObjectDistanceTransform
 20from torch_em.transform.raw import normalize_percentile, normalize
 21
 22
 23def identity(x):
 24    """Identity transformation.
 25
 26    This is a helper function to skip data normalization when finetuning SAM.
 27    Data normalization is performed within the model and should thus be skipped as
 28    a preprocessing step in training.
 29    """
 30    return x
 31
 32
 33def require_8bit(x):
 34    """Transformation to require 8bit input data range (0-255).
 35    """
 36    if x.max() < 1:
 37        x = x * 255
 38    return x
 39
 40
 41def get_trainable_sam_model(
 42    model_type: str = _DEFAULT_MODEL,
 43    device: Optional[Union[str, torch.device]] = None,
 44    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
 45    freeze: Optional[List[str]] = None,
 46    return_state: bool = False,
 47    peft_kwargs: Optional[Dict] = None,
 48    flexible_load_checkpoint: bool = False,
 49    **model_kwargs
 50) -> TrainableSAM:
 51    """Get the trainable sam model.
 52
 53    Args:
 54        model_type: The segment anything model that should be finetuned.
 55            The weights of this model will be used for initialization, unless a
 56            custom weight file is passed via `checkpoint_path`.
 57        device: The device to use for training.
 58        checkpoint_path: Path to a custom checkpoint from which to load the model weights.
 59        freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder
 60            By default nothing is frozen and the full model is updated.
 61        return_state: Whether to return the full checkpoint state.
 62        peft_kwargs: Keyword arguments for the PEFT wrapper class.
 63        flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.
 64        model_kwargs: Additional keyword arguments for the `util.get_sam_model`.
 65
 66    Returns:
 67        The trainable segment anything model.
 68    """
 69    # set the device here so that the correct one is passed to TrainableSAM below
 70    device = get_device(device)
 71    _, sam, state = get_sam_model(
 72        model_type=model_type,
 73        device=device,
 74        checkpoint_path=checkpoint_path,
 75        return_sam=True,
 76        return_state=True,
 77        flexible_load_checkpoint=flexible_load_checkpoint,
 78        **model_kwargs
 79    )
 80
 81    # NOTE: This is done exclusive to "get_sam_model" here to use PEFT's layer-specific initialization on top.
 82    # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything.
 83    # Overwrites the SAM model by freezing the backbone and allow PEFT methods.
 84    if peft_kwargs and isinstance(peft_kwargs, dict):
 85        if model_type[:5] == "vit_t":
 86            raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.")
 87
 88        sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam
 89
 90    # freeze components of the model if freeze was passed
 91    # ideally we would want to add components in such a way that:
 92    # - we would be able to freeze the choice of encoder/decoder blocks, yet be able to add components to the network
 93    #   (for e.g. encoder blocks to "image_encoder")
 94    if freeze is not None:
 95        for name, param in sam.named_parameters():
 96            if not isinstance(freeze, list):
 97                # we "freeze" only for one specific component when passed a "particular" part
 98                freeze = [freeze]
 99
100            # we would want to "freeze" all the components in the model if passed a list of parts
101            for l_item in freeze:
102                # in case PEFT is switched on, we cannot freeze the image encoder
103                if (peft_kwargs and peft_kwargs.get('rank') is not None) and (l_item == "image_encoder"):
104                    raise ValueError("You cannot use PEFT & freeze the image encoder at the same time.")
105
106                if name.startswith(f"{l_item}"):
107                    param.requires_grad = False
108
109    # convert to trainable sam
110    trainable_sam = TrainableSAM(sam)
111
112    if return_state:
113        return trainable_sam, state
114    return trainable_sam
115
116
117class ConvertToSamInputs:
118    """Convert outputs of data loader to the expected batched inputs of the SegmentAnything model.
119
120    Args:
121        transform: The transformation to resize the prompts. Should be the same transform used in the
122            model to resize the inputs. If `None` the prompts will not be resized.
123        dilation_strength: The dilation factor.
124            It determines a "safety" border from which prompts are not sampled to avoid ambiguous prompts
125            due to imprecise groundtruth masks.
126        box_distortion_factor: Factor for distorting the box annotations derived from the groundtruth masks.
127    """
128    def __init__(
129        self,
130        transform: Optional[ResizeLongestSide],
131        dilation_strength: int = 10,
132        box_distortion_factor: Optional[float] = None,
133    ) -> None:
134        self.dilation_strength = dilation_strength
135        self.transform = identity if transform is None else transform
136        self.box_distortion_factor = box_distortion_factor
137
138    def _distort_boxes(self, bbox_coordinates, shape):
139        distorted_boxes = []
140        for bbox in bbox_coordinates:
141            # The bounding box is parametrized by y0, x0, y1, x1.
142            y0, x0, y1, x1 = bbox
143            ly, lx = y1 - y0, x1 - x0
144            y0 = int(round(max(0, y0 - np.random.uniform(0, self.box_distortion_factor) * ly)))
145            y1 = int(round(min(shape[0], y1 + np.random.uniform(0, self.box_distortion_factor) * ly)))
146            x0 = int(round(max(0, x0 - np.random.uniform(0, self.box_distortion_factor) * lx)))
147            x1 = int(round(min(shape[1], x1 + np.random.uniform(0, self.box_distortion_factor) * lx)))
148            distorted_boxes.append([y0, x0, y1, x1])
149        return distorted_boxes
150
151    def _get_prompt_lists(self, gt, n_samples, prompt_generator):
152        """Returns a list of "expected" prompts subjected to the random input attributes for prompting."""
153
154        _, bbox_coordinates = get_centers_and_bounding_boxes(gt, mode="p")
155
156        # get the segment ids
157        cell_ids = np.unique(gt)[1:]
158        if n_samples is None:  # n-samples is set to None, so we use all ids
159            sampled_cell_ids = cell_ids
160
161        else:  # n-samples is set, so we subsample the cell ids
162            sampled_cell_ids = np.random.choice(cell_ids, size=min(n_samples, len(cell_ids)), replace=False)
163            sampled_cell_ids = np.sort(sampled_cell_ids)
164
165        # only keep the bounding boxes for sampled cell ids
166        bbox_coordinates = [bbox_coordinates[sampled_id] for sampled_id in sampled_cell_ids]
167        if self.box_distortion_factor is not None:
168            bbox_coordinates = self._distort_boxes(bbox_coordinates, shape=gt.shape[-2:])
169
170        # convert the gt to the one-hot-encoded masks for the sampled cell ids
171        object_masks = segmentation_to_one_hot(gt, None if n_samples is None else sampled_cell_ids)
172
173        # derive and return the prompts
174        point_prompts, point_label_prompts, box_prompts, _ = prompt_generator(object_masks, bbox_coordinates)
175        return box_prompts, point_prompts, point_label_prompts, sampled_cell_ids
176
177    def __call__(self, x, y, n_pos, n_neg, get_boxes=False, n_samples=None):
178        """Convert the outputs of dataloader and prompt settings to the batch format expected by SAM.
179        """
180        # condition to see if we get point prompts, then we (ofc) use point-prompting
181        # else we don't use point prompting
182        if n_pos == 0 and n_neg == 0:
183            get_points = False
184        else:
185            get_points = True
186
187        # keeping the solution open by checking for deterministic/dynamic choice of point prompts
188        prompt_generator = PointAndBoxPromptGenerator(n_positive_points=n_pos,
189                                                      n_negative_points=n_neg,
190                                                      dilation_strength=self.dilation_strength,
191                                                      get_box_prompts=get_boxes,
192                                                      get_point_prompts=get_points)
193
194        batched_inputs = []
195        batched_sampled_cell_ids_list = []
196
197        for image, gt in zip(x, y):
198            gt = gt.squeeze().numpy().astype(np.int64)
199            box_prompts, point_prompts, point_label_prompts, sampled_cell_ids = self._get_prompt_lists(
200                gt, n_samples, prompt_generator,
201            )
202
203            # check to be sure about the expected size of the no. of elements in different settings
204            if get_boxes:
205                assert len(sampled_cell_ids) == len(box_prompts), f"{len(sampled_cell_ids)}, {len(box_prompts)}"
206
207            if get_points:
208                assert len(sampled_cell_ids) == len(point_prompts) == len(point_label_prompts), \
209                    f"{len(sampled_cell_ids)}, {len(point_prompts)}, {len(point_label_prompts)}"
210
211            batched_sampled_cell_ids_list.append(sampled_cell_ids)
212
213            batched_input = {"image": image, "original_size": image.shape[1:]}
214            if get_boxes:
215                batched_input["boxes"] = self.transform.apply_boxes_torch(
216                    box_prompts, original_size=gt.shape[-2:]
217                ) if self.transform is not None else box_prompts
218            if get_points:
219                batched_input["point_coords"] = self.transform.apply_coords_torch(
220                    point_prompts, original_size=gt.shape[-2:]
221                ) if self.transform is not None else point_prompts
222                batched_input["point_labels"] = point_label_prompts
223
224            batched_inputs.append(batched_input)
225
226        return batched_inputs, batched_sampled_cell_ids_list
227
228
229class ConvertToSemanticSamInputs:
230    """Convert outputs of data loader to the expected batched inputs of the SegmentAnything model
231    for semantic segmentation.
232    """
233    def __call__(self, x, y):
234        """Convert the outputs of dataloader to the batched format of inputs expected by SAM.
235        """
236        batched_inputs = []
237        for image, gt in zip(x, y):
238            batched_input = {"image": image, "original_size": image.shape[-2:]}
239            batched_inputs.append(batched_input)
240
241        return batched_inputs
242
243
244#
245# Raw and Label Transformations for the Generalist and Specialist finetuning
246#
247
248
249def normalize_to_8bit(raw):
250    raw = normalize(raw) * 255
251    return raw
252
253
254class ResizeRawTrafo:
255    def __init__(self, desired_shape, do_rescaling=False, padding="constant"):
256        self.desired_shape = desired_shape
257        self.padding = padding
258        self.do_rescaling = do_rescaling
259
260    def __call__(self, raw):
261        if self.do_rescaling:
262            raw = normalize_percentile(raw, axis=(1, 2))
263            raw = np.mean(raw, axis=0)
264            raw = normalize(raw)
265            raw = raw * 255
266
267        tmp_ddim = (self.desired_shape[0] - raw.shape[0], self.desired_shape[1] - raw.shape[1])
268        ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2)
269        raw = np.pad(
270            raw,
271            pad_width=((ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1]))),
272            mode=self.padding
273        )
274        assert raw.shape == self.desired_shape
275        return raw
276
277
278class ResizeLabelTrafo:
279    def __init__(self, desired_shape, padding="constant", min_size=0):
280        self.desired_shape = desired_shape
281        self.padding = padding
282        self.min_size = min_size
283
284    def __call__(self, labels):
285        distance_trafo = PerObjectDistanceTransform(
286            distances=True, boundary_distances=True, directed_distances=False,
287            foreground=True, instances=True, min_size=self.min_size
288        )
289        labels = distance_trafo(labels)
290
291        # choosing H and W from labels (4, H, W), from above dist trafo outputs
292        tmp_ddim = (self.desired_shape[0] - labels.shape[1], self.desired_shape[0] - labels.shape[2])
293        ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2)
294        labels = np.pad(
295            labels,
296            pad_width=((0, 0), (ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1]))),
297            mode=self.padding
298        )
299        assert labels.shape[1:] == self.desired_shape, labels.shape
300        return labels
def identity(x):
24def identity(x):
25    """Identity transformation.
26
27    This is a helper function to skip data normalization when finetuning SAM.
28    Data normalization is performed within the model and should thus be skipped as
29    a preprocessing step in training.
30    """
31    return x

Identity transformation.

This is a helper function to skip data normalization when finetuning SAM. Data normalization is performed within the model and should thus be skipped as a preprocessing step in training.

def require_8bit(x):
34def require_8bit(x):
35    """Transformation to require 8bit input data range (0-255).
36    """
37    if x.max() < 1:
38        x = x * 255
39    return x

Transformation to require 8bit input data range (0-255).

def get_trainable_sam_model( model_type: str = 'vit_l', device: Union[str, torch.device, NoneType] = None, checkpoint_path: Union[str, os.PathLike, NoneType] = None, freeze: Optional[List[str]] = None, return_state: bool = False, peft_kwargs: Optional[Dict] = None, flexible_load_checkpoint: bool = False, **model_kwargs) -> micro_sam.training.trainable_sam.TrainableSAM:
 42def get_trainable_sam_model(
 43    model_type: str = _DEFAULT_MODEL,
 44    device: Optional[Union[str, torch.device]] = None,
 45    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
 46    freeze: Optional[List[str]] = None,
 47    return_state: bool = False,
 48    peft_kwargs: Optional[Dict] = None,
 49    flexible_load_checkpoint: bool = False,
 50    **model_kwargs
 51) -> TrainableSAM:
 52    """Get the trainable sam model.
 53
 54    Args:
 55        model_type: The segment anything model that should be finetuned.
 56            The weights of this model will be used for initialization, unless a
 57            custom weight file is passed via `checkpoint_path`.
 58        device: The device to use for training.
 59        checkpoint_path: Path to a custom checkpoint from which to load the model weights.
 60        freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder
 61            By default nothing is frozen and the full model is updated.
 62        return_state: Whether to return the full checkpoint state.
 63        peft_kwargs: Keyword arguments for the PEFT wrapper class.
 64        flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.
 65        model_kwargs: Additional keyword arguments for the `util.get_sam_model`.
 66
 67    Returns:
 68        The trainable segment anything model.
 69    """
 70    # set the device here so that the correct one is passed to TrainableSAM below
 71    device = get_device(device)
 72    _, sam, state = get_sam_model(
 73        model_type=model_type,
 74        device=device,
 75        checkpoint_path=checkpoint_path,
 76        return_sam=True,
 77        return_state=True,
 78        flexible_load_checkpoint=flexible_load_checkpoint,
 79        **model_kwargs
 80    )
 81
 82    # NOTE: This is done exclusive to "get_sam_model" here to use PEFT's layer-specific initialization on top.
 83    # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything.
 84    # Overwrites the SAM model by freezing the backbone and allow PEFT methods.
 85    if peft_kwargs and isinstance(peft_kwargs, dict):
 86        if model_type[:5] == "vit_t":
 87            raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.")
 88
 89        sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam
 90
 91    # freeze components of the model if freeze was passed
 92    # ideally we would want to add components in such a way that:
 93    # - we would be able to freeze the choice of encoder/decoder blocks, yet be able to add components to the network
 94    #   (for e.g. encoder blocks to "image_encoder")
 95    if freeze is not None:
 96        for name, param in sam.named_parameters():
 97            if not isinstance(freeze, list):
 98                # we "freeze" only for one specific component when passed a "particular" part
 99                freeze = [freeze]
100
101            # we would want to "freeze" all the components in the model if passed a list of parts
102            for l_item in freeze:
103                # in case PEFT is switched on, we cannot freeze the image encoder
104                if (peft_kwargs and peft_kwargs.get('rank') is not None) and (l_item == "image_encoder"):
105                    raise ValueError("You cannot use PEFT & freeze the image encoder at the same time.")
106
107                if name.startswith(f"{l_item}"):
108                    param.requires_grad = False
109
110    # convert to trainable sam
111    trainable_sam = TrainableSAM(sam)
112
113    if return_state:
114        return trainable_sam, state
115    return trainable_sam

Get the trainable sam model.

Arguments:
  • model_type: The segment anything model that should be finetuned. The weights of this model will be used for initialization, unless a custom weight file is passed via checkpoint_path.
  • device: The device to use for training.
  • checkpoint_path: Path to a custom checkpoint from which to load the model weights.
  • freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder By default nothing is frozen and the full model is updated.
  • return_state: Whether to return the full checkpoint state.
  • peft_kwargs: Keyword arguments for the PEFT wrapper class.
  • flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.
  • model_kwargs: Additional keyword arguments for the util.get_sam_model.
Returns:

The trainable segment anything model.

class ConvertToSamInputs:
118class ConvertToSamInputs:
119    """Convert outputs of data loader to the expected batched inputs of the SegmentAnything model.
120
121    Args:
122        transform: The transformation to resize the prompts. Should be the same transform used in the
123            model to resize the inputs. If `None` the prompts will not be resized.
124        dilation_strength: The dilation factor.
125            It determines a "safety" border from which prompts are not sampled to avoid ambiguous prompts
126            due to imprecise groundtruth masks.
127        box_distortion_factor: Factor for distorting the box annotations derived from the groundtruth masks.
128    """
129    def __init__(
130        self,
131        transform: Optional[ResizeLongestSide],
132        dilation_strength: int = 10,
133        box_distortion_factor: Optional[float] = None,
134    ) -> None:
135        self.dilation_strength = dilation_strength
136        self.transform = identity if transform is None else transform
137        self.box_distortion_factor = box_distortion_factor
138
139    def _distort_boxes(self, bbox_coordinates, shape):
140        distorted_boxes = []
141        for bbox in bbox_coordinates:
142            # The bounding box is parametrized by y0, x0, y1, x1.
143            y0, x0, y1, x1 = bbox
144            ly, lx = y1 - y0, x1 - x0
145            y0 = int(round(max(0, y0 - np.random.uniform(0, self.box_distortion_factor) * ly)))
146            y1 = int(round(min(shape[0], y1 + np.random.uniform(0, self.box_distortion_factor) * ly)))
147            x0 = int(round(max(0, x0 - np.random.uniform(0, self.box_distortion_factor) * lx)))
148            x1 = int(round(min(shape[1], x1 + np.random.uniform(0, self.box_distortion_factor) * lx)))
149            distorted_boxes.append([y0, x0, y1, x1])
150        return distorted_boxes
151
152    def _get_prompt_lists(self, gt, n_samples, prompt_generator):
153        """Returns a list of "expected" prompts subjected to the random input attributes for prompting."""
154
155        _, bbox_coordinates = get_centers_and_bounding_boxes(gt, mode="p")
156
157        # get the segment ids
158        cell_ids = np.unique(gt)[1:]
159        if n_samples is None:  # n-samples is set to None, so we use all ids
160            sampled_cell_ids = cell_ids
161
162        else:  # n-samples is set, so we subsample the cell ids
163            sampled_cell_ids = np.random.choice(cell_ids, size=min(n_samples, len(cell_ids)), replace=False)
164            sampled_cell_ids = np.sort(sampled_cell_ids)
165
166        # only keep the bounding boxes for sampled cell ids
167        bbox_coordinates = [bbox_coordinates[sampled_id] for sampled_id in sampled_cell_ids]
168        if self.box_distortion_factor is not None:
169            bbox_coordinates = self._distort_boxes(bbox_coordinates, shape=gt.shape[-2:])
170
171        # convert the gt to the one-hot-encoded masks for the sampled cell ids
172        object_masks = segmentation_to_one_hot(gt, None if n_samples is None else sampled_cell_ids)
173
174        # derive and return the prompts
175        point_prompts, point_label_prompts, box_prompts, _ = prompt_generator(object_masks, bbox_coordinates)
176        return box_prompts, point_prompts, point_label_prompts, sampled_cell_ids
177
178    def __call__(self, x, y, n_pos, n_neg, get_boxes=False, n_samples=None):
179        """Convert the outputs of dataloader and prompt settings to the batch format expected by SAM.
180        """
181        # condition to see if we get point prompts, then we (ofc) use point-prompting
182        # else we don't use point prompting
183        if n_pos == 0 and n_neg == 0:
184            get_points = False
185        else:
186            get_points = True
187
188        # keeping the solution open by checking for deterministic/dynamic choice of point prompts
189        prompt_generator = PointAndBoxPromptGenerator(n_positive_points=n_pos,
190                                                      n_negative_points=n_neg,
191                                                      dilation_strength=self.dilation_strength,
192                                                      get_box_prompts=get_boxes,
193                                                      get_point_prompts=get_points)
194
195        batched_inputs = []
196        batched_sampled_cell_ids_list = []
197
198        for image, gt in zip(x, y):
199            gt = gt.squeeze().numpy().astype(np.int64)
200            box_prompts, point_prompts, point_label_prompts, sampled_cell_ids = self._get_prompt_lists(
201                gt, n_samples, prompt_generator,
202            )
203
204            # check to be sure about the expected size of the no. of elements in different settings
205            if get_boxes:
206                assert len(sampled_cell_ids) == len(box_prompts), f"{len(sampled_cell_ids)}, {len(box_prompts)}"
207
208            if get_points:
209                assert len(sampled_cell_ids) == len(point_prompts) == len(point_label_prompts), \
210                    f"{len(sampled_cell_ids)}, {len(point_prompts)}, {len(point_label_prompts)}"
211
212            batched_sampled_cell_ids_list.append(sampled_cell_ids)
213
214            batched_input = {"image": image, "original_size": image.shape[1:]}
215            if get_boxes:
216                batched_input["boxes"] = self.transform.apply_boxes_torch(
217                    box_prompts, original_size=gt.shape[-2:]
218                ) if self.transform is not None else box_prompts
219            if get_points:
220                batched_input["point_coords"] = self.transform.apply_coords_torch(
221                    point_prompts, original_size=gt.shape[-2:]
222                ) if self.transform is not None else point_prompts
223                batched_input["point_labels"] = point_label_prompts
224
225            batched_inputs.append(batched_input)
226
227        return batched_inputs, batched_sampled_cell_ids_list

Convert outputs of data loader to the expected batched inputs of the SegmentAnything model.

Arguments:
  • transform: The transformation to resize the prompts. Should be the same transform used in the model to resize the inputs. If None the prompts will not be resized.
  • dilation_strength: The dilation factor. It determines a "safety" border from which prompts are not sampled to avoid ambiguous prompts due to imprecise groundtruth masks.
  • box_distortion_factor: Factor for distorting the box annotations derived from the groundtruth masks.
ConvertToSamInputs( transform: Optional[segment_anything.utils.transforms.ResizeLongestSide], dilation_strength: int = 10, box_distortion_factor: Optional[float] = None)
129    def __init__(
130        self,
131        transform: Optional[ResizeLongestSide],
132        dilation_strength: int = 10,
133        box_distortion_factor: Optional[float] = None,
134    ) -> None:
135        self.dilation_strength = dilation_strength
136        self.transform = identity if transform is None else transform
137        self.box_distortion_factor = box_distortion_factor
dilation_strength
transform
box_distortion_factor
class ConvertToSemanticSamInputs:
230class ConvertToSemanticSamInputs:
231    """Convert outputs of data loader to the expected batched inputs of the SegmentAnything model
232    for semantic segmentation.
233    """
234    def __call__(self, x, y):
235        """Convert the outputs of dataloader to the batched format of inputs expected by SAM.
236        """
237        batched_inputs = []
238        for image, gt in zip(x, y):
239            batched_input = {"image": image, "original_size": image.shape[-2:]}
240            batched_inputs.append(batched_input)
241
242        return batched_inputs

Convert outputs of data loader to the expected batched inputs of the SegmentAnything model for semantic segmentation.

def normalize_to_8bit(raw):
250def normalize_to_8bit(raw):
251    raw = normalize(raw) * 255
252    return raw
class ResizeRawTrafo:
255class ResizeRawTrafo:
256    def __init__(self, desired_shape, do_rescaling=False, padding="constant"):
257        self.desired_shape = desired_shape
258        self.padding = padding
259        self.do_rescaling = do_rescaling
260
261    def __call__(self, raw):
262        if self.do_rescaling:
263            raw = normalize_percentile(raw, axis=(1, 2))
264            raw = np.mean(raw, axis=0)
265            raw = normalize(raw)
266            raw = raw * 255
267
268        tmp_ddim = (self.desired_shape[0] - raw.shape[0], self.desired_shape[1] - raw.shape[1])
269        ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2)
270        raw = np.pad(
271            raw,
272            pad_width=((ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1]))),
273            mode=self.padding
274        )
275        assert raw.shape == self.desired_shape
276        return raw
ResizeRawTrafo(desired_shape, do_rescaling=False, padding='constant')
256    def __init__(self, desired_shape, do_rescaling=False, padding="constant"):
257        self.desired_shape = desired_shape
258        self.padding = padding
259        self.do_rescaling = do_rescaling
desired_shape
padding
do_rescaling
class ResizeLabelTrafo:
279class ResizeLabelTrafo:
280    def __init__(self, desired_shape, padding="constant", min_size=0):
281        self.desired_shape = desired_shape
282        self.padding = padding
283        self.min_size = min_size
284
285    def __call__(self, labels):
286        distance_trafo = PerObjectDistanceTransform(
287            distances=True, boundary_distances=True, directed_distances=False,
288            foreground=True, instances=True, min_size=self.min_size
289        )
290        labels = distance_trafo(labels)
291
292        # choosing H and W from labels (4, H, W), from above dist trafo outputs
293        tmp_ddim = (self.desired_shape[0] - labels.shape[1], self.desired_shape[0] - labels.shape[2])
294        ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2)
295        labels = np.pad(
296            labels,
297            pad_width=((0, 0), (ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1]))),
298            mode=self.padding
299        )
300        assert labels.shape[1:] == self.desired_shape, labels.shape
301        return labels
ResizeLabelTrafo(desired_shape, padding='constant', min_size=0)
280    def __init__(self, desired_shape, padding="constant", min_size=0):
281        self.desired_shape = desired_shape
282        self.padding = padding
283        self.min_size = min_size
desired_shape
padding
min_size