micro_sam.bioimageio.predictor_adaptor
1import warnings 2from typing import Optional, Tuple 3 4import torch 5from torch import nn 6 7from segment_anything.predictor import SamPredictor 8 9try: 10 # Avoid import warnings from mobile_sam 11 with warnings.catch_warnings(): 12 warnings.simplefilter("ignore") 13 from mobile_sam import sam_model_registry 14except ImportError: 15 from segment_anything import sam_model_registry 16 17 18class PredictorAdaptor(nn.Module): 19 """Wrapper around the SamPredictor. 20 21 This model supports the same functionality as SamPredictor and can provide mask segmentations 22 from box, point or mask input prompts. 23 24 Args: 25 model_type: The type of the model for the image encoder. 26 Can be one of 'vit_b', 'vit_l', 'vit_h' or 'vit_t'. 27 For 'vit_t' support the 'mobile_sam' package has to be installed. 28 """ 29 def __init__(self, model_type: str) -> None: 30 super().__init__() 31 self.sam_model = sam_model_registry[model_type]() 32 self.sam = SamPredictor(self.sam_model) 33 34 def load_state_dict(self, state, **kwargs): 35 return self.sam.model.load_state_dict(state, **kwargs) 36 37 @torch.no_grad() 38 def forward( 39 self, 40 image: torch.Tensor, 41 box_prompts: Optional[torch.Tensor] = None, 42 point_prompts: Optional[torch.Tensor] = None, 43 point_labels: Optional[torch.Tensor] = None, 44 mask_prompts: Optional[torch.Tensor] = None, 45 embeddings: Optional[torch.Tensor] = None, 46 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 47 """ 48 49 Args: 50 image: torch inputs of dimensions B x C x H x W 51 box_prompts: box coordinates of dimensions B x OBJECTS x 4 52 point_prompts: point coordinates of dimension B x OBJECTS x POINTS x 2 53 point_labels: point labels of dimension B x OBJECTS x POINTS 54 mask_prompts: mask prompts of dimension B x OBJECTS x 256 x 256 55 embeddings: precomputed image embeddings B x 256 x 64 x 64 56 57 Returns: 58 The segmentation masks. 59 The scores for prediction quality. 60 The computed image embeddings. 61 """ 62 batch_size = image.shape[0] 63 if batch_size != 1: 64 raise ValueError 65 66 # Cast to float for MPS compatibility: F.interpolate with antialias=True 67 # only supports floating-point dtypes on MPS (Apple Silicon). 68 image_float = image.float() if not image.is_floating_point() else image 69 70 # We have image embeddings set and image embeddings were not passed. 71 if self.sam.is_image_set and embeddings is None: 72 pass # do nothing 73 74 # The embeddings are passed, so we set them. 75 elif embeddings is not None: 76 self.sam.features = embeddings 77 self.sam.orig_h, self.sam.orig_w = image.shape[2:] 78 self.sam.input_h, self.sam.input_w = self.sam.transform.apply_image_torch(image_float).shape[2:] 79 self.sam.is_image_set = True 80 81 # We don't have image embeddings set and they were not passed. 82 elif not self.sam.is_image_set: 83 input_ = self.sam.transform.apply_image_torch(image_float) 84 self.sam.set_torch_image(input_, original_image_size=image.shape[2:]) 85 self.sam.orig_h, self.sam.orig_w = self.sam.original_size 86 self.sam.input_h, self.sam.input_w = self.sam.input_size 87 88 assert self.sam.is_image_set, "The predictor has not yet been initialized." 89 90 # Ensure input size and original size are set. 91 self.sam.input_size = (self.sam.input_h, self.sam.input_w) 92 self.sam.original_size = (self.sam.orig_h, self.sam.orig_w) 93 94 if box_prompts is None: 95 boxes = None 96 else: 97 boxes = self.sam.transform.apply_boxes_torch(box_prompts, original_size=self.sam.original_size) 98 99 if point_prompts is None: 100 point_coords = None 101 else: 102 assert point_labels is not None 103 point_coords = self.sam.transform.apply_coords_torch(point_prompts, original_size=self.sam.original_size)[0] 104 point_labels = point_labels[0] 105 106 if mask_prompts is None: 107 mask_input = None 108 else: 109 mask_input = mask_prompts[0] 110 111 masks, scores, _ = self.sam.predict_torch( 112 point_coords=point_coords, 113 point_labels=point_labels, 114 boxes=boxes, 115 mask_input=mask_input, 116 multimask_output=False 117 ) 118 119 assert masks.shape[2:] == image.shape[2:], \ 120 f"{masks.shape[2:]} is not as expected ({image.shape[2:]})" 121 122 # Ensure batch axis. 123 if masks.ndim == 4: 124 masks = masks[None] 125 assert scores.ndim == 2 126 scores = scores[None] 127 128 embeddings = self.sam.get_image_embedding() 129 return masks.to(dtype=torch.uint8), scores, embeddings
19class PredictorAdaptor(nn.Module): 20 """Wrapper around the SamPredictor. 21 22 This model supports the same functionality as SamPredictor and can provide mask segmentations 23 from box, point or mask input prompts. 24 25 Args: 26 model_type: The type of the model for the image encoder. 27 Can be one of 'vit_b', 'vit_l', 'vit_h' or 'vit_t'. 28 For 'vit_t' support the 'mobile_sam' package has to be installed. 29 """ 30 def __init__(self, model_type: str) -> None: 31 super().__init__() 32 self.sam_model = sam_model_registry[model_type]() 33 self.sam = SamPredictor(self.sam_model) 34 35 def load_state_dict(self, state, **kwargs): 36 return self.sam.model.load_state_dict(state, **kwargs) 37 38 @torch.no_grad() 39 def forward( 40 self, 41 image: torch.Tensor, 42 box_prompts: Optional[torch.Tensor] = None, 43 point_prompts: Optional[torch.Tensor] = None, 44 point_labels: Optional[torch.Tensor] = None, 45 mask_prompts: Optional[torch.Tensor] = None, 46 embeddings: Optional[torch.Tensor] = None, 47 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 48 """ 49 50 Args: 51 image: torch inputs of dimensions B x C x H x W 52 box_prompts: box coordinates of dimensions B x OBJECTS x 4 53 point_prompts: point coordinates of dimension B x OBJECTS x POINTS x 2 54 point_labels: point labels of dimension B x OBJECTS x POINTS 55 mask_prompts: mask prompts of dimension B x OBJECTS x 256 x 256 56 embeddings: precomputed image embeddings B x 256 x 64 x 64 57 58 Returns: 59 The segmentation masks. 60 The scores for prediction quality. 61 The computed image embeddings. 62 """ 63 batch_size = image.shape[0] 64 if batch_size != 1: 65 raise ValueError 66 67 # Cast to float for MPS compatibility: F.interpolate with antialias=True 68 # only supports floating-point dtypes on MPS (Apple Silicon). 69 image_float = image.float() if not image.is_floating_point() else image 70 71 # We have image embeddings set and image embeddings were not passed. 72 if self.sam.is_image_set and embeddings is None: 73 pass # do nothing 74 75 # The embeddings are passed, so we set them. 76 elif embeddings is not None: 77 self.sam.features = embeddings 78 self.sam.orig_h, self.sam.orig_w = image.shape[2:] 79 self.sam.input_h, self.sam.input_w = self.sam.transform.apply_image_torch(image_float).shape[2:] 80 self.sam.is_image_set = True 81 82 # We don't have image embeddings set and they were not passed. 83 elif not self.sam.is_image_set: 84 input_ = self.sam.transform.apply_image_torch(image_float) 85 self.sam.set_torch_image(input_, original_image_size=image.shape[2:]) 86 self.sam.orig_h, self.sam.orig_w = self.sam.original_size 87 self.sam.input_h, self.sam.input_w = self.sam.input_size 88 89 assert self.sam.is_image_set, "The predictor has not yet been initialized." 90 91 # Ensure input size and original size are set. 92 self.sam.input_size = (self.sam.input_h, self.sam.input_w) 93 self.sam.original_size = (self.sam.orig_h, self.sam.orig_w) 94 95 if box_prompts is None: 96 boxes = None 97 else: 98 boxes = self.sam.transform.apply_boxes_torch(box_prompts, original_size=self.sam.original_size) 99 100 if point_prompts is None: 101 point_coords = None 102 else: 103 assert point_labels is not None 104 point_coords = self.sam.transform.apply_coords_torch(point_prompts, original_size=self.sam.original_size)[0] 105 point_labels = point_labels[0] 106 107 if mask_prompts is None: 108 mask_input = None 109 else: 110 mask_input = mask_prompts[0] 111 112 masks, scores, _ = self.sam.predict_torch( 113 point_coords=point_coords, 114 point_labels=point_labels, 115 boxes=boxes, 116 mask_input=mask_input, 117 multimask_output=False 118 ) 119 120 assert masks.shape[2:] == image.shape[2:], \ 121 f"{masks.shape[2:]} is not as expected ({image.shape[2:]})" 122 123 # Ensure batch axis. 124 if masks.ndim == 4: 125 masks = masks[None] 126 assert scores.ndim == 2 127 scores = scores[None] 128 129 embeddings = self.sam.get_image_embedding() 130 return masks.to(dtype=torch.uint8), scores, embeddings
Wrapper around the SamPredictor.
This model supports the same functionality as SamPredictor and can provide mask segmentations from box, point or mask input prompts.
Arguments:
- model_type: The type of the model for the image encoder. Can be one of 'vit_b', 'vit_l', 'vit_h' or 'vit_t'. For 'vit_t' support the 'mobile_sam' package has to be installed.
30 def __init__(self, model_type: str) -> None: 31 super().__init__() 32 self.sam_model = sam_model_registry[model_type]() 33 self.sam = SamPredictor(self.sam_model)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
35 def load_state_dict(self, state, **kwargs): 36 return self.sam.model.load_state_dict(state, **kwargs)
Copy parameters and buffers from state_dict into this module and its descendants.
If strict is True, then
the keys of state_dict must exactly match the keys returned
by this module's ~torch.nn.Module.state_dict() function.
If assign is True the optimizer must be created after
the call to load_state_dict unless
~torch.__future__.get_swap_module_params_on_conversion() is True.
Arguments:
- state_dict (dict): a dict containing parameters and persistent buffers.
- strict (bool, optional): whether to strictly enforce that the keys
in
state_dictmatch the keys returned by this module's~torch.nn.Module.state_dict()function. Default:True - assign (bool, optional): When set to
False, the properties of the tensors in the current module are preserved whereas setting it toTruepreserves properties of the Tensors in the state dict. The only exception is therequires_gradfield of~torch.nn.Parameterfor which the value from the module is preserved. Default:False
Returns:
NamedTuplewithmissing_keysandunexpected_keysfields: *missing_keysis a list of str containing any keys that are expected by this module but missing from the providedstate_dict. *unexpected_keysis a list of str containing the keys that are not expected by this module but present in the providedstate_dict.
Note:
If a parameter or buffer is registered as
Noneand its corresponding key exists instate_dict,load_state_dict()will raise aRuntimeError.
38 @torch.no_grad() 39 def forward( 40 self, 41 image: torch.Tensor, 42 box_prompts: Optional[torch.Tensor] = None, 43 point_prompts: Optional[torch.Tensor] = None, 44 point_labels: Optional[torch.Tensor] = None, 45 mask_prompts: Optional[torch.Tensor] = None, 46 embeddings: Optional[torch.Tensor] = None, 47 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 48 """ 49 50 Args: 51 image: torch inputs of dimensions B x C x H x W 52 box_prompts: box coordinates of dimensions B x OBJECTS x 4 53 point_prompts: point coordinates of dimension B x OBJECTS x POINTS x 2 54 point_labels: point labels of dimension B x OBJECTS x POINTS 55 mask_prompts: mask prompts of dimension B x OBJECTS x 256 x 256 56 embeddings: precomputed image embeddings B x 256 x 64 x 64 57 58 Returns: 59 The segmentation masks. 60 The scores for prediction quality. 61 The computed image embeddings. 62 """ 63 batch_size = image.shape[0] 64 if batch_size != 1: 65 raise ValueError 66 67 # Cast to float for MPS compatibility: F.interpolate with antialias=True 68 # only supports floating-point dtypes on MPS (Apple Silicon). 69 image_float = image.float() if not image.is_floating_point() else image 70 71 # We have image embeddings set and image embeddings were not passed. 72 if self.sam.is_image_set and embeddings is None: 73 pass # do nothing 74 75 # The embeddings are passed, so we set them. 76 elif embeddings is not None: 77 self.sam.features = embeddings 78 self.sam.orig_h, self.sam.orig_w = image.shape[2:] 79 self.sam.input_h, self.sam.input_w = self.sam.transform.apply_image_torch(image_float).shape[2:] 80 self.sam.is_image_set = True 81 82 # We don't have image embeddings set and they were not passed. 83 elif not self.sam.is_image_set: 84 input_ = self.sam.transform.apply_image_torch(image_float) 85 self.sam.set_torch_image(input_, original_image_size=image.shape[2:]) 86 self.sam.orig_h, self.sam.orig_w = self.sam.original_size 87 self.sam.input_h, self.sam.input_w = self.sam.input_size 88 89 assert self.sam.is_image_set, "The predictor has not yet been initialized." 90 91 # Ensure input size and original size are set. 92 self.sam.input_size = (self.sam.input_h, self.sam.input_w) 93 self.sam.original_size = (self.sam.orig_h, self.sam.orig_w) 94 95 if box_prompts is None: 96 boxes = None 97 else: 98 boxes = self.sam.transform.apply_boxes_torch(box_prompts, original_size=self.sam.original_size) 99 100 if point_prompts is None: 101 point_coords = None 102 else: 103 assert point_labels is not None 104 point_coords = self.sam.transform.apply_coords_torch(point_prompts, original_size=self.sam.original_size)[0] 105 point_labels = point_labels[0] 106 107 if mask_prompts is None: 108 mask_input = None 109 else: 110 mask_input = mask_prompts[0] 111 112 masks, scores, _ = self.sam.predict_torch( 113 point_coords=point_coords, 114 point_labels=point_labels, 115 boxes=boxes, 116 mask_input=mask_input, 117 multimask_output=False 118 ) 119 120 assert masks.shape[2:] == image.shape[2:], \ 121 f"{masks.shape[2:]} is not as expected ({image.shape[2:]})" 122 123 # Ensure batch axis. 124 if masks.ndim == 4: 125 masks = masks[None] 126 assert scores.ndim == 2 127 scores = scores[None] 128 129 embeddings = self.sam.get_image_embedding() 130 return masks.to(dtype=torch.uint8), scores, embeddings
Arguments:
- image: torch inputs of dimensions B x C x H x W
- box_prompts: box coordinates of dimensions B x OBJECTS x 4
- point_prompts: point coordinates of dimension B x OBJECTS x POINTS x 2
- point_labels: point labels of dimension B x OBJECTS x POINTS
- mask_prompts: mask prompts of dimension B x OBJECTS x 256 x 256
- embeddings: precomputed image embeddings B x 256 x 64 x 64
Returns:
The segmentation masks. The scores for prediction quality. The computed image embeddings.