micro_sam.training.util
1import os 2from math import ceil, floor 3from functools import partial 4from typing import Dict, List, Optional, Union, Tuple, Callable 5 6import numpy as np 7 8import torch 9 10from segment_anything.utils.transforms import ResizeLongestSide 11 12from ..prompt_generators import PointAndBoxPromptGenerator 13from ..util import ( 14 get_centers_and_bounding_boxes, get_sam_model, get_device, 15 segmentation_to_one_hot, _DEFAULT_MODEL, 16) 17from .. import models as custom_models 18from .trainable_sam import TrainableSAM 19 20from torch_em.transform.label import PerObjectDistanceTransform 21from torch_em.transform.raw import normalize_percentile, normalize 22from torch_em.data.datasets.light_microscopy.neurips_cell_seg import to_rgb 23 24 25def identity(x): 26 """Identity transformation. 27 28 This is a helper function to skip data normalization when finetuning SAM. 29 Data normalization is performed within the model and should thus be skipped as 30 a preprocessing step in training. 31 """ 32 return x 33 34 35def require_8bit(x): 36 """Transformation to require 8bit input data range (0-255). 37 """ 38 if x.max() < 1: 39 x = x * 255 40 return x 41 42 43def _raw_transform(image: np.ndarray, raw_trafo: Callable) -> np.ndarray: 44 return raw_trafo(image) * 255 45 46 47def _normalize_percentile(image: np.ndarray) -> np.ndarray: 48 image = normalize_percentile(image) # Use 1st and 99th percentile values for min-max normalization. 49 image = np.clip(image, 0, 1) # Clip the values to be in range [0, 1]. 50 return image 51 52 53def get_raw_transform(preprocess: Optional[str] = None) -> Optional[Callable]: 54 """Transformation functions to normalize inputs. 55 56 Args: 57 preprocess: By default, the transformation function is set to 'None'. 58 The user can choose from 'normalize_minmax' / 'normalize_percentile'. 59 60 Returns: 61 The transformation function. 62 """ 63 64 if preprocess is None: # Ensures that inputs are 8-bit. 65 return require_8bit 66 else: 67 if preprocess == "normalize_minmax": 68 raw_trafo = normalize 69 elif preprocess == "normalize_percentile": 70 raw_trafo = _normalize_percentile 71 else: 72 raise ValueError(f"'{preprocess}' is not a supported preprocessing.") 73 74 return partial(_raw_transform, raw_trafo=raw_trafo) 75 76 77def get_trainable_sam_model( 78 model_type: str = _DEFAULT_MODEL, 79 device: Optional[Union[str, torch.device]] = None, 80 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 81 freeze: Optional[List[str]] = None, 82 return_state: bool = False, 83 peft_kwargs: Optional[Dict] = None, 84 flexible_load_checkpoint: bool = False, 85 **model_kwargs 86) -> TrainableSAM: 87 """Get the trainable sam model. 88 89 Args: 90 model_type: The segment anything model that should be finetuned. The weights of this model 91 will be used for initialization, unless a custom weight file is passed via `checkpoint_path`. 92 device: The device to use for training. 93 checkpoint_path: Path to a custom checkpoint from which to load the model weights. 94 freeze: Specify parts of the model that should be frozen, namely: `image_encoder`, `prompt_encoder` and 95 `mask_decoder`. By default nothing is frozen and the full model is updated. 96 return_state: Whether to return the full checkpoint state. 97 peft_kwargs: Keyword arguments for the PEFT wrapper class. 98 flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. 99 model_kwargs: Additional keyword arguments for the `util.get_sam_model`. 100 101 Returns: 102 The trainable segment anything model. 103 """ 104 # set the device here so that the correct one is passed to TrainableSAM below 105 device = get_device(device) 106 _, sam, state = get_sam_model( 107 model_type=model_type, 108 device=device, 109 checkpoint_path=checkpoint_path, 110 return_sam=True, 111 return_state=True, 112 flexible_load_checkpoint=flexible_load_checkpoint, 113 **model_kwargs 114 ) 115 116 # NOTE: This is done exclusive to "get_sam_model" here to use PEFT's layer-specific initialization on top. 117 # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. 118 # Overwrites the SAM model by freezing the backbone and allow PEFT methods. 119 if peft_kwargs and isinstance(peft_kwargs, dict): 120 if model_type[:5] == "vit_t": 121 raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.") 122 123 sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam 124 125 # freeze components of the model if freeze was passed 126 # ideally we would want to add components in such a way that: 127 # - we would be able to freeze the choice of encoder/decoder blocks, yet be able to add components to the network 128 # (for e.g. encoder blocks to "image_encoder") 129 if freeze is not None: 130 for name, param in sam.named_parameters(): 131 if not isinstance(freeze, list): 132 # we "freeze" only for one specific component when passed a "particular" part 133 freeze = [freeze] 134 135 # we would want to "freeze" all the components in the model if passed a list of parts 136 for l_item in freeze: 137 # in case PEFT is switched on, we cannot freeze the image encoder 138 if (peft_kwargs and peft_kwargs.get('rank') is not None) and (l_item == "image_encoder"): 139 raise ValueError("You cannot use PEFT & freeze the image encoder at the same time.") 140 141 if name.startswith(f"{l_item}"): 142 param.requires_grad = False 143 144 # convert to trainable sam 145 trainable_sam = TrainableSAM(sam) 146 147 if return_state: 148 return trainable_sam, state 149 return trainable_sam 150 151 152class ConvertToSamInputs: 153 """Convert outputs of data loader to the expected batched inputs of the SegmentAnything model. 154 155 Args: 156 transform: The transformation to resize the prompts. Should be the same transform used in the 157 model to resize the inputs. If `None` the prompts will not be resized. 158 dilation_strength: The dilation factor. 159 It determines a "safety" border from which prompts are not sampled to avoid ambiguous prompts 160 due to imprecise groundtruth masks. 161 box_distortion_factor: Factor for distorting the box annotations derived from the groundtruth masks. 162 """ 163 def __init__( 164 self, 165 transform: Optional[ResizeLongestSide], 166 dilation_strength: int = 10, 167 box_distortion_factor: Optional[float] = None, 168 ) -> None: 169 self.dilation_strength = dilation_strength 170 self.transform = identity if transform is None else transform 171 self.box_distortion_factor = box_distortion_factor 172 173 def _distort_boxes(self, bbox_coordinates, shape): 174 distorted_boxes = [] 175 for bbox in bbox_coordinates: 176 # The bounding box is parametrized by y0, x0, y1, x1. 177 y0, x0, y1, x1 = bbox 178 ly, lx = y1 - y0, x1 - x0 179 y0 = int(round(max(0, y0 - np.random.uniform(0, self.box_distortion_factor) * ly))) 180 y1 = int(round(min(shape[0], y1 + np.random.uniform(0, self.box_distortion_factor) * ly))) 181 x0 = int(round(max(0, x0 - np.random.uniform(0, self.box_distortion_factor) * lx))) 182 x1 = int(round(min(shape[1], x1 + np.random.uniform(0, self.box_distortion_factor) * lx))) 183 distorted_boxes.append([y0, x0, y1, x1]) 184 return distorted_boxes 185 186 def _get_prompt_lists(self, gt, n_samples, prompt_generator): 187 """Returns a list of "expected" prompts subjected to the random input attributes for prompting.""" 188 189 _, bbox_coordinates = get_centers_and_bounding_boxes(gt, mode="p") 190 191 # get the segment ids 192 cell_ids = np.unique(gt)[1:] 193 if n_samples is None: # n-samples is set to None, so we use all ids 194 sampled_cell_ids = cell_ids 195 196 else: # n-samples is set, so we subsample the cell ids 197 sampled_cell_ids = np.random.choice(cell_ids, size=min(n_samples, len(cell_ids)), replace=False) 198 sampled_cell_ids = np.sort(sampled_cell_ids) 199 200 # only keep the bounding boxes for sampled cell ids 201 bbox_coordinates = [bbox_coordinates[sampled_id] for sampled_id in sampled_cell_ids] 202 if self.box_distortion_factor is not None: 203 bbox_coordinates = self._distort_boxes(bbox_coordinates, shape=gt.shape[-2:]) 204 205 # convert the gt to the one-hot-encoded masks for the sampled cell ids 206 object_masks = segmentation_to_one_hot(gt, None if n_samples is None else sampled_cell_ids) 207 208 # derive and return the prompts 209 point_prompts, point_label_prompts, box_prompts, _ = prompt_generator(object_masks, bbox_coordinates) 210 return box_prompts, point_prompts, point_label_prompts, sampled_cell_ids 211 212 def __call__(self, x, y, n_pos, n_neg, get_boxes=False, n_samples=None): 213 """Convert the outputs of dataloader and prompt settings to the batch format expected by SAM. 214 """ 215 # condition to see if we get point prompts, then we (ofc) use point-prompting 216 # else we don't use point prompting 217 if n_pos == 0 and n_neg == 0: 218 get_points = False 219 else: 220 get_points = True 221 222 # keeping the solution open by checking for deterministic/dynamic choice of point prompts 223 prompt_generator = PointAndBoxPromptGenerator( 224 n_positive_points=n_pos, 225 n_negative_points=n_neg, 226 dilation_strength=self.dilation_strength, 227 get_box_prompts=get_boxes, 228 get_point_prompts=get_points 229 ) 230 231 batched_inputs = [] 232 batched_sampled_cell_ids_list = [] 233 234 for image, gt in zip(x, y): 235 gt = gt.squeeze().numpy().astype(np.int64) 236 box_prompts, point_prompts, point_label_prompts, sampled_cell_ids = self._get_prompt_lists( 237 gt, n_samples, prompt_generator, 238 ) 239 240 # check to be sure about the expected size of the no. of elements in different settings 241 if get_boxes: 242 assert len(sampled_cell_ids) == len(box_prompts), f"{len(sampled_cell_ids)}, {len(box_prompts)}" 243 244 if get_points: 245 assert len(sampled_cell_ids) == len(point_prompts) == len(point_label_prompts), \ 246 f"{len(sampled_cell_ids)}, {len(point_prompts)}, {len(point_label_prompts)}" 247 248 batched_sampled_cell_ids_list.append(sampled_cell_ids) 249 250 batched_input = {"image": image, "original_size": image.shape[1:]} 251 if get_boxes: 252 batched_input["boxes"] = self.transform.apply_boxes_torch( 253 box_prompts, original_size=gt.shape[-2:] 254 ) if self.transform is not None else box_prompts 255 256 if get_points: 257 batched_input["point_coords"] = self.transform.apply_coords_torch( 258 point_prompts, original_size=gt.shape[-2:] 259 ) if self.transform is not None else point_prompts 260 batched_input["point_labels"] = point_label_prompts 261 262 batched_inputs.append(batched_input) 263 264 return batched_inputs, batched_sampled_cell_ids_list 265 266 267class ConvertToSemanticSamInputs: 268 """Convert outputs of data loader to the expected batched inputs of the Segment Anything model 269 for semantic segmentation. 270 """ 271 def __call__(self, x, y): 272 """Convert the outputs of dataloader to the batched format of inputs expected by SAM. 273 """ 274 batched_inputs = [] 275 for image in x: 276 batched_input = {"image": image, "original_size": image.shape[-2:]} 277 batched_inputs.append(batched_input) 278 279 return batched_inputs 280 281 282# 283# Raw and Label Transformations for the Generalist and Specialist finetuning 284# 285 286 287def normalize_to_8bit(raw): 288 raw = normalize(raw) * 255 289 return raw 290 291 292class ResizeRawTrafo: 293 def __init__( 294 self, 295 desired_shape: Tuple[int, ...], 296 do_rescaling: bool = False, 297 valid_channels: Optional[Union[int, Tuple[int, ...]]] = None, 298 padding: str = "constant" 299 ): 300 self.desired_shape = desired_shape 301 self.do_rescaling = do_rescaling 302 self.valid_channels = valid_channels 303 self.padding = padding 304 305 def __call__(self, raw): 306 raw = to_rgb(raw) # Ensure all images are in 3-channels: triplicate one channel to three channels. 307 308 if self.do_rescaling: 309 raw = normalize_percentile(raw, axis=self.valid_channels) 310 raw = normalize(raw) 311 raw = raw * 255 312 313 # Pad the inputs to the desired shape. 314 tmp_ddim = [desired - curr for desired, curr in zip(self.desired_shape, raw.shape)] 315 ddim = [(per_dim / 2) for per_dim in tmp_ddim] 316 pad_width = [(ceil(d), floor(d)) for d in ddim] 317 raw = np.pad(raw, pad_width=pad_width, mode=self.padding) 318 319 assert raw.shape == self.desired_shape 320 return raw 321 322 323class ResizeLabelTrafo: 324 def __init__( 325 self, desired_shape: Tuple[int, ...], min_size: int = 0, padding: str = "constant", 326 ): 327 self.desired_shape = desired_shape 328 self.min_size = min_size 329 self.padding = padding 330 331 def __call__(self, labels): 332 distance_trafo = PerObjectDistanceTransform( 333 distances=True, 334 boundary_distances=True, 335 directed_distances=False, 336 foreground=True, 337 instances=True, 338 min_size=self.min_size 339 ) 340 labels = distance_trafo(labels) 341 342 # choosing H and W from labels (4, H, W), from above dist trafo outputs 343 tmp_ddim = (self.desired_shape[0] - labels.shape[1], self.desired_shape[0] - labels.shape[2]) 344 ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2) 345 labels = np.pad( 346 labels, 347 pad_width=((0, 0), (ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1]))), 348 mode=self.padding 349 ) 350 assert labels.shape[1:] == self.desired_shape, labels.shape 351 return labels
def
identity(x):
26def identity(x): 27 """Identity transformation. 28 29 This is a helper function to skip data normalization when finetuning SAM. 30 Data normalization is performed within the model and should thus be skipped as 31 a preprocessing step in training. 32 """ 33 return x
Identity transformation.
This is a helper function to skip data normalization when finetuning SAM. Data normalization is performed within the model and should thus be skipped as a preprocessing step in training.
def
require_8bit(x):
36def require_8bit(x): 37 """Transformation to require 8bit input data range (0-255). 38 """ 39 if x.max() < 1: 40 x = x * 255 41 return x
Transformation to require 8bit input data range (0-255).
def
get_raw_transform(preprocess: Optional[str] = None) -> Optional[Callable]:
54def get_raw_transform(preprocess: Optional[str] = None) -> Optional[Callable]: 55 """Transformation functions to normalize inputs. 56 57 Args: 58 preprocess: By default, the transformation function is set to 'None'. 59 The user can choose from 'normalize_minmax' / 'normalize_percentile'. 60 61 Returns: 62 The transformation function. 63 """ 64 65 if preprocess is None: # Ensures that inputs are 8-bit. 66 return require_8bit 67 else: 68 if preprocess == "normalize_minmax": 69 raw_trafo = normalize 70 elif preprocess == "normalize_percentile": 71 raw_trafo = _normalize_percentile 72 else: 73 raise ValueError(f"'{preprocess}' is not a supported preprocessing.") 74 75 return partial(_raw_transform, raw_trafo=raw_trafo)
Transformation functions to normalize inputs.
Arguments:
- preprocess: By default, the transformation function is set to 'None'. The user can choose from 'normalize_minmax' / 'normalize_percentile'.
Returns:
The transformation function.
def
get_trainable_sam_model( model_type: str = 'vit_b_lm', device: Union[str, torch.device, NoneType] = None, checkpoint_path: Union[os.PathLike, str, NoneType] = None, freeze: Optional[List[str]] = None, return_state: bool = False, peft_kwargs: Optional[Dict] = None, flexible_load_checkpoint: bool = False, **model_kwargs) -> micro_sam.training.trainable_sam.TrainableSAM:
78def get_trainable_sam_model( 79 model_type: str = _DEFAULT_MODEL, 80 device: Optional[Union[str, torch.device]] = None, 81 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 82 freeze: Optional[List[str]] = None, 83 return_state: bool = False, 84 peft_kwargs: Optional[Dict] = None, 85 flexible_load_checkpoint: bool = False, 86 **model_kwargs 87) -> TrainableSAM: 88 """Get the trainable sam model. 89 90 Args: 91 model_type: The segment anything model that should be finetuned. The weights of this model 92 will be used for initialization, unless a custom weight file is passed via `checkpoint_path`. 93 device: The device to use for training. 94 checkpoint_path: Path to a custom checkpoint from which to load the model weights. 95 freeze: Specify parts of the model that should be frozen, namely: `image_encoder`, `prompt_encoder` and 96 `mask_decoder`. By default nothing is frozen and the full model is updated. 97 return_state: Whether to return the full checkpoint state. 98 peft_kwargs: Keyword arguments for the PEFT wrapper class. 99 flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. 100 model_kwargs: Additional keyword arguments for the `util.get_sam_model`. 101 102 Returns: 103 The trainable segment anything model. 104 """ 105 # set the device here so that the correct one is passed to TrainableSAM below 106 device = get_device(device) 107 _, sam, state = get_sam_model( 108 model_type=model_type, 109 device=device, 110 checkpoint_path=checkpoint_path, 111 return_sam=True, 112 return_state=True, 113 flexible_load_checkpoint=flexible_load_checkpoint, 114 **model_kwargs 115 ) 116 117 # NOTE: This is done exclusive to "get_sam_model" here to use PEFT's layer-specific initialization on top. 118 # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. 119 # Overwrites the SAM model by freezing the backbone and allow PEFT methods. 120 if peft_kwargs and isinstance(peft_kwargs, dict): 121 if model_type[:5] == "vit_t": 122 raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.") 123 124 sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam 125 126 # freeze components of the model if freeze was passed 127 # ideally we would want to add components in such a way that: 128 # - we would be able to freeze the choice of encoder/decoder blocks, yet be able to add components to the network 129 # (for e.g. encoder blocks to "image_encoder") 130 if freeze is not None: 131 for name, param in sam.named_parameters(): 132 if not isinstance(freeze, list): 133 # we "freeze" only for one specific component when passed a "particular" part 134 freeze = [freeze] 135 136 # we would want to "freeze" all the components in the model if passed a list of parts 137 for l_item in freeze: 138 # in case PEFT is switched on, we cannot freeze the image encoder 139 if (peft_kwargs and peft_kwargs.get('rank') is not None) and (l_item == "image_encoder"): 140 raise ValueError("You cannot use PEFT & freeze the image encoder at the same time.") 141 142 if name.startswith(f"{l_item}"): 143 param.requires_grad = False 144 145 # convert to trainable sam 146 trainable_sam = TrainableSAM(sam) 147 148 if return_state: 149 return trainable_sam, state 150 return trainable_sam
Get the trainable sam model.
Arguments:
- model_type: The segment anything model that should be finetuned. The weights of this model
will be used for initialization, unless a custom weight file is passed via
checkpoint_path
. - device: The device to use for training.
- checkpoint_path: Path to a custom checkpoint from which to load the model weights.
- freeze: Specify parts of the model that should be frozen, namely:
image_encoder
,prompt_encoder
andmask_decoder
. By default nothing is frozen and the full model is updated. - return_state: Whether to return the full checkpoint state.
- peft_kwargs: Keyword arguments for the PEFT wrapper class.
- flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.
- model_kwargs: Additional keyword arguments for the
util.get_sam_model
.
Returns:
The trainable segment anything model.
class
ConvertToSamInputs:
153class ConvertToSamInputs: 154 """Convert outputs of data loader to the expected batched inputs of the SegmentAnything model. 155 156 Args: 157 transform: The transformation to resize the prompts. Should be the same transform used in the 158 model to resize the inputs. If `None` the prompts will not be resized. 159 dilation_strength: The dilation factor. 160 It determines a "safety" border from which prompts are not sampled to avoid ambiguous prompts 161 due to imprecise groundtruth masks. 162 box_distortion_factor: Factor for distorting the box annotations derived from the groundtruth masks. 163 """ 164 def __init__( 165 self, 166 transform: Optional[ResizeLongestSide], 167 dilation_strength: int = 10, 168 box_distortion_factor: Optional[float] = None, 169 ) -> None: 170 self.dilation_strength = dilation_strength 171 self.transform = identity if transform is None else transform 172 self.box_distortion_factor = box_distortion_factor 173 174 def _distort_boxes(self, bbox_coordinates, shape): 175 distorted_boxes = [] 176 for bbox in bbox_coordinates: 177 # The bounding box is parametrized by y0, x0, y1, x1. 178 y0, x0, y1, x1 = bbox 179 ly, lx = y1 - y0, x1 - x0 180 y0 = int(round(max(0, y0 - np.random.uniform(0, self.box_distortion_factor) * ly))) 181 y1 = int(round(min(shape[0], y1 + np.random.uniform(0, self.box_distortion_factor) * ly))) 182 x0 = int(round(max(0, x0 - np.random.uniform(0, self.box_distortion_factor) * lx))) 183 x1 = int(round(min(shape[1], x1 + np.random.uniform(0, self.box_distortion_factor) * lx))) 184 distorted_boxes.append([y0, x0, y1, x1]) 185 return distorted_boxes 186 187 def _get_prompt_lists(self, gt, n_samples, prompt_generator): 188 """Returns a list of "expected" prompts subjected to the random input attributes for prompting.""" 189 190 _, bbox_coordinates = get_centers_and_bounding_boxes(gt, mode="p") 191 192 # get the segment ids 193 cell_ids = np.unique(gt)[1:] 194 if n_samples is None: # n-samples is set to None, so we use all ids 195 sampled_cell_ids = cell_ids 196 197 else: # n-samples is set, so we subsample the cell ids 198 sampled_cell_ids = np.random.choice(cell_ids, size=min(n_samples, len(cell_ids)), replace=False) 199 sampled_cell_ids = np.sort(sampled_cell_ids) 200 201 # only keep the bounding boxes for sampled cell ids 202 bbox_coordinates = [bbox_coordinates[sampled_id] for sampled_id in sampled_cell_ids] 203 if self.box_distortion_factor is not None: 204 bbox_coordinates = self._distort_boxes(bbox_coordinates, shape=gt.shape[-2:]) 205 206 # convert the gt to the one-hot-encoded masks for the sampled cell ids 207 object_masks = segmentation_to_one_hot(gt, None if n_samples is None else sampled_cell_ids) 208 209 # derive and return the prompts 210 point_prompts, point_label_prompts, box_prompts, _ = prompt_generator(object_masks, bbox_coordinates) 211 return box_prompts, point_prompts, point_label_prompts, sampled_cell_ids 212 213 def __call__(self, x, y, n_pos, n_neg, get_boxes=False, n_samples=None): 214 """Convert the outputs of dataloader and prompt settings to the batch format expected by SAM. 215 """ 216 # condition to see if we get point prompts, then we (ofc) use point-prompting 217 # else we don't use point prompting 218 if n_pos == 0 and n_neg == 0: 219 get_points = False 220 else: 221 get_points = True 222 223 # keeping the solution open by checking for deterministic/dynamic choice of point prompts 224 prompt_generator = PointAndBoxPromptGenerator( 225 n_positive_points=n_pos, 226 n_negative_points=n_neg, 227 dilation_strength=self.dilation_strength, 228 get_box_prompts=get_boxes, 229 get_point_prompts=get_points 230 ) 231 232 batched_inputs = [] 233 batched_sampled_cell_ids_list = [] 234 235 for image, gt in zip(x, y): 236 gt = gt.squeeze().numpy().astype(np.int64) 237 box_prompts, point_prompts, point_label_prompts, sampled_cell_ids = self._get_prompt_lists( 238 gt, n_samples, prompt_generator, 239 ) 240 241 # check to be sure about the expected size of the no. of elements in different settings 242 if get_boxes: 243 assert len(sampled_cell_ids) == len(box_prompts), f"{len(sampled_cell_ids)}, {len(box_prompts)}" 244 245 if get_points: 246 assert len(sampled_cell_ids) == len(point_prompts) == len(point_label_prompts), \ 247 f"{len(sampled_cell_ids)}, {len(point_prompts)}, {len(point_label_prompts)}" 248 249 batched_sampled_cell_ids_list.append(sampled_cell_ids) 250 251 batched_input = {"image": image, "original_size": image.shape[1:]} 252 if get_boxes: 253 batched_input["boxes"] = self.transform.apply_boxes_torch( 254 box_prompts, original_size=gt.shape[-2:] 255 ) if self.transform is not None else box_prompts 256 257 if get_points: 258 batched_input["point_coords"] = self.transform.apply_coords_torch( 259 point_prompts, original_size=gt.shape[-2:] 260 ) if self.transform is not None else point_prompts 261 batched_input["point_labels"] = point_label_prompts 262 263 batched_inputs.append(batched_input) 264 265 return batched_inputs, batched_sampled_cell_ids_list
Convert outputs of data loader to the expected batched inputs of the SegmentAnything model.
Arguments:
- transform: The transformation to resize the prompts. Should be the same transform used in the
model to resize the inputs. If
None
the prompts will not be resized. - dilation_strength: The dilation factor. It determines a "safety" border from which prompts are not sampled to avoid ambiguous prompts due to imprecise groundtruth masks.
- box_distortion_factor: Factor for distorting the box annotations derived from the groundtruth masks.
ConvertToSamInputs( transform: Optional[segment_anything.utils.transforms.ResizeLongestSide], dilation_strength: int = 10, box_distortion_factor: Optional[float] = None)
164 def __init__( 165 self, 166 transform: Optional[ResizeLongestSide], 167 dilation_strength: int = 10, 168 box_distortion_factor: Optional[float] = None, 169 ) -> None: 170 self.dilation_strength = dilation_strength 171 self.transform = identity if transform is None else transform 172 self.box_distortion_factor = box_distortion_factor
class
ConvertToSemanticSamInputs:
268class ConvertToSemanticSamInputs: 269 """Convert outputs of data loader to the expected batched inputs of the Segment Anything model 270 for semantic segmentation. 271 """ 272 def __call__(self, x, y): 273 """Convert the outputs of dataloader to the batched format of inputs expected by SAM. 274 """ 275 batched_inputs = [] 276 for image in x: 277 batched_input = {"image": image, "original_size": image.shape[-2:]} 278 batched_inputs.append(batched_input) 279 280 return batched_inputs
Convert outputs of data loader to the expected batched inputs of the Segment Anything model for semantic segmentation.
def
normalize_to_8bit(raw):
class
ResizeRawTrafo:
293class ResizeRawTrafo: 294 def __init__( 295 self, 296 desired_shape: Tuple[int, ...], 297 do_rescaling: bool = False, 298 valid_channels: Optional[Union[int, Tuple[int, ...]]] = None, 299 padding: str = "constant" 300 ): 301 self.desired_shape = desired_shape 302 self.do_rescaling = do_rescaling 303 self.valid_channels = valid_channels 304 self.padding = padding 305 306 def __call__(self, raw): 307 raw = to_rgb(raw) # Ensure all images are in 3-channels: triplicate one channel to three channels. 308 309 if self.do_rescaling: 310 raw = normalize_percentile(raw, axis=self.valid_channels) 311 raw = normalize(raw) 312 raw = raw * 255 313 314 # Pad the inputs to the desired shape. 315 tmp_ddim = [desired - curr for desired, curr in zip(self.desired_shape, raw.shape)] 316 ddim = [(per_dim / 2) for per_dim in tmp_ddim] 317 pad_width = [(ceil(d), floor(d)) for d in ddim] 318 raw = np.pad(raw, pad_width=pad_width, mode=self.padding) 319 320 assert raw.shape == self.desired_shape 321 return raw
ResizeRawTrafo( desired_shape: Tuple[int, ...], do_rescaling: bool = False, valid_channels: Union[int, Tuple[int, ...], NoneType] = None, padding: str = 'constant')
294 def __init__( 295 self, 296 desired_shape: Tuple[int, ...], 297 do_rescaling: bool = False, 298 valid_channels: Optional[Union[int, Tuple[int, ...]]] = None, 299 padding: str = "constant" 300 ): 301 self.desired_shape = desired_shape 302 self.do_rescaling = do_rescaling 303 self.valid_channels = valid_channels 304 self.padding = padding
class
ResizeLabelTrafo:
324class ResizeLabelTrafo: 325 def __init__( 326 self, desired_shape: Tuple[int, ...], min_size: int = 0, padding: str = "constant", 327 ): 328 self.desired_shape = desired_shape 329 self.min_size = min_size 330 self.padding = padding 331 332 def __call__(self, labels): 333 distance_trafo = PerObjectDistanceTransform( 334 distances=True, 335 boundary_distances=True, 336 directed_distances=False, 337 foreground=True, 338 instances=True, 339 min_size=self.min_size 340 ) 341 labels = distance_trafo(labels) 342 343 # choosing H and W from labels (4, H, W), from above dist trafo outputs 344 tmp_ddim = (self.desired_shape[0] - labels.shape[1], self.desired_shape[0] - labels.shape[2]) 345 ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2) 346 labels = np.pad( 347 labels, 348 pad_width=((0, 0), (ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1]))), 349 mode=self.padding 350 ) 351 assert labels.shape[1:] == self.desired_shape, labels.shape 352 return labels