micro_sam.training.util
1import os 2from math import ceil, floor 3from functools import partial 4from typing import Dict, List, Optional, Union, Tuple, Callable 5 6import numpy as np 7 8import torch 9 10from segment_anything.utils.transforms import ResizeLongestSide 11 12from ..prompt_generators import PointAndBoxPromptGenerator 13from ..util import ( 14 get_centers_and_bounding_boxes, get_sam_model, get_device, 15 segmentation_to_one_hot, _DEFAULT_MODEL, 16) 17from .. import models as custom_models 18from .trainable_sam import TrainableSAM 19 20from torch_em.transform.label import PerObjectDistanceTransform 21from torch_em.transform.raw import normalize_percentile, normalize 22from torch_em.data.datasets.light_microscopy.neurips_cell_seg import to_rgb 23 24 25def identity(x): 26 """Identity transformation. 27 28 This is a helper function to skip data normalization when finetuning SAM. 29 Data normalization is performed within the model and should thus be skipped as 30 a preprocessing step in training. 31 """ 32 return x 33 34 35def require_8bit(x): 36 """Transformation to require 8bit input data range (0-255). 37 """ 38 if x.max() < 1: 39 x = x * 255 40 return x 41 42 43def _raw_transform(image: np.ndarray, raw_trafo: Callable) -> np.ndarray: 44 return raw_trafo(image) * 255 45 46 47def _normalize_percentile(image: np.ndarray) -> np.ndarray: 48 image = normalize_percentile(image) # Use 1st and 99th percentile values for min-max normalization. 49 image = np.clip(image, 0, 1) # Clip the values to be in range [0, 1]. 50 return image 51 52 53def get_raw_transform(preprocess: Optional[str] = None) -> Optional[Callable]: 54 """Transformation functions to normalize inputs. 55 56 Args: 57 preprocess: By default, the transformation function is set to 'None'. 58 The user can choose from 'normalize_minmax' / 'normalize_percentile'. 59 60 Returns: 61 The transformation function. 62 """ 63 64 if preprocess is None: # Ensures that inputs are 8-bit. 65 return require_8bit 66 else: 67 if preprocess == "normalize_minmax": 68 raw_trafo = normalize 69 elif preprocess == "normalize_percentile": 70 raw_trafo = _normalize_percentile 71 else: 72 raise ValueError(f"'{preprocess}' is not a supported preprocessing.") 73 74 return partial(_raw_transform, raw_trafo=raw_trafo) 75 76 77def get_trainable_sam_model( 78 model_type: str = _DEFAULT_MODEL, 79 device: Optional[Union[str, torch.device]] = None, 80 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 81 freeze: Optional[List[str]] = None, 82 return_state: bool = False, 83 peft_kwargs: Optional[Dict] = None, 84 flexible_load_checkpoint: bool = False, 85 **model_kwargs 86) -> TrainableSAM: 87 """Get the trainable sam model. 88 89 Args: 90 model_type: The Segment Anything model that should be finetuned. The weights of this model 91 will be used for initialization, unless a custom weight file is passed via `checkpoint_path`. 92 device: The device to use for training. By default, automatically chooses the best available model. 93 checkpoint_path: Path to a custom checkpoint from which to load the model weights. 94 freeze: Specify parts of the model that should be frozen, namely: `image_encoder`, `prompt_encoder` and 95 `mask_decoder`. By default nothing is frozen and the full model is updated. 96 return_state: Whether to return the full checkpoint state. By default, set to 'False'. 97 peft_kwargs: Keyword arguments for the PEFT wrapper class. 98 flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. 99 By default, set to 'False'. 100 model_kwargs: Additional keyword arguments for the `util.get_sam_model`. 101 102 Returns: 103 The trainable segment anything model. 104 """ 105 # set the device here so that the correct one is passed to TrainableSAM below 106 device = get_device(device) 107 _, sam, state = get_sam_model( 108 model_type=model_type, 109 device=device, 110 checkpoint_path=checkpoint_path, 111 return_sam=True, 112 return_state=True, 113 flexible_load_checkpoint=flexible_load_checkpoint, 114 **model_kwargs 115 ) 116 117 # NOTE: This is done exclusive to "get_sam_model" here to use PEFT's layer-specific initialization on top. 118 # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. 119 # Overwrites the SAM model by freezing the backbone and allow PEFT methods. 120 if peft_kwargs and isinstance(peft_kwargs, dict): 121 if model_type[:5] == "vit_t": 122 raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.") 123 124 sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam 125 126 # freeze components of the model if freeze was passed 127 # ideally we would want to add components in such a way that: 128 # - we would be able to freeze the choice of encoder/decoder blocks, yet be able to add components to the network 129 # (for e.g. encoder blocks to "image_encoder") 130 if freeze is not None: 131 for name, param in sam.named_parameters(): 132 if not isinstance(freeze, list): 133 # we "freeze" only for one specific component when passed a "particular" part 134 freeze = [freeze] 135 136 # we would want to "freeze" all the components in the model if passed a list of parts 137 for l_item in freeze: 138 # in case PEFT is switched on, we cannot freeze the image encoder 139 if (peft_kwargs and peft_kwargs.get('rank') is not None) and (l_item == "image_encoder"): 140 raise ValueError("You cannot use PEFT & freeze the image encoder at the same time.") 141 142 if name.startswith(f"{l_item}"): 143 param.requires_grad = False 144 145 # convert to trainable sam 146 trainable_sam = TrainableSAM(sam) 147 148 if return_state: 149 return trainable_sam, state 150 return trainable_sam 151 152 153class ConvertToSamInputs: 154 """Convert outputs of data loader to the expected batched inputs of the Segment Anything model. 155 156 Args: 157 transform: The transformation to resize the prompts. Should be the same transform used in the 158 model to resize the inputs. If `None` the prompts will not be resized. 159 dilation_strength: The dilation factor. 160 It determines a "safety" border from which prompts are not sampled to avoid ambiguous prompts 161 due to imprecise groundtruth masks. By default, set to '10'. 162 box_distortion_factor: Factor for distorting the box annotations derived from the groundtruth masks. 163 """ 164 def __init__( 165 self, 166 transform: Optional[ResizeLongestSide], 167 dilation_strength: int = 10, 168 box_distortion_factor: Optional[float] = None, 169 ) -> None: 170 self.dilation_strength = dilation_strength 171 self.transform = identity if transform is None else transform 172 self.box_distortion_factor = box_distortion_factor 173 174 def _distort_boxes(self, bbox_coordinates, shape): 175 distorted_boxes = [] 176 for bbox in bbox_coordinates: 177 # The bounding box is parametrized by y0, x0, y1, x1. 178 y0, x0, y1, x1 = bbox 179 ly, lx = y1 - y0, x1 - x0 180 y0 = int(round(max(0, y0 - np.random.uniform(0, self.box_distortion_factor) * ly))) 181 y1 = int(round(min(shape[0], y1 + np.random.uniform(0, self.box_distortion_factor) * ly))) 182 x0 = int(round(max(0, x0 - np.random.uniform(0, self.box_distortion_factor) * lx))) 183 x1 = int(round(min(shape[1], x1 + np.random.uniform(0, self.box_distortion_factor) * lx))) 184 distorted_boxes.append([y0, x0, y1, x1]) 185 return distorted_boxes 186 187 def _get_prompt_lists(self, gt, n_samples, prompt_generator): 188 """Returns a list of "expected" prompts subjected to the random input attributes for prompting.""" 189 190 _, bbox_coordinates = get_centers_and_bounding_boxes(gt, mode="p") 191 192 # get the segment ids 193 cell_ids = np.unique(gt)[1:] 194 if n_samples is None: # n-samples is set to None, so we use all ids 195 sampled_cell_ids = cell_ids 196 197 else: # n-samples is set, so we subsample the cell ids 198 sampled_cell_ids = np.random.choice(cell_ids, size=min(n_samples, len(cell_ids)), replace=False) 199 sampled_cell_ids = np.sort(sampled_cell_ids) 200 201 # only keep the bounding boxes for sampled cell ids 202 bbox_coordinates = [bbox_coordinates[sampled_id] for sampled_id in sampled_cell_ids] 203 if self.box_distortion_factor is not None: 204 bbox_coordinates = self._distort_boxes(bbox_coordinates, shape=gt.shape[-2:]) 205 206 # convert the gt to the one-hot-encoded masks for the sampled cell ids 207 object_masks = segmentation_to_one_hot(gt, None if n_samples is None else sampled_cell_ids) 208 209 # derive and return the prompts 210 point_prompts, point_label_prompts, box_prompts, _ = prompt_generator(object_masks, bbox_coordinates) 211 return box_prompts, point_prompts, point_label_prompts, sampled_cell_ids 212 213 def __call__(self, x, y, n_pos, n_neg, get_boxes=False, n_samples=None): 214 """Convert the outputs of dataloader and prompt settings to the batch format expected by SAM. 215 """ 216 # condition to see if we get point prompts, then we (ofc) use point-prompting 217 # else we don't use point prompting 218 if n_pos == 0 and n_neg == 0: 219 get_points = False 220 else: 221 get_points = True 222 223 # keeping the solution open by checking for deterministic/dynamic choice of point prompts 224 prompt_generator = PointAndBoxPromptGenerator( 225 n_positive_points=n_pos, 226 n_negative_points=n_neg, 227 dilation_strength=self.dilation_strength, 228 get_box_prompts=get_boxes, 229 get_point_prompts=get_points 230 ) 231 232 batched_inputs = [] 233 batched_sampled_cell_ids_list = [] 234 235 for image, gt in zip(x, y): 236 gt = gt.squeeze().numpy().astype(np.int64) 237 box_prompts, point_prompts, point_label_prompts, sampled_cell_ids = self._get_prompt_lists( 238 gt, n_samples, prompt_generator, 239 ) 240 241 # check to be sure about the expected size of the no. of elements in different settings 242 if get_boxes: 243 assert len(sampled_cell_ids) == len(box_prompts), f"{len(sampled_cell_ids)}, {len(box_prompts)}" 244 245 if get_points: 246 assert len(sampled_cell_ids) == len(point_prompts) == len(point_label_prompts), \ 247 f"{len(sampled_cell_ids)}, {len(point_prompts)}, {len(point_label_prompts)}" 248 249 batched_sampled_cell_ids_list.append(sampled_cell_ids) 250 251 batched_input = {"image": image, "original_size": image.shape[1:]} 252 if get_boxes: 253 batched_input["boxes"] = self.transform.apply_boxes_torch( 254 box_prompts, original_size=gt.shape[-2:] 255 ) if self.transform is not None else box_prompts 256 257 if get_points: 258 batched_input["point_coords"] = self.transform.apply_coords_torch( 259 point_prompts, original_size=gt.shape[-2:] 260 ) if self.transform is not None else point_prompts 261 batched_input["point_labels"] = point_label_prompts 262 263 batched_inputs.append(batched_input) 264 265 return batched_inputs, batched_sampled_cell_ids_list 266 267 268class ConvertToSemanticSamInputs: 269 """Convert outputs of data loader to the expected batched inputs of the Segment Anything model 270 for semantic segmentation. 271 """ 272 def __call__(self, x, y): 273 """Convert the outputs of dataloader to the batched format of inputs expected by SAM. 274 """ 275 batched_inputs = [] 276 for image in x: 277 batched_input = {"image": image, "original_size": image.shape[-2:]} 278 batched_inputs.append(batched_input) 279 280 return batched_inputs 281 282 283# 284# Raw and Label Transformations for the Generalist and Specialist finetuning 285# 286 287 288def normalize_to_8bit(raw): 289 raw = normalize(raw) * 255 290 return raw 291 292 293class ResizeRawTrafo: 294 def __init__( 295 self, 296 desired_shape: Tuple[int, ...], 297 do_rescaling: bool = False, 298 valid_channels: Optional[Union[int, Tuple[int, ...]]] = None, 299 padding: str = "constant", 300 ensure_rgb: bool = True, 301 ): 302 self.desired_shape = desired_shape 303 self.do_rescaling = do_rescaling 304 self.valid_channels = valid_channels 305 self.padding = padding 306 self.ensure_rgb = ensure_rgb 307 308 def __call__(self, raw): 309 if self.ensure_rgb: 310 raw = to_rgb(raw) # Ensure all images are in 3-channels: triplicate one channel to three channels. 311 312 if self.do_rescaling: 313 raw = normalize_percentile(raw, axis=self.valid_channels) 314 raw = normalize(raw) 315 raw = raw * 255 316 317 # Pad the inputs to the desired shape. 318 tmp_ddim = [desired - curr for desired, curr in zip(self.desired_shape, raw.shape)] 319 ddim = [(per_dim / 2) for per_dim in tmp_ddim] 320 pad_width = [(ceil(d), floor(d)) for d in ddim] 321 raw = np.pad(raw, pad_width=pad_width, mode=self.padding) 322 323 assert raw.shape == self.desired_shape 324 return raw 325 326 327class ResizeLabelTrafo: 328 def __init__( 329 self, desired_shape: Tuple[int, ...], min_size: int = 0, padding: str = "constant", 330 ): 331 self.desired_shape = desired_shape 332 self.min_size = min_size 333 self.padding = padding 334 335 def __call__(self, labels): 336 distance_trafo = PerObjectDistanceTransform( 337 distances=True, 338 boundary_distances=True, 339 directed_distances=False, 340 foreground=True, 341 instances=True, 342 min_size=self.min_size 343 ) 344 labels = distance_trafo(labels) 345 346 # choosing H and W from labels (4, H, W), from above dist trafo outputs 347 tmp_ddim = (self.desired_shape[0] - labels.shape[1], self.desired_shape[0] - labels.shape[2]) 348 ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2) 349 labels = np.pad( 350 labels, 351 pad_width=((0, 0), (ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1]))), 352 mode=self.padding 353 ) 354 assert labels.shape[1:] == self.desired_shape, labels.shape 355 return labels
def
identity(x):
26def identity(x): 27 """Identity transformation. 28 29 This is a helper function to skip data normalization when finetuning SAM. 30 Data normalization is performed within the model and should thus be skipped as 31 a preprocessing step in training. 32 """ 33 return x
Identity transformation.
This is a helper function to skip data normalization when finetuning SAM. Data normalization is performed within the model and should thus be skipped as a preprocessing step in training.
def
require_8bit(x):
36def require_8bit(x): 37 """Transformation to require 8bit input data range (0-255). 38 """ 39 if x.max() < 1: 40 x = x * 255 41 return x
Transformation to require 8bit input data range (0-255).
def
get_raw_transform(preprocess: Optional[str] = None) -> Optional[Callable]:
54def get_raw_transform(preprocess: Optional[str] = None) -> Optional[Callable]: 55 """Transformation functions to normalize inputs. 56 57 Args: 58 preprocess: By default, the transformation function is set to 'None'. 59 The user can choose from 'normalize_minmax' / 'normalize_percentile'. 60 61 Returns: 62 The transformation function. 63 """ 64 65 if preprocess is None: # Ensures that inputs are 8-bit. 66 return require_8bit 67 else: 68 if preprocess == "normalize_minmax": 69 raw_trafo = normalize 70 elif preprocess == "normalize_percentile": 71 raw_trafo = _normalize_percentile 72 else: 73 raise ValueError(f"'{preprocess}' is not a supported preprocessing.") 74 75 return partial(_raw_transform, raw_trafo=raw_trafo)
Transformation functions to normalize inputs.
Arguments:
- preprocess: By default, the transformation function is set to 'None'. The user can choose from 'normalize_minmax' / 'normalize_percentile'.
Returns:
The transformation function.
def
get_trainable_sam_model( model_type: str = 'vit_b_lm', device: Union[str, torch.device, NoneType] = None, checkpoint_path: Union[os.PathLike, str, NoneType] = None, freeze: Optional[List[str]] = None, return_state: bool = False, peft_kwargs: Optional[Dict] = None, flexible_load_checkpoint: bool = False, **model_kwargs) -> micro_sam.training.trainable_sam.TrainableSAM:
78def get_trainable_sam_model( 79 model_type: str = _DEFAULT_MODEL, 80 device: Optional[Union[str, torch.device]] = None, 81 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 82 freeze: Optional[List[str]] = None, 83 return_state: bool = False, 84 peft_kwargs: Optional[Dict] = None, 85 flexible_load_checkpoint: bool = False, 86 **model_kwargs 87) -> TrainableSAM: 88 """Get the trainable sam model. 89 90 Args: 91 model_type: The Segment Anything model that should be finetuned. The weights of this model 92 will be used for initialization, unless a custom weight file is passed via `checkpoint_path`. 93 device: The device to use for training. By default, automatically chooses the best available model. 94 checkpoint_path: Path to a custom checkpoint from which to load the model weights. 95 freeze: Specify parts of the model that should be frozen, namely: `image_encoder`, `prompt_encoder` and 96 `mask_decoder`. By default nothing is frozen and the full model is updated. 97 return_state: Whether to return the full checkpoint state. By default, set to 'False'. 98 peft_kwargs: Keyword arguments for the PEFT wrapper class. 99 flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. 100 By default, set to 'False'. 101 model_kwargs: Additional keyword arguments for the `util.get_sam_model`. 102 103 Returns: 104 The trainable segment anything model. 105 """ 106 # set the device here so that the correct one is passed to TrainableSAM below 107 device = get_device(device) 108 _, sam, state = get_sam_model( 109 model_type=model_type, 110 device=device, 111 checkpoint_path=checkpoint_path, 112 return_sam=True, 113 return_state=True, 114 flexible_load_checkpoint=flexible_load_checkpoint, 115 **model_kwargs 116 ) 117 118 # NOTE: This is done exclusive to "get_sam_model" here to use PEFT's layer-specific initialization on top. 119 # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. 120 # Overwrites the SAM model by freezing the backbone and allow PEFT methods. 121 if peft_kwargs and isinstance(peft_kwargs, dict): 122 if model_type[:5] == "vit_t": 123 raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.") 124 125 sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam 126 127 # freeze components of the model if freeze was passed 128 # ideally we would want to add components in such a way that: 129 # - we would be able to freeze the choice of encoder/decoder blocks, yet be able to add components to the network 130 # (for e.g. encoder blocks to "image_encoder") 131 if freeze is not None: 132 for name, param in sam.named_parameters(): 133 if not isinstance(freeze, list): 134 # we "freeze" only for one specific component when passed a "particular" part 135 freeze = [freeze] 136 137 # we would want to "freeze" all the components in the model if passed a list of parts 138 for l_item in freeze: 139 # in case PEFT is switched on, we cannot freeze the image encoder 140 if (peft_kwargs and peft_kwargs.get('rank') is not None) and (l_item == "image_encoder"): 141 raise ValueError("You cannot use PEFT & freeze the image encoder at the same time.") 142 143 if name.startswith(f"{l_item}"): 144 param.requires_grad = False 145 146 # convert to trainable sam 147 trainable_sam = TrainableSAM(sam) 148 149 if return_state: 150 return trainable_sam, state 151 return trainable_sam
Get the trainable sam model.
Arguments:
- model_type: The Segment Anything model that should be finetuned. The weights of this model
will be used for initialization, unless a custom weight file is passed via
checkpoint_path
. - device: The device to use for training. By default, automatically chooses the best available model.
- checkpoint_path: Path to a custom checkpoint from which to load the model weights.
- freeze: Specify parts of the model that should be frozen, namely:
image_encoder
,prompt_encoder
andmask_decoder
. By default nothing is frozen and the full model is updated. - return_state: Whether to return the full checkpoint state. By default, set to 'False'.
- peft_kwargs: Keyword arguments for the PEFT wrapper class.
- flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. By default, set to 'False'.
- model_kwargs: Additional keyword arguments for the
util.get_sam_model
.
Returns:
The trainable segment anything model.
class
ConvertToSamInputs:
154class ConvertToSamInputs: 155 """Convert outputs of data loader to the expected batched inputs of the Segment Anything model. 156 157 Args: 158 transform: The transformation to resize the prompts. Should be the same transform used in the 159 model to resize the inputs. If `None` the prompts will not be resized. 160 dilation_strength: The dilation factor. 161 It determines a "safety" border from which prompts are not sampled to avoid ambiguous prompts 162 due to imprecise groundtruth masks. By default, set to '10'. 163 box_distortion_factor: Factor for distorting the box annotations derived from the groundtruth masks. 164 """ 165 def __init__( 166 self, 167 transform: Optional[ResizeLongestSide], 168 dilation_strength: int = 10, 169 box_distortion_factor: Optional[float] = None, 170 ) -> None: 171 self.dilation_strength = dilation_strength 172 self.transform = identity if transform is None else transform 173 self.box_distortion_factor = box_distortion_factor 174 175 def _distort_boxes(self, bbox_coordinates, shape): 176 distorted_boxes = [] 177 for bbox in bbox_coordinates: 178 # The bounding box is parametrized by y0, x0, y1, x1. 179 y0, x0, y1, x1 = bbox 180 ly, lx = y1 - y0, x1 - x0 181 y0 = int(round(max(0, y0 - np.random.uniform(0, self.box_distortion_factor) * ly))) 182 y1 = int(round(min(shape[0], y1 + np.random.uniform(0, self.box_distortion_factor) * ly))) 183 x0 = int(round(max(0, x0 - np.random.uniform(0, self.box_distortion_factor) * lx))) 184 x1 = int(round(min(shape[1], x1 + np.random.uniform(0, self.box_distortion_factor) * lx))) 185 distorted_boxes.append([y0, x0, y1, x1]) 186 return distorted_boxes 187 188 def _get_prompt_lists(self, gt, n_samples, prompt_generator): 189 """Returns a list of "expected" prompts subjected to the random input attributes for prompting.""" 190 191 _, bbox_coordinates = get_centers_and_bounding_boxes(gt, mode="p") 192 193 # get the segment ids 194 cell_ids = np.unique(gt)[1:] 195 if n_samples is None: # n-samples is set to None, so we use all ids 196 sampled_cell_ids = cell_ids 197 198 else: # n-samples is set, so we subsample the cell ids 199 sampled_cell_ids = np.random.choice(cell_ids, size=min(n_samples, len(cell_ids)), replace=False) 200 sampled_cell_ids = np.sort(sampled_cell_ids) 201 202 # only keep the bounding boxes for sampled cell ids 203 bbox_coordinates = [bbox_coordinates[sampled_id] for sampled_id in sampled_cell_ids] 204 if self.box_distortion_factor is not None: 205 bbox_coordinates = self._distort_boxes(bbox_coordinates, shape=gt.shape[-2:]) 206 207 # convert the gt to the one-hot-encoded masks for the sampled cell ids 208 object_masks = segmentation_to_one_hot(gt, None if n_samples is None else sampled_cell_ids) 209 210 # derive and return the prompts 211 point_prompts, point_label_prompts, box_prompts, _ = prompt_generator(object_masks, bbox_coordinates) 212 return box_prompts, point_prompts, point_label_prompts, sampled_cell_ids 213 214 def __call__(self, x, y, n_pos, n_neg, get_boxes=False, n_samples=None): 215 """Convert the outputs of dataloader and prompt settings to the batch format expected by SAM. 216 """ 217 # condition to see if we get point prompts, then we (ofc) use point-prompting 218 # else we don't use point prompting 219 if n_pos == 0 and n_neg == 0: 220 get_points = False 221 else: 222 get_points = True 223 224 # keeping the solution open by checking for deterministic/dynamic choice of point prompts 225 prompt_generator = PointAndBoxPromptGenerator( 226 n_positive_points=n_pos, 227 n_negative_points=n_neg, 228 dilation_strength=self.dilation_strength, 229 get_box_prompts=get_boxes, 230 get_point_prompts=get_points 231 ) 232 233 batched_inputs = [] 234 batched_sampled_cell_ids_list = [] 235 236 for image, gt in zip(x, y): 237 gt = gt.squeeze().numpy().astype(np.int64) 238 box_prompts, point_prompts, point_label_prompts, sampled_cell_ids = self._get_prompt_lists( 239 gt, n_samples, prompt_generator, 240 ) 241 242 # check to be sure about the expected size of the no. of elements in different settings 243 if get_boxes: 244 assert len(sampled_cell_ids) == len(box_prompts), f"{len(sampled_cell_ids)}, {len(box_prompts)}" 245 246 if get_points: 247 assert len(sampled_cell_ids) == len(point_prompts) == len(point_label_prompts), \ 248 f"{len(sampled_cell_ids)}, {len(point_prompts)}, {len(point_label_prompts)}" 249 250 batched_sampled_cell_ids_list.append(sampled_cell_ids) 251 252 batched_input = {"image": image, "original_size": image.shape[1:]} 253 if get_boxes: 254 batched_input["boxes"] = self.transform.apply_boxes_torch( 255 box_prompts, original_size=gt.shape[-2:] 256 ) if self.transform is not None else box_prompts 257 258 if get_points: 259 batched_input["point_coords"] = self.transform.apply_coords_torch( 260 point_prompts, original_size=gt.shape[-2:] 261 ) if self.transform is not None else point_prompts 262 batched_input["point_labels"] = point_label_prompts 263 264 batched_inputs.append(batched_input) 265 266 return batched_inputs, batched_sampled_cell_ids_list
Convert outputs of data loader to the expected batched inputs of the Segment Anything model.
Arguments:
- transform: The transformation to resize the prompts. Should be the same transform used in the
model to resize the inputs. If
None
the prompts will not be resized. - dilation_strength: The dilation factor. It determines a "safety" border from which prompts are not sampled to avoid ambiguous prompts due to imprecise groundtruth masks. By default, set to '10'.
- box_distortion_factor: Factor for distorting the box annotations derived from the groundtruth masks.
ConvertToSamInputs( transform: Optional[segment_anything.utils.transforms.ResizeLongestSide], dilation_strength: int = 10, box_distortion_factor: Optional[float] = None)
165 def __init__( 166 self, 167 transform: Optional[ResizeLongestSide], 168 dilation_strength: int = 10, 169 box_distortion_factor: Optional[float] = None, 170 ) -> None: 171 self.dilation_strength = dilation_strength 172 self.transform = identity if transform is None else transform 173 self.box_distortion_factor = box_distortion_factor
class
ConvertToSemanticSamInputs:
269class ConvertToSemanticSamInputs: 270 """Convert outputs of data loader to the expected batched inputs of the Segment Anything model 271 for semantic segmentation. 272 """ 273 def __call__(self, x, y): 274 """Convert the outputs of dataloader to the batched format of inputs expected by SAM. 275 """ 276 batched_inputs = [] 277 for image in x: 278 batched_input = {"image": image, "original_size": image.shape[-2:]} 279 batched_inputs.append(batched_input) 280 281 return batched_inputs
Convert outputs of data loader to the expected batched inputs of the Segment Anything model for semantic segmentation.
def
normalize_to_8bit(raw):
class
ResizeRawTrafo:
294class ResizeRawTrafo: 295 def __init__( 296 self, 297 desired_shape: Tuple[int, ...], 298 do_rescaling: bool = False, 299 valid_channels: Optional[Union[int, Tuple[int, ...]]] = None, 300 padding: str = "constant", 301 ensure_rgb: bool = True, 302 ): 303 self.desired_shape = desired_shape 304 self.do_rescaling = do_rescaling 305 self.valid_channels = valid_channels 306 self.padding = padding 307 self.ensure_rgb = ensure_rgb 308 309 def __call__(self, raw): 310 if self.ensure_rgb: 311 raw = to_rgb(raw) # Ensure all images are in 3-channels: triplicate one channel to three channels. 312 313 if self.do_rescaling: 314 raw = normalize_percentile(raw, axis=self.valid_channels) 315 raw = normalize(raw) 316 raw = raw * 255 317 318 # Pad the inputs to the desired shape. 319 tmp_ddim = [desired - curr for desired, curr in zip(self.desired_shape, raw.shape)] 320 ddim = [(per_dim / 2) for per_dim in tmp_ddim] 321 pad_width = [(ceil(d), floor(d)) for d in ddim] 322 raw = np.pad(raw, pad_width=pad_width, mode=self.padding) 323 324 assert raw.shape == self.desired_shape 325 return raw
ResizeRawTrafo( desired_shape: Tuple[int, ...], do_rescaling: bool = False, valid_channels: Union[int, Tuple[int, ...], NoneType] = None, padding: str = 'constant', ensure_rgb: bool = True)
295 def __init__( 296 self, 297 desired_shape: Tuple[int, ...], 298 do_rescaling: bool = False, 299 valid_channels: Optional[Union[int, Tuple[int, ...]]] = None, 300 padding: str = "constant", 301 ensure_rgb: bool = True, 302 ): 303 self.desired_shape = desired_shape 304 self.do_rescaling = do_rescaling 305 self.valid_channels = valid_channels 306 self.padding = padding 307 self.ensure_rgb = ensure_rgb
class
ResizeLabelTrafo:
328class ResizeLabelTrafo: 329 def __init__( 330 self, desired_shape: Tuple[int, ...], min_size: int = 0, padding: str = "constant", 331 ): 332 self.desired_shape = desired_shape 333 self.min_size = min_size 334 self.padding = padding 335 336 def __call__(self, labels): 337 distance_trafo = PerObjectDistanceTransform( 338 distances=True, 339 boundary_distances=True, 340 directed_distances=False, 341 foreground=True, 342 instances=True, 343 min_size=self.min_size 344 ) 345 labels = distance_trafo(labels) 346 347 # choosing H and W from labels (4, H, W), from above dist trafo outputs 348 tmp_ddim = (self.desired_shape[0] - labels.shape[1], self.desired_shape[0] - labels.shape[2]) 349 ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2) 350 labels = np.pad( 351 labels, 352 pad_width=((0, 0), (ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1]))), 353 mode=self.padding 354 ) 355 assert labels.shape[1:] == self.desired_shape, labels.shape 356 return labels