micro_sam.training.util

  1import os
  2from math import ceil, floor
  3from functools import partial
  4from typing import Dict, List, Optional, Union, Tuple, Callable
  5
  6import numpy as np
  7
  8import torch
  9
 10from segment_anything.utils.transforms import ResizeLongestSide
 11
 12from ..prompt_generators import PointAndBoxPromptGenerator
 13from ..util import (
 14    get_centers_and_bounding_boxes, get_sam_model, get_device,
 15    segmentation_to_one_hot, _DEFAULT_MODEL,
 16)
 17from .. import models as custom_models
 18from .trainable_sam import TrainableSAM
 19
 20from torch_em.transform.label import PerObjectDistanceTransform
 21from torch_em.transform.raw import normalize_percentile, normalize
 22from torch_em.data.datasets.light_microscopy.neurips_cell_seg import to_rgb
 23
 24
 25def identity(x):
 26    """Identity transformation.
 27
 28    This is a helper function to skip data normalization when finetuning SAM.
 29    Data normalization is performed within the model and should thus be skipped as
 30    a preprocessing step in training.
 31    """
 32    return x
 33
 34
 35def require_8bit(x):
 36    """Transformation to require 8bit input data range (0-255).
 37    """
 38    if x.max() < 1:
 39        x = x * 255
 40    return x
 41
 42
 43def _raw_transform(image: np.ndarray, raw_trafo: Callable) -> np.ndarray:
 44    return raw_trafo(image) * 255
 45
 46
 47def _normalize_percentile(image: np.ndarray) -> np.ndarray:
 48    image = normalize_percentile(image)  # Use 1st and 99th percentile values for min-max normalization.
 49    image = np.clip(image, 0, 1)  # Clip the values to be in range [0, 1].
 50    return image
 51
 52
 53def get_raw_transform(preprocess: Optional[str] = None) -> Optional[Callable]:
 54    """Transformation functions to normalize inputs.
 55
 56    Args:
 57        preprocess: By default, the transformation function is set to 'None'.
 58            The user can choose from 'normalize_minmax' / 'normalize_percentile'.
 59
 60    Returns:
 61        The transformation function.
 62    """
 63
 64    if preprocess is None:  # Ensures that inputs are 8-bit.
 65        return require_8bit
 66    else:
 67        if preprocess == "normalize_minmax":
 68            raw_trafo = normalize
 69        elif preprocess == "normalize_percentile":
 70            raw_trafo = _normalize_percentile
 71        else:
 72            raise ValueError(f"'{preprocess}' is not a supported preprocessing.")
 73
 74        return partial(_raw_transform, raw_trafo=raw_trafo)
 75
 76
 77def get_trainable_sam_model(
 78    model_type: str = _DEFAULT_MODEL,
 79    device: Optional[Union[str, torch.device]] = None,
 80    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
 81    freeze: Optional[List[str]] = None,
 82    return_state: bool = False,
 83    peft_kwargs: Optional[Dict] = None,
 84    flexible_load_checkpoint: bool = False,
 85    **model_kwargs
 86) -> TrainableSAM:
 87    """Get the trainable sam model.
 88
 89    Args:
 90        model_type: The segment anything model that should be finetuned. The weights of this model
 91            will be used for initialization, unless a custom weight file is passed via `checkpoint_path`.
 92        device: The device to use for training.
 93        checkpoint_path: Path to a custom checkpoint from which to load the model weights.
 94        freeze: Specify parts of the model that should be frozen, namely: `image_encoder`, `prompt_encoder` and
 95            `mask_decoder`. By default nothing is frozen and the full model is updated.
 96        return_state: Whether to return the full checkpoint state.
 97        peft_kwargs: Keyword arguments for the PEFT wrapper class.
 98        flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.
 99        model_kwargs: Additional keyword arguments for the `util.get_sam_model`.
100
101    Returns:
102        The trainable segment anything model.
103    """
104    # set the device here so that the correct one is passed to TrainableSAM below
105    device = get_device(device)
106    _, sam, state = get_sam_model(
107        model_type=model_type,
108        device=device,
109        checkpoint_path=checkpoint_path,
110        return_sam=True,
111        return_state=True,
112        flexible_load_checkpoint=flexible_load_checkpoint,
113        **model_kwargs
114    )
115
116    # NOTE: This is done exclusive to "get_sam_model" here to use PEFT's layer-specific initialization on top.
117    # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything.
118    # Overwrites the SAM model by freezing the backbone and allow PEFT methods.
119    if peft_kwargs and isinstance(peft_kwargs, dict):
120        if model_type[:5] == "vit_t":
121            raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.")
122
123        sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam
124
125    # freeze components of the model if freeze was passed
126    # ideally we would want to add components in such a way that:
127    # - we would be able to freeze the choice of encoder/decoder blocks, yet be able to add components to the network
128    #   (for e.g. encoder blocks to "image_encoder")
129    if freeze is not None:
130        for name, param in sam.named_parameters():
131            if not isinstance(freeze, list):
132                # we "freeze" only for one specific component when passed a "particular" part
133                freeze = [freeze]
134
135            # we would want to "freeze" all the components in the model if passed a list of parts
136            for l_item in freeze:
137                # in case PEFT is switched on, we cannot freeze the image encoder
138                if (peft_kwargs and peft_kwargs.get('rank') is not None) and (l_item == "image_encoder"):
139                    raise ValueError("You cannot use PEFT & freeze the image encoder at the same time.")
140
141                if name.startswith(f"{l_item}"):
142                    param.requires_grad = False
143
144    # convert to trainable sam
145    trainable_sam = TrainableSAM(sam)
146
147    if return_state:
148        return trainable_sam, state
149    return trainable_sam
150
151
152class ConvertToSamInputs:
153    """Convert outputs of data loader to the expected batched inputs of the SegmentAnything model.
154
155    Args:
156        transform: The transformation to resize the prompts. Should be the same transform used in the
157            model to resize the inputs. If `None` the prompts will not be resized.
158        dilation_strength: The dilation factor.
159            It determines a "safety" border from which prompts are not sampled to avoid ambiguous prompts
160            due to imprecise groundtruth masks.
161        box_distortion_factor: Factor for distorting the box annotations derived from the groundtruth masks.
162    """
163    def __init__(
164        self,
165        transform: Optional[ResizeLongestSide],
166        dilation_strength: int = 10,
167        box_distortion_factor: Optional[float] = None,
168    ) -> None:
169        self.dilation_strength = dilation_strength
170        self.transform = identity if transform is None else transform
171        self.box_distortion_factor = box_distortion_factor
172
173    def _distort_boxes(self, bbox_coordinates, shape):
174        distorted_boxes = []
175        for bbox in bbox_coordinates:
176            # The bounding box is parametrized by y0, x0, y1, x1.
177            y0, x0, y1, x1 = bbox
178            ly, lx = y1 - y0, x1 - x0
179            y0 = int(round(max(0, y0 - np.random.uniform(0, self.box_distortion_factor) * ly)))
180            y1 = int(round(min(shape[0], y1 + np.random.uniform(0, self.box_distortion_factor) * ly)))
181            x0 = int(round(max(0, x0 - np.random.uniform(0, self.box_distortion_factor) * lx)))
182            x1 = int(round(min(shape[1], x1 + np.random.uniform(0, self.box_distortion_factor) * lx)))
183            distorted_boxes.append([y0, x0, y1, x1])
184        return distorted_boxes
185
186    def _get_prompt_lists(self, gt, n_samples, prompt_generator):
187        """Returns a list of "expected" prompts subjected to the random input attributes for prompting."""
188
189        _, bbox_coordinates = get_centers_and_bounding_boxes(gt, mode="p")
190
191        # get the segment ids
192        cell_ids = np.unique(gt)[1:]
193        if n_samples is None:  # n-samples is set to None, so we use all ids
194            sampled_cell_ids = cell_ids
195
196        else:  # n-samples is set, so we subsample the cell ids
197            sampled_cell_ids = np.random.choice(cell_ids, size=min(n_samples, len(cell_ids)), replace=False)
198            sampled_cell_ids = np.sort(sampled_cell_ids)
199
200        # only keep the bounding boxes for sampled cell ids
201        bbox_coordinates = [bbox_coordinates[sampled_id] for sampled_id in sampled_cell_ids]
202        if self.box_distortion_factor is not None:
203            bbox_coordinates = self._distort_boxes(bbox_coordinates, shape=gt.shape[-2:])
204
205        # convert the gt to the one-hot-encoded masks for the sampled cell ids
206        object_masks = segmentation_to_one_hot(gt, None if n_samples is None else sampled_cell_ids)
207
208        # derive and return the prompts
209        point_prompts, point_label_prompts, box_prompts, _ = prompt_generator(object_masks, bbox_coordinates)
210        return box_prompts, point_prompts, point_label_prompts, sampled_cell_ids
211
212    def __call__(self, x, y, n_pos, n_neg, get_boxes=False, n_samples=None):
213        """Convert the outputs of dataloader and prompt settings to the batch format expected by SAM.
214        """
215        # condition to see if we get point prompts, then we (ofc) use point-prompting
216        # else we don't use point prompting
217        if n_pos == 0 and n_neg == 0:
218            get_points = False
219        else:
220            get_points = True
221
222        # keeping the solution open by checking for deterministic/dynamic choice of point prompts
223        prompt_generator = PointAndBoxPromptGenerator(
224            n_positive_points=n_pos,
225            n_negative_points=n_neg,
226            dilation_strength=self.dilation_strength,
227            get_box_prompts=get_boxes,
228            get_point_prompts=get_points
229        )
230
231        batched_inputs = []
232        batched_sampled_cell_ids_list = []
233
234        for image, gt in zip(x, y):
235            gt = gt.squeeze().numpy().astype(np.int64)
236            box_prompts, point_prompts, point_label_prompts, sampled_cell_ids = self._get_prompt_lists(
237                gt, n_samples, prompt_generator,
238            )
239
240            # check to be sure about the expected size of the no. of elements in different settings
241            if get_boxes:
242                assert len(sampled_cell_ids) == len(box_prompts), f"{len(sampled_cell_ids)}, {len(box_prompts)}"
243
244            if get_points:
245                assert len(sampled_cell_ids) == len(point_prompts) == len(point_label_prompts), \
246                    f"{len(sampled_cell_ids)}, {len(point_prompts)}, {len(point_label_prompts)}"
247
248            batched_sampled_cell_ids_list.append(sampled_cell_ids)
249
250            batched_input = {"image": image, "original_size": image.shape[1:]}
251            if get_boxes:
252                batched_input["boxes"] = self.transform.apply_boxes_torch(
253                    box_prompts, original_size=gt.shape[-2:]
254                ) if self.transform is not None else box_prompts
255
256            if get_points:
257                batched_input["point_coords"] = self.transform.apply_coords_torch(
258                    point_prompts, original_size=gt.shape[-2:]
259                ) if self.transform is not None else point_prompts
260                batched_input["point_labels"] = point_label_prompts
261
262            batched_inputs.append(batched_input)
263
264        return batched_inputs, batched_sampled_cell_ids_list
265
266
267class ConvertToSemanticSamInputs:
268    """Convert outputs of data loader to the expected batched inputs of the Segment Anything model
269    for semantic segmentation.
270    """
271    def __call__(self, x, y):
272        """Convert the outputs of dataloader to the batched format of inputs expected by SAM.
273        """
274        batched_inputs = []
275        for image in x:
276            batched_input = {"image": image, "original_size": image.shape[-2:]}
277            batched_inputs.append(batched_input)
278
279        return batched_inputs
280
281
282#
283# Raw and Label Transformations for the Generalist and Specialist finetuning
284#
285
286
287def normalize_to_8bit(raw):
288    raw = normalize(raw) * 255
289    return raw
290
291
292class ResizeRawTrafo:
293    def __init__(
294        self,
295        desired_shape: Tuple[int, ...],
296        do_rescaling: bool = False,
297        valid_channels: Optional[Union[int, Tuple[int, ...]]] = None,
298        padding: str = "constant"
299    ):
300        self.desired_shape = desired_shape
301        self.do_rescaling = do_rescaling
302        self.valid_channels = valid_channels
303        self.padding = padding
304
305    def __call__(self, raw):
306        raw = to_rgb(raw)  # Ensure all images are in 3-channels: triplicate one channel to three channels.
307
308        if self.do_rescaling:
309            raw = normalize_percentile(raw, axis=self.valid_channels)
310            raw = normalize(raw)
311            raw = raw * 255
312
313        # Pad the inputs to the desired shape.
314        tmp_ddim = [desired - curr for desired, curr in zip(self.desired_shape, raw.shape)]
315        ddim = [(per_dim / 2) for per_dim in tmp_ddim]
316        pad_width = [(ceil(d), floor(d)) for d in ddim]
317        raw = np.pad(raw, pad_width=pad_width, mode=self.padding)
318
319        assert raw.shape == self.desired_shape
320        return raw
321
322
323class ResizeLabelTrafo:
324    def __init__(
325        self, desired_shape: Tuple[int, ...], min_size: int = 0, padding: str = "constant",
326    ):
327        self.desired_shape = desired_shape
328        self.min_size = min_size
329        self.padding = padding
330
331    def __call__(self, labels):
332        distance_trafo = PerObjectDistanceTransform(
333            distances=True,
334            boundary_distances=True,
335            directed_distances=False,
336            foreground=True,
337            instances=True,
338            min_size=self.min_size
339        )
340        labels = distance_trafo(labels)
341
342        # choosing H and W from labels (4, H, W), from above dist trafo outputs
343        tmp_ddim = (self.desired_shape[0] - labels.shape[1], self.desired_shape[0] - labels.shape[2])
344        ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2)
345        labels = np.pad(
346            labels,
347            pad_width=((0, 0), (ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1]))),
348            mode=self.padding
349        )
350        assert labels.shape[1:] == self.desired_shape, labels.shape
351        return labels
def identity(x):
26def identity(x):
27    """Identity transformation.
28
29    This is a helper function to skip data normalization when finetuning SAM.
30    Data normalization is performed within the model and should thus be skipped as
31    a preprocessing step in training.
32    """
33    return x

Identity transformation.

This is a helper function to skip data normalization when finetuning SAM. Data normalization is performed within the model and should thus be skipped as a preprocessing step in training.

def require_8bit(x):
36def require_8bit(x):
37    """Transformation to require 8bit input data range (0-255).
38    """
39    if x.max() < 1:
40        x = x * 255
41    return x

Transformation to require 8bit input data range (0-255).

def get_raw_transform(preprocess: Optional[str] = None) -> Optional[Callable]:
54def get_raw_transform(preprocess: Optional[str] = None) -> Optional[Callable]:
55    """Transformation functions to normalize inputs.
56
57    Args:
58        preprocess: By default, the transformation function is set to 'None'.
59            The user can choose from 'normalize_minmax' / 'normalize_percentile'.
60
61    Returns:
62        The transformation function.
63    """
64
65    if preprocess is None:  # Ensures that inputs are 8-bit.
66        return require_8bit
67    else:
68        if preprocess == "normalize_minmax":
69            raw_trafo = normalize
70        elif preprocess == "normalize_percentile":
71            raw_trafo = _normalize_percentile
72        else:
73            raise ValueError(f"'{preprocess}' is not a supported preprocessing.")
74
75        return partial(_raw_transform, raw_trafo=raw_trafo)

Transformation functions to normalize inputs.

Arguments:
  • preprocess: By default, the transformation function is set to 'None'. The user can choose from 'normalize_minmax' / 'normalize_percentile'.
Returns:

The transformation function.

def get_trainable_sam_model( model_type: str = 'vit_b_lm', device: Union[str, torch.device, NoneType] = None, checkpoint_path: Union[os.PathLike, str, NoneType] = None, freeze: Optional[List[str]] = None, return_state: bool = False, peft_kwargs: Optional[Dict] = None, flexible_load_checkpoint: bool = False, **model_kwargs) -> micro_sam.training.trainable_sam.TrainableSAM:
 78def get_trainable_sam_model(
 79    model_type: str = _DEFAULT_MODEL,
 80    device: Optional[Union[str, torch.device]] = None,
 81    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
 82    freeze: Optional[List[str]] = None,
 83    return_state: bool = False,
 84    peft_kwargs: Optional[Dict] = None,
 85    flexible_load_checkpoint: bool = False,
 86    **model_kwargs
 87) -> TrainableSAM:
 88    """Get the trainable sam model.
 89
 90    Args:
 91        model_type: The segment anything model that should be finetuned. The weights of this model
 92            will be used for initialization, unless a custom weight file is passed via `checkpoint_path`.
 93        device: The device to use for training.
 94        checkpoint_path: Path to a custom checkpoint from which to load the model weights.
 95        freeze: Specify parts of the model that should be frozen, namely: `image_encoder`, `prompt_encoder` and
 96            `mask_decoder`. By default nothing is frozen and the full model is updated.
 97        return_state: Whether to return the full checkpoint state.
 98        peft_kwargs: Keyword arguments for the PEFT wrapper class.
 99        flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.
100        model_kwargs: Additional keyword arguments for the `util.get_sam_model`.
101
102    Returns:
103        The trainable segment anything model.
104    """
105    # set the device here so that the correct one is passed to TrainableSAM below
106    device = get_device(device)
107    _, sam, state = get_sam_model(
108        model_type=model_type,
109        device=device,
110        checkpoint_path=checkpoint_path,
111        return_sam=True,
112        return_state=True,
113        flexible_load_checkpoint=flexible_load_checkpoint,
114        **model_kwargs
115    )
116
117    # NOTE: This is done exclusive to "get_sam_model" here to use PEFT's layer-specific initialization on top.
118    # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything.
119    # Overwrites the SAM model by freezing the backbone and allow PEFT methods.
120    if peft_kwargs and isinstance(peft_kwargs, dict):
121        if model_type[:5] == "vit_t":
122            raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.")
123
124        sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam
125
126    # freeze components of the model if freeze was passed
127    # ideally we would want to add components in such a way that:
128    # - we would be able to freeze the choice of encoder/decoder blocks, yet be able to add components to the network
129    #   (for e.g. encoder blocks to "image_encoder")
130    if freeze is not None:
131        for name, param in sam.named_parameters():
132            if not isinstance(freeze, list):
133                # we "freeze" only for one specific component when passed a "particular" part
134                freeze = [freeze]
135
136            # we would want to "freeze" all the components in the model if passed a list of parts
137            for l_item in freeze:
138                # in case PEFT is switched on, we cannot freeze the image encoder
139                if (peft_kwargs and peft_kwargs.get('rank') is not None) and (l_item == "image_encoder"):
140                    raise ValueError("You cannot use PEFT & freeze the image encoder at the same time.")
141
142                if name.startswith(f"{l_item}"):
143                    param.requires_grad = False
144
145    # convert to trainable sam
146    trainable_sam = TrainableSAM(sam)
147
148    if return_state:
149        return trainable_sam, state
150    return trainable_sam

Get the trainable sam model.

Arguments:
  • model_type: The segment anything model that should be finetuned. The weights of this model will be used for initialization, unless a custom weight file is passed via checkpoint_path.
  • device: The device to use for training.
  • checkpoint_path: Path to a custom checkpoint from which to load the model weights.
  • freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder. By default nothing is frozen and the full model is updated.
  • return_state: Whether to return the full checkpoint state.
  • peft_kwargs: Keyword arguments for the PEFT wrapper class.
  • flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.
  • model_kwargs: Additional keyword arguments for the util.get_sam_model.
Returns:

The trainable segment anything model.

class ConvertToSamInputs:
153class ConvertToSamInputs:
154    """Convert outputs of data loader to the expected batched inputs of the SegmentAnything model.
155
156    Args:
157        transform: The transformation to resize the prompts. Should be the same transform used in the
158            model to resize the inputs. If `None` the prompts will not be resized.
159        dilation_strength: The dilation factor.
160            It determines a "safety" border from which prompts are not sampled to avoid ambiguous prompts
161            due to imprecise groundtruth masks.
162        box_distortion_factor: Factor for distorting the box annotations derived from the groundtruth masks.
163    """
164    def __init__(
165        self,
166        transform: Optional[ResizeLongestSide],
167        dilation_strength: int = 10,
168        box_distortion_factor: Optional[float] = None,
169    ) -> None:
170        self.dilation_strength = dilation_strength
171        self.transform = identity if transform is None else transform
172        self.box_distortion_factor = box_distortion_factor
173
174    def _distort_boxes(self, bbox_coordinates, shape):
175        distorted_boxes = []
176        for bbox in bbox_coordinates:
177            # The bounding box is parametrized by y0, x0, y1, x1.
178            y0, x0, y1, x1 = bbox
179            ly, lx = y1 - y0, x1 - x0
180            y0 = int(round(max(0, y0 - np.random.uniform(0, self.box_distortion_factor) * ly)))
181            y1 = int(round(min(shape[0], y1 + np.random.uniform(0, self.box_distortion_factor) * ly)))
182            x0 = int(round(max(0, x0 - np.random.uniform(0, self.box_distortion_factor) * lx)))
183            x1 = int(round(min(shape[1], x1 + np.random.uniform(0, self.box_distortion_factor) * lx)))
184            distorted_boxes.append([y0, x0, y1, x1])
185        return distorted_boxes
186
187    def _get_prompt_lists(self, gt, n_samples, prompt_generator):
188        """Returns a list of "expected" prompts subjected to the random input attributes for prompting."""
189
190        _, bbox_coordinates = get_centers_and_bounding_boxes(gt, mode="p")
191
192        # get the segment ids
193        cell_ids = np.unique(gt)[1:]
194        if n_samples is None:  # n-samples is set to None, so we use all ids
195            sampled_cell_ids = cell_ids
196
197        else:  # n-samples is set, so we subsample the cell ids
198            sampled_cell_ids = np.random.choice(cell_ids, size=min(n_samples, len(cell_ids)), replace=False)
199            sampled_cell_ids = np.sort(sampled_cell_ids)
200
201        # only keep the bounding boxes for sampled cell ids
202        bbox_coordinates = [bbox_coordinates[sampled_id] for sampled_id in sampled_cell_ids]
203        if self.box_distortion_factor is not None:
204            bbox_coordinates = self._distort_boxes(bbox_coordinates, shape=gt.shape[-2:])
205
206        # convert the gt to the one-hot-encoded masks for the sampled cell ids
207        object_masks = segmentation_to_one_hot(gt, None if n_samples is None else sampled_cell_ids)
208
209        # derive and return the prompts
210        point_prompts, point_label_prompts, box_prompts, _ = prompt_generator(object_masks, bbox_coordinates)
211        return box_prompts, point_prompts, point_label_prompts, sampled_cell_ids
212
213    def __call__(self, x, y, n_pos, n_neg, get_boxes=False, n_samples=None):
214        """Convert the outputs of dataloader and prompt settings to the batch format expected by SAM.
215        """
216        # condition to see if we get point prompts, then we (ofc) use point-prompting
217        # else we don't use point prompting
218        if n_pos == 0 and n_neg == 0:
219            get_points = False
220        else:
221            get_points = True
222
223        # keeping the solution open by checking for deterministic/dynamic choice of point prompts
224        prompt_generator = PointAndBoxPromptGenerator(
225            n_positive_points=n_pos,
226            n_negative_points=n_neg,
227            dilation_strength=self.dilation_strength,
228            get_box_prompts=get_boxes,
229            get_point_prompts=get_points
230        )
231
232        batched_inputs = []
233        batched_sampled_cell_ids_list = []
234
235        for image, gt in zip(x, y):
236            gt = gt.squeeze().numpy().astype(np.int64)
237            box_prompts, point_prompts, point_label_prompts, sampled_cell_ids = self._get_prompt_lists(
238                gt, n_samples, prompt_generator,
239            )
240
241            # check to be sure about the expected size of the no. of elements in different settings
242            if get_boxes:
243                assert len(sampled_cell_ids) == len(box_prompts), f"{len(sampled_cell_ids)}, {len(box_prompts)}"
244
245            if get_points:
246                assert len(sampled_cell_ids) == len(point_prompts) == len(point_label_prompts), \
247                    f"{len(sampled_cell_ids)}, {len(point_prompts)}, {len(point_label_prompts)}"
248
249            batched_sampled_cell_ids_list.append(sampled_cell_ids)
250
251            batched_input = {"image": image, "original_size": image.shape[1:]}
252            if get_boxes:
253                batched_input["boxes"] = self.transform.apply_boxes_torch(
254                    box_prompts, original_size=gt.shape[-2:]
255                ) if self.transform is not None else box_prompts
256
257            if get_points:
258                batched_input["point_coords"] = self.transform.apply_coords_torch(
259                    point_prompts, original_size=gt.shape[-2:]
260                ) if self.transform is not None else point_prompts
261                batched_input["point_labels"] = point_label_prompts
262
263            batched_inputs.append(batched_input)
264
265        return batched_inputs, batched_sampled_cell_ids_list

Convert outputs of data loader to the expected batched inputs of the SegmentAnything model.

Arguments:
  • transform: The transformation to resize the prompts. Should be the same transform used in the model to resize the inputs. If None the prompts will not be resized.
  • dilation_strength: The dilation factor. It determines a "safety" border from which prompts are not sampled to avoid ambiguous prompts due to imprecise groundtruth masks.
  • box_distortion_factor: Factor for distorting the box annotations derived from the groundtruth masks.
ConvertToSamInputs( transform: Optional[segment_anything.utils.transforms.ResizeLongestSide], dilation_strength: int = 10, box_distortion_factor: Optional[float] = None)
164    def __init__(
165        self,
166        transform: Optional[ResizeLongestSide],
167        dilation_strength: int = 10,
168        box_distortion_factor: Optional[float] = None,
169    ) -> None:
170        self.dilation_strength = dilation_strength
171        self.transform = identity if transform is None else transform
172        self.box_distortion_factor = box_distortion_factor
dilation_strength
transform
box_distortion_factor
class ConvertToSemanticSamInputs:
268class ConvertToSemanticSamInputs:
269    """Convert outputs of data loader to the expected batched inputs of the Segment Anything model
270    for semantic segmentation.
271    """
272    def __call__(self, x, y):
273        """Convert the outputs of dataloader to the batched format of inputs expected by SAM.
274        """
275        batched_inputs = []
276        for image in x:
277            batched_input = {"image": image, "original_size": image.shape[-2:]}
278            batched_inputs.append(batched_input)
279
280        return batched_inputs

Convert outputs of data loader to the expected batched inputs of the Segment Anything model for semantic segmentation.

def normalize_to_8bit(raw):
288def normalize_to_8bit(raw):
289    raw = normalize(raw) * 255
290    return raw
class ResizeRawTrafo:
293class ResizeRawTrafo:
294    def __init__(
295        self,
296        desired_shape: Tuple[int, ...],
297        do_rescaling: bool = False,
298        valid_channels: Optional[Union[int, Tuple[int, ...]]] = None,
299        padding: str = "constant"
300    ):
301        self.desired_shape = desired_shape
302        self.do_rescaling = do_rescaling
303        self.valid_channels = valid_channels
304        self.padding = padding
305
306    def __call__(self, raw):
307        raw = to_rgb(raw)  # Ensure all images are in 3-channels: triplicate one channel to three channels.
308
309        if self.do_rescaling:
310            raw = normalize_percentile(raw, axis=self.valid_channels)
311            raw = normalize(raw)
312            raw = raw * 255
313
314        # Pad the inputs to the desired shape.
315        tmp_ddim = [desired - curr for desired, curr in zip(self.desired_shape, raw.shape)]
316        ddim = [(per_dim / 2) for per_dim in tmp_ddim]
317        pad_width = [(ceil(d), floor(d)) for d in ddim]
318        raw = np.pad(raw, pad_width=pad_width, mode=self.padding)
319
320        assert raw.shape == self.desired_shape
321        return raw
ResizeRawTrafo( desired_shape: Tuple[int, ...], do_rescaling: bool = False, valid_channels: Union[int, Tuple[int, ...], NoneType] = None, padding: str = 'constant')
294    def __init__(
295        self,
296        desired_shape: Tuple[int, ...],
297        do_rescaling: bool = False,
298        valid_channels: Optional[Union[int, Tuple[int, ...]]] = None,
299        padding: str = "constant"
300    ):
301        self.desired_shape = desired_shape
302        self.do_rescaling = do_rescaling
303        self.valid_channels = valid_channels
304        self.padding = padding
desired_shape
do_rescaling
valid_channels
padding
class ResizeLabelTrafo:
324class ResizeLabelTrafo:
325    def __init__(
326        self, desired_shape: Tuple[int, ...], min_size: int = 0, padding: str = "constant",
327    ):
328        self.desired_shape = desired_shape
329        self.min_size = min_size
330        self.padding = padding
331
332    def __call__(self, labels):
333        distance_trafo = PerObjectDistanceTransform(
334            distances=True,
335            boundary_distances=True,
336            directed_distances=False,
337            foreground=True,
338            instances=True,
339            min_size=self.min_size
340        )
341        labels = distance_trafo(labels)
342
343        # choosing H and W from labels (4, H, W), from above dist trafo outputs
344        tmp_ddim = (self.desired_shape[0] - labels.shape[1], self.desired_shape[0] - labels.shape[2])
345        ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2)
346        labels = np.pad(
347            labels,
348            pad_width=((0, 0), (ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1]))),
349            mode=self.padding
350        )
351        assert labels.shape[1:] == self.desired_shape, labels.shape
352        return labels
ResizeLabelTrafo( desired_shape: Tuple[int, ...], min_size: int = 0, padding: str = 'constant')
325    def __init__(
326        self, desired_shape: Tuple[int, ...], min_size: int = 0, padding: str = "constant",
327    ):
328        self.desired_shape = desired_shape
329        self.min_size = min_size
330        self.padding = padding
desired_shape
min_size
padding