micro_sam.bioimageio.predictor_adaptor
1import warnings 2from typing import Optional, Tuple 3 4import torch 5from torch import nn 6 7from segment_anything.predictor import SamPredictor 8 9try: 10 # Avoid import warnings from mobile_sam 11 with warnings.catch_warnings(): 12 warnings.simplefilter("ignore") 13 from mobile_sam import sam_model_registry 14except ImportError: 15 from segment_anything import sam_model_registry 16 17 18class PredictorAdaptor(nn.Module): 19 """Wrapper around the SamPredictor. 20 21 This model supports the same functionality as SamPredictor and can provide mask segmentations 22 from box, point or mask input prompts. 23 24 Args: 25 model_type: The type of the model for the image encoder. 26 Can be one of 'vit_b', 'vit_l', 'vit_h' or 'vit_t'. 27 For 'vit_t' support the 'mobile_sam' package has to be installed. 28 """ 29 def __init__(self, model_type: str) -> None: 30 super().__init__() 31 sam_model = sam_model_registry[model_type]() 32 self.sam = SamPredictor(sam_model) 33 34 def load_state_dict(self, state): 35 self.sam.model.load_state_dict(state) 36 37 @torch.no_grad() 38 def forward( 39 self, 40 image: torch.Tensor, 41 box_prompts: Optional[torch.Tensor] = None, 42 point_prompts: Optional[torch.Tensor] = None, 43 point_labels: Optional[torch.Tensor] = None, 44 mask_prompts: Optional[torch.Tensor] = None, 45 embeddings: Optional[torch.Tensor] = None, 46 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 47 """ 48 49 Args: 50 image: torch inputs of dimensions B x C x H x W 51 box_prompts: box coordinates of dimensions B x OBJECTS x 4 52 point_prompts: point coordinates of dimension B x OBJECTS x POINTS x 2 53 point_labels: point labels of dimension B x OBJECTS x POINTS 54 mask_prompts: mask prompts of dimension B x OBJECTS x 256 x 256 55 embeddings: precomputed image embeddings B x 256 x 64 x 64 56 57 Returns: 58 """ 59 batch_size = image.shape[0] 60 if batch_size != 1: 61 raise ValueError 62 63 # We have image embeddings set and image embeddings were not passed. 64 if self.sam.is_image_set and embeddings is None: 65 pass # do nothing 66 67 # The embeddings are passed, so we set them. 68 elif embeddings is not None: 69 self.sam.features = embeddings 70 self.sam.orig_h, self.sam.orig_w = image.shape[2:] 71 self.sam.input_h, self.sam.input_w = self.sam.transform.apply_image_torch(image).shape[2:] 72 self.sam.is_image_set = True 73 74 # We don't have image embeddings set and they were not passed. 75 elif not self.sam.is_image_set: 76 input_ = self.sam.transform.apply_image_torch(image) 77 self.sam.set_torch_image(input_, original_image_size=image.shape[2:]) 78 self.sam.orig_h, self.sam.orig_w = self.sam.original_size 79 self.sam.input_h, self.sam.input_w = self.sam.input_size 80 81 assert self.sam.is_image_set, "The predictor has not yet been initialized." 82 83 # Ensure input size and original size are set. 84 self.sam.input_size = (self.sam.input_h, self.sam.input_w) 85 self.sam.original_size = (self.sam.orig_h, self.sam.orig_w) 86 87 if box_prompts is None: 88 boxes = None 89 else: 90 boxes = self.sam.transform.apply_boxes_torch(box_prompts, original_size=self.sam.original_size) 91 92 if point_prompts is None: 93 point_coords = None 94 else: 95 assert point_labels is not None 96 point_coords = self.sam.transform.apply_coords_torch(point_prompts, original_size=self.sam.original_size)[0] 97 point_labels = point_labels[0] 98 99 if mask_prompts is None: 100 mask_input = None 101 else: 102 mask_input = mask_prompts[0] 103 104 masks, scores, _ = self.sam.predict_torch( 105 point_coords=point_coords, 106 point_labels=point_labels, 107 boxes=boxes, 108 mask_input=mask_input, 109 multimask_output=False 110 ) 111 112 assert masks.shape[2:] == image.shape[2:], \ 113 f"{masks.shape[2:]} is not as expected ({image.shape[2:]})" 114 115 # Ensure batch axis. 116 if masks.ndim == 4: 117 masks = masks[None] 118 assert scores.ndim == 2 119 scores = scores[None] 120 121 embeddings = self.sam.get_image_embedding() 122 return masks.to(dtype=torch.uint8), scores, embeddings
19class PredictorAdaptor(nn.Module): 20 """Wrapper around the SamPredictor. 21 22 This model supports the same functionality as SamPredictor and can provide mask segmentations 23 from box, point or mask input prompts. 24 25 Args: 26 model_type: The type of the model for the image encoder. 27 Can be one of 'vit_b', 'vit_l', 'vit_h' or 'vit_t'. 28 For 'vit_t' support the 'mobile_sam' package has to be installed. 29 """ 30 def __init__(self, model_type: str) -> None: 31 super().__init__() 32 sam_model = sam_model_registry[model_type]() 33 self.sam = SamPredictor(sam_model) 34 35 def load_state_dict(self, state): 36 self.sam.model.load_state_dict(state) 37 38 @torch.no_grad() 39 def forward( 40 self, 41 image: torch.Tensor, 42 box_prompts: Optional[torch.Tensor] = None, 43 point_prompts: Optional[torch.Tensor] = None, 44 point_labels: Optional[torch.Tensor] = None, 45 mask_prompts: Optional[torch.Tensor] = None, 46 embeddings: Optional[torch.Tensor] = None, 47 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 48 """ 49 50 Args: 51 image: torch inputs of dimensions B x C x H x W 52 box_prompts: box coordinates of dimensions B x OBJECTS x 4 53 point_prompts: point coordinates of dimension B x OBJECTS x POINTS x 2 54 point_labels: point labels of dimension B x OBJECTS x POINTS 55 mask_prompts: mask prompts of dimension B x OBJECTS x 256 x 256 56 embeddings: precomputed image embeddings B x 256 x 64 x 64 57 58 Returns: 59 """ 60 batch_size = image.shape[0] 61 if batch_size != 1: 62 raise ValueError 63 64 # We have image embeddings set and image embeddings were not passed. 65 if self.sam.is_image_set and embeddings is None: 66 pass # do nothing 67 68 # The embeddings are passed, so we set them. 69 elif embeddings is not None: 70 self.sam.features = embeddings 71 self.sam.orig_h, self.sam.orig_w = image.shape[2:] 72 self.sam.input_h, self.sam.input_w = self.sam.transform.apply_image_torch(image).shape[2:] 73 self.sam.is_image_set = True 74 75 # We don't have image embeddings set and they were not passed. 76 elif not self.sam.is_image_set: 77 input_ = self.sam.transform.apply_image_torch(image) 78 self.sam.set_torch_image(input_, original_image_size=image.shape[2:]) 79 self.sam.orig_h, self.sam.orig_w = self.sam.original_size 80 self.sam.input_h, self.sam.input_w = self.sam.input_size 81 82 assert self.sam.is_image_set, "The predictor has not yet been initialized." 83 84 # Ensure input size and original size are set. 85 self.sam.input_size = (self.sam.input_h, self.sam.input_w) 86 self.sam.original_size = (self.sam.orig_h, self.sam.orig_w) 87 88 if box_prompts is None: 89 boxes = None 90 else: 91 boxes = self.sam.transform.apply_boxes_torch(box_prompts, original_size=self.sam.original_size) 92 93 if point_prompts is None: 94 point_coords = None 95 else: 96 assert point_labels is not None 97 point_coords = self.sam.transform.apply_coords_torch(point_prompts, original_size=self.sam.original_size)[0] 98 point_labels = point_labels[0] 99 100 if mask_prompts is None: 101 mask_input = None 102 else: 103 mask_input = mask_prompts[0] 104 105 masks, scores, _ = self.sam.predict_torch( 106 point_coords=point_coords, 107 point_labels=point_labels, 108 boxes=boxes, 109 mask_input=mask_input, 110 multimask_output=False 111 ) 112 113 assert masks.shape[2:] == image.shape[2:], \ 114 f"{masks.shape[2:]} is not as expected ({image.shape[2:]})" 115 116 # Ensure batch axis. 117 if masks.ndim == 4: 118 masks = masks[None] 119 assert scores.ndim == 2 120 scores = scores[None] 121 122 embeddings = self.sam.get_image_embedding() 123 return masks.to(dtype=torch.uint8), scores, embeddings
Wrapper around the SamPredictor.
This model supports the same functionality as SamPredictor and can provide mask segmentations from box, point or mask input prompts.
Arguments:
- model_type: The type of the model for the image encoder. Can be one of 'vit_b', 'vit_l', 'vit_h' or 'vit_t'. For 'vit_t' support the 'mobile_sam' package has to be installed.
30 def __init__(self, model_type: str) -> None: 31 super().__init__() 32 sam_model = sam_model_registry[model_type]() 33 self.sam = SamPredictor(sam_model)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Copy parameters and buffers from state_dict
into this module and its descendants.
If strict
is True
, then
the keys of state_dict
must exactly match the keys returned
by this module's ~torch.nn.Module.state_dict()
function.
If assign
is True
the optimizer must be created after
the call to load_state_dict
unless
~torch.__future__.get_swap_module_params_on_conversion()
is True
.
Arguments:
- state_dict (dict): a dict containing parameters and persistent buffers.
- strict (bool, optional): whether to strictly enforce that the keys
in
state_dict
match the keys returned by this module's~torch.nn.Module.state_dict()
function. Default:True
- assign (bool, optional): When
False
, the properties of the tensors in the current module are preserved while whenTrue
, the properties of the Tensors in the state dict are preserved. The only exception is therequires_grad
field of~torch.nn.Parameter
s for which the value from the module is preserved. Default:False
Returns:
NamedTuple
withmissing_keys
andunexpected_keys
fields: * missing_keys is a list of str containing any keys that are expected by this module but missing from the providedstate_dict
. * unexpected_keys is a list of str containing the keys that are not expected by this module but present in the providedstate_dict
.
Note:
If a parameter or buffer is registered as
None
and its corresponding key exists instate_dict
,load_state_dict()
will raise aRuntimeError
.
38 @torch.no_grad() 39 def forward( 40 self, 41 image: torch.Tensor, 42 box_prompts: Optional[torch.Tensor] = None, 43 point_prompts: Optional[torch.Tensor] = None, 44 point_labels: Optional[torch.Tensor] = None, 45 mask_prompts: Optional[torch.Tensor] = None, 46 embeddings: Optional[torch.Tensor] = None, 47 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 48 """ 49 50 Args: 51 image: torch inputs of dimensions B x C x H x W 52 box_prompts: box coordinates of dimensions B x OBJECTS x 4 53 point_prompts: point coordinates of dimension B x OBJECTS x POINTS x 2 54 point_labels: point labels of dimension B x OBJECTS x POINTS 55 mask_prompts: mask prompts of dimension B x OBJECTS x 256 x 256 56 embeddings: precomputed image embeddings B x 256 x 64 x 64 57 58 Returns: 59 """ 60 batch_size = image.shape[0] 61 if batch_size != 1: 62 raise ValueError 63 64 # We have image embeddings set and image embeddings were not passed. 65 if self.sam.is_image_set and embeddings is None: 66 pass # do nothing 67 68 # The embeddings are passed, so we set them. 69 elif embeddings is not None: 70 self.sam.features = embeddings 71 self.sam.orig_h, self.sam.orig_w = image.shape[2:] 72 self.sam.input_h, self.sam.input_w = self.sam.transform.apply_image_torch(image).shape[2:] 73 self.sam.is_image_set = True 74 75 # We don't have image embeddings set and they were not passed. 76 elif not self.sam.is_image_set: 77 input_ = self.sam.transform.apply_image_torch(image) 78 self.sam.set_torch_image(input_, original_image_size=image.shape[2:]) 79 self.sam.orig_h, self.sam.orig_w = self.sam.original_size 80 self.sam.input_h, self.sam.input_w = self.sam.input_size 81 82 assert self.sam.is_image_set, "The predictor has not yet been initialized." 83 84 # Ensure input size and original size are set. 85 self.sam.input_size = (self.sam.input_h, self.sam.input_w) 86 self.sam.original_size = (self.sam.orig_h, self.sam.orig_w) 87 88 if box_prompts is None: 89 boxes = None 90 else: 91 boxes = self.sam.transform.apply_boxes_torch(box_prompts, original_size=self.sam.original_size) 92 93 if point_prompts is None: 94 point_coords = None 95 else: 96 assert point_labels is not None 97 point_coords = self.sam.transform.apply_coords_torch(point_prompts, original_size=self.sam.original_size)[0] 98 point_labels = point_labels[0] 99 100 if mask_prompts is None: 101 mask_input = None 102 else: 103 mask_input = mask_prompts[0] 104 105 masks, scores, _ = self.sam.predict_torch( 106 point_coords=point_coords, 107 point_labels=point_labels, 108 boxes=boxes, 109 mask_input=mask_input, 110 multimask_output=False 111 ) 112 113 assert masks.shape[2:] == image.shape[2:], \ 114 f"{masks.shape[2:]} is not as expected ({image.shape[2:]})" 115 116 # Ensure batch axis. 117 if masks.ndim == 4: 118 masks = masks[None] 119 assert scores.ndim == 2 120 scores = scores[None] 121 122 embeddings = self.sam.get_image_embedding() 123 return masks.to(dtype=torch.uint8), scores, embeddings
Arguments:
- image: torch inputs of dimensions B x C x H x W
- box_prompts: box coordinates of dimensions B x OBJECTS x 4
- point_prompts: point coordinates of dimension B x OBJECTS x POINTS x 2
- point_labels: point labels of dimension B x OBJECTS x POINTS
- mask_prompts: mask prompts of dimension B x OBJECTS x 256 x 256
- embeddings: precomputed image embeddings B x 256 x 64 x 64
Returns:
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile