micro_sam.bioimageio.predictor_adaptor
1import warnings 2from typing import Optional, Tuple 3 4import torch 5from torch import nn 6 7from segment_anything.predictor import SamPredictor 8 9try: 10 # Avoid import warnings from mobile_sam 11 with warnings.catch_warnings(): 12 warnings.simplefilter("ignore") 13 from mobile_sam import sam_model_registry 14except ImportError: 15 from segment_anything import sam_model_registry 16 17 18class PredictorAdaptor(nn.Module): 19 """Wrapper around the SamPredictor. 20 21 This model supports the same functionality as SamPredictor and can provide mask segmentations 22 from box, point or mask input prompts. 23 24 Args: 25 model_type: The type of the model for the image encoder. 26 Can be one of 'vit_b', 'vit_l', 'vit_h' or 'vit_t'. 27 For 'vit_t' support the 'mobile_sam' package has to be installed. 28 """ 29 def __init__(self, model_type: str) -> None: 30 super().__init__() 31 sam_model = sam_model_registry[model_type]() 32 self.sam = SamPredictor(sam_model) 33 34 def load_state_dict(self, state): 35 self.sam.model.load_state_dict(state) 36 37 @torch.no_grad() 38 def forward( 39 self, 40 image: torch.Tensor, 41 box_prompts: Optional[torch.Tensor] = None, 42 point_prompts: Optional[torch.Tensor] = None, 43 point_labels: Optional[torch.Tensor] = None, 44 mask_prompts: Optional[torch.Tensor] = None, 45 embeddings: Optional[torch.Tensor] = None, 46 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 47 """ 48 49 Args: 50 image: torch inputs of dimensions B x C x H x W 51 box_prompts: box coordinates of dimensions B x OBJECTS x 4 52 point_prompts: point coordinates of dimension B x OBJECTS x POINTS x 2 53 point_labels: point labels of dimension B x OBJECTS x POINTS 54 mask_prompts: mask prompts of dimension B x OBJECTS x 256 x 256 55 embeddings: precomputed image embeddings B x 256 x 64 x 64 56 57 Returns: 58 The segmentation masks. 59 The scores for prediction quality. 60 The computed image embeddings. 61 """ 62 batch_size = image.shape[0] 63 if batch_size != 1: 64 raise ValueError 65 66 # We have image embeddings set and image embeddings were not passed. 67 if self.sam.is_image_set and embeddings is None: 68 pass # do nothing 69 70 # The embeddings are passed, so we set them. 71 elif embeddings is not None: 72 self.sam.features = embeddings 73 self.sam.orig_h, self.sam.orig_w = image.shape[2:] 74 self.sam.input_h, self.sam.input_w = self.sam.transform.apply_image_torch(image).shape[2:] 75 self.sam.is_image_set = True 76 77 # We don't have image embeddings set and they were not passed. 78 elif not self.sam.is_image_set: 79 input_ = self.sam.transform.apply_image_torch(image) 80 self.sam.set_torch_image(input_, original_image_size=image.shape[2:]) 81 self.sam.orig_h, self.sam.orig_w = self.sam.original_size 82 self.sam.input_h, self.sam.input_w = self.sam.input_size 83 84 assert self.sam.is_image_set, "The predictor has not yet been initialized." 85 86 # Ensure input size and original size are set. 87 self.sam.input_size = (self.sam.input_h, self.sam.input_w) 88 self.sam.original_size = (self.sam.orig_h, self.sam.orig_w) 89 90 if box_prompts is None: 91 boxes = None 92 else: 93 boxes = self.sam.transform.apply_boxes_torch(box_prompts, original_size=self.sam.original_size) 94 95 if point_prompts is None: 96 point_coords = None 97 else: 98 assert point_labels is not None 99 point_coords = self.sam.transform.apply_coords_torch(point_prompts, original_size=self.sam.original_size)[0] 100 point_labels = point_labels[0] 101 102 if mask_prompts is None: 103 mask_input = None 104 else: 105 mask_input = mask_prompts[0] 106 107 masks, scores, _ = self.sam.predict_torch( 108 point_coords=point_coords, 109 point_labels=point_labels, 110 boxes=boxes, 111 mask_input=mask_input, 112 multimask_output=False 113 ) 114 115 assert masks.shape[2:] == image.shape[2:], \ 116 f"{masks.shape[2:]} is not as expected ({image.shape[2:]})" 117 118 # Ensure batch axis. 119 if masks.ndim == 4: 120 masks = masks[None] 121 assert scores.ndim == 2 122 scores = scores[None] 123 124 embeddings = self.sam.get_image_embedding() 125 return masks.to(dtype=torch.uint8), scores, embeddings
19class PredictorAdaptor(nn.Module): 20 """Wrapper around the SamPredictor. 21 22 This model supports the same functionality as SamPredictor and can provide mask segmentations 23 from box, point or mask input prompts. 24 25 Args: 26 model_type: The type of the model for the image encoder. 27 Can be one of 'vit_b', 'vit_l', 'vit_h' or 'vit_t'. 28 For 'vit_t' support the 'mobile_sam' package has to be installed. 29 """ 30 def __init__(self, model_type: str) -> None: 31 super().__init__() 32 sam_model = sam_model_registry[model_type]() 33 self.sam = SamPredictor(sam_model) 34 35 def load_state_dict(self, state): 36 self.sam.model.load_state_dict(state) 37 38 @torch.no_grad() 39 def forward( 40 self, 41 image: torch.Tensor, 42 box_prompts: Optional[torch.Tensor] = None, 43 point_prompts: Optional[torch.Tensor] = None, 44 point_labels: Optional[torch.Tensor] = None, 45 mask_prompts: Optional[torch.Tensor] = None, 46 embeddings: Optional[torch.Tensor] = None, 47 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 48 """ 49 50 Args: 51 image: torch inputs of dimensions B x C x H x W 52 box_prompts: box coordinates of dimensions B x OBJECTS x 4 53 point_prompts: point coordinates of dimension B x OBJECTS x POINTS x 2 54 point_labels: point labels of dimension B x OBJECTS x POINTS 55 mask_prompts: mask prompts of dimension B x OBJECTS x 256 x 256 56 embeddings: precomputed image embeddings B x 256 x 64 x 64 57 58 Returns: 59 The segmentation masks. 60 The scores for prediction quality. 61 The computed image embeddings. 62 """ 63 batch_size = image.shape[0] 64 if batch_size != 1: 65 raise ValueError 66 67 # We have image embeddings set and image embeddings were not passed. 68 if self.sam.is_image_set and embeddings is None: 69 pass # do nothing 70 71 # The embeddings are passed, so we set them. 72 elif embeddings is not None: 73 self.sam.features = embeddings 74 self.sam.orig_h, self.sam.orig_w = image.shape[2:] 75 self.sam.input_h, self.sam.input_w = self.sam.transform.apply_image_torch(image).shape[2:] 76 self.sam.is_image_set = True 77 78 # We don't have image embeddings set and they were not passed. 79 elif not self.sam.is_image_set: 80 input_ = self.sam.transform.apply_image_torch(image) 81 self.sam.set_torch_image(input_, original_image_size=image.shape[2:]) 82 self.sam.orig_h, self.sam.orig_w = self.sam.original_size 83 self.sam.input_h, self.sam.input_w = self.sam.input_size 84 85 assert self.sam.is_image_set, "The predictor has not yet been initialized." 86 87 # Ensure input size and original size are set. 88 self.sam.input_size = (self.sam.input_h, self.sam.input_w) 89 self.sam.original_size = (self.sam.orig_h, self.sam.orig_w) 90 91 if box_prompts is None: 92 boxes = None 93 else: 94 boxes = self.sam.transform.apply_boxes_torch(box_prompts, original_size=self.sam.original_size) 95 96 if point_prompts is None: 97 point_coords = None 98 else: 99 assert point_labels is not None 100 point_coords = self.sam.transform.apply_coords_torch(point_prompts, original_size=self.sam.original_size)[0] 101 point_labels = point_labels[0] 102 103 if mask_prompts is None: 104 mask_input = None 105 else: 106 mask_input = mask_prompts[0] 107 108 masks, scores, _ = self.sam.predict_torch( 109 point_coords=point_coords, 110 point_labels=point_labels, 111 boxes=boxes, 112 mask_input=mask_input, 113 multimask_output=False 114 ) 115 116 assert masks.shape[2:] == image.shape[2:], \ 117 f"{masks.shape[2:]} is not as expected ({image.shape[2:]})" 118 119 # Ensure batch axis. 120 if masks.ndim == 4: 121 masks = masks[None] 122 assert scores.ndim == 2 123 scores = scores[None] 124 125 embeddings = self.sam.get_image_embedding() 126 return masks.to(dtype=torch.uint8), scores, embeddings
Wrapper around the SamPredictor.
This model supports the same functionality as SamPredictor and can provide mask segmentations from box, point or mask input prompts.
Arguments:
- model_type: The type of the model for the image encoder. Can be one of 'vit_b', 'vit_l', 'vit_h' or 'vit_t'. For 'vit_t' support the 'mobile_sam' package has to be installed.
30 def __init__(self, model_type: str) -> None: 31 super().__init__() 32 sam_model = sam_model_registry[model_type]() 33 self.sam = SamPredictor(sam_model)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Copy parameters and buffers from state_dict into this module and its descendants.
If strict is True, then
the keys of state_dict must exactly match the keys returned
by this module's ~torch.nn.Module.state_dict() function.
If assign is True the optimizer must be created after
the call to load_state_dict unless
~torch.__future__.get_swap_module_params_on_conversion() is True.
Arguments:
- state_dict (dict): a dict containing parameters and persistent buffers.
- strict (bool, optional): whether to strictly enforce that the keys
in
state_dictmatch the keys returned by this module's~torch.nn.Module.state_dict()function. Default:True - assign (bool, optional): When set to
False, the properties of the tensors in the current module are preserved whereas setting it toTruepreserves properties of the Tensors in the state dict. The only exception is therequires_gradfield of~torch.nn.Parameters for which the value from the module is preserved. Default:False
Returns:
NamedTuplewithmissing_keysandunexpected_keysfields: * missing_keys is a list of str containing any keys that are expected by this module but missing from the providedstate_dict. * unexpected_keys is a list of str containing the keys that are not expected by this module but present in the providedstate_dict.
Note:
If a parameter or buffer is registered as
Noneand its corresponding key exists instate_dict,load_state_dict()will raise aRuntimeError.
38 @torch.no_grad() 39 def forward( 40 self, 41 image: torch.Tensor, 42 box_prompts: Optional[torch.Tensor] = None, 43 point_prompts: Optional[torch.Tensor] = None, 44 point_labels: Optional[torch.Tensor] = None, 45 mask_prompts: Optional[torch.Tensor] = None, 46 embeddings: Optional[torch.Tensor] = None, 47 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 48 """ 49 50 Args: 51 image: torch inputs of dimensions B x C x H x W 52 box_prompts: box coordinates of dimensions B x OBJECTS x 4 53 point_prompts: point coordinates of dimension B x OBJECTS x POINTS x 2 54 point_labels: point labels of dimension B x OBJECTS x POINTS 55 mask_prompts: mask prompts of dimension B x OBJECTS x 256 x 256 56 embeddings: precomputed image embeddings B x 256 x 64 x 64 57 58 Returns: 59 The segmentation masks. 60 The scores for prediction quality. 61 The computed image embeddings. 62 """ 63 batch_size = image.shape[0] 64 if batch_size != 1: 65 raise ValueError 66 67 # We have image embeddings set and image embeddings were not passed. 68 if self.sam.is_image_set and embeddings is None: 69 pass # do nothing 70 71 # The embeddings are passed, so we set them. 72 elif embeddings is not None: 73 self.sam.features = embeddings 74 self.sam.orig_h, self.sam.orig_w = image.shape[2:] 75 self.sam.input_h, self.sam.input_w = self.sam.transform.apply_image_torch(image).shape[2:] 76 self.sam.is_image_set = True 77 78 # We don't have image embeddings set and they were not passed. 79 elif not self.sam.is_image_set: 80 input_ = self.sam.transform.apply_image_torch(image) 81 self.sam.set_torch_image(input_, original_image_size=image.shape[2:]) 82 self.sam.orig_h, self.sam.orig_w = self.sam.original_size 83 self.sam.input_h, self.sam.input_w = self.sam.input_size 84 85 assert self.sam.is_image_set, "The predictor has not yet been initialized." 86 87 # Ensure input size and original size are set. 88 self.sam.input_size = (self.sam.input_h, self.sam.input_w) 89 self.sam.original_size = (self.sam.orig_h, self.sam.orig_w) 90 91 if box_prompts is None: 92 boxes = None 93 else: 94 boxes = self.sam.transform.apply_boxes_torch(box_prompts, original_size=self.sam.original_size) 95 96 if point_prompts is None: 97 point_coords = None 98 else: 99 assert point_labels is not None 100 point_coords = self.sam.transform.apply_coords_torch(point_prompts, original_size=self.sam.original_size)[0] 101 point_labels = point_labels[0] 102 103 if mask_prompts is None: 104 mask_input = None 105 else: 106 mask_input = mask_prompts[0] 107 108 masks, scores, _ = self.sam.predict_torch( 109 point_coords=point_coords, 110 point_labels=point_labels, 111 boxes=boxes, 112 mask_input=mask_input, 113 multimask_output=False 114 ) 115 116 assert masks.shape[2:] == image.shape[2:], \ 117 f"{masks.shape[2:]} is not as expected ({image.shape[2:]})" 118 119 # Ensure batch axis. 120 if masks.ndim == 4: 121 masks = masks[None] 122 assert scores.ndim == 2 123 scores = scores[None] 124 125 embeddings = self.sam.get_image_embedding() 126 return masks.to(dtype=torch.uint8), scores, embeddings
Arguments:
- image: torch inputs of dimensions B x C x H x W
- box_prompts: box coordinates of dimensions B x OBJECTS x 4
- point_prompts: point coordinates of dimension B x OBJECTS x POINTS x 2
- point_labels: point labels of dimension B x OBJECTS x POINTS
- mask_prompts: mask prompts of dimension B x OBJECTS x 256 x 256
- embeddings: precomputed image embeddings B x 256 x 64 x 64
Returns:
The segmentation masks. The scores for prediction quality. The computed image embeddings.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- set_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- mtia
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_post_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_pre_hook
- register_load_state_dict_post_hook
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile