micro_sam.bioimageio.predictor_adaptor
1import warnings 2from typing import Optional, Tuple 3 4import torch 5from torch import nn 6 7from segment_anything.predictor import SamPredictor 8 9try: 10 # Avoid import warnings from mobile_sam 11 with warnings.catch_warnings(): 12 warnings.simplefilter("ignore") 13 from mobile_sam import sam_model_registry 14except ImportError: 15 from segment_anything import sam_model_registry 16 17 18class PredictorAdaptor(nn.Module): 19 """Wrapper around the SamPredictor. 20 21 This model supports the same functionality as SamPredictor and can provide mask segmentations 22 from box, point or mask input prompts. 23 24 Args: 25 model_type: The type of the model for the image encoder. 26 Can be one of 'vit_b', 'vit_l', 'vit_h' or 'vit_t'. 27 For 'vit_t' support the 'mobile_sam' package has to be installed. 28 """ 29 def __init__(self, model_type: str) -> None: 30 super().__init__() 31 sam_model = sam_model_registry[model_type]() 32 self.sam = SamPredictor(sam_model) 33 34 def load_state_dict(self, state): 35 self.sam.model.load_state_dict(state) 36 37 @torch.no_grad() 38 def forward( 39 self, 40 image: torch.Tensor, 41 box_prompts: Optional[torch.Tensor] = None, 42 point_prompts: Optional[torch.Tensor] = None, 43 point_labels: Optional[torch.Tensor] = None, 44 mask_prompts: Optional[torch.Tensor] = None, 45 embeddings: Optional[torch.Tensor] = None, 46 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 47 """ 48 49 Args: 50 image: torch inputs of dimensions B x C x H x W 51 box_prompts: box coordinates of dimensions B x OBJECTS x 4 52 point_prompts: point coordinates of dimension B x OBJECTS x POINTS x 2 53 point_labels: point labels of dimension B x OBJECTS x POINTS 54 mask_prompts: mask prompts of dimension B x OBJECTS x 256 x 256 55 embeddings: precomputed image embeddings B x 256 x 64 x 64 56 57 Returns: 58 The segmentation masks. 59 The scores for prediction quality. 60 The computed image embeddings. 61 """ 62 batch_size = image.shape[0] 63 if batch_size != 1: 64 raise ValueError 65 66 # We have image embeddings set and image embeddings were not passed. 67 if self.sam.is_image_set and embeddings is None: 68 pass # do nothing 69 70 # The embeddings are passed, so we set them. 71 elif embeddings is not None: 72 self.sam.features = embeddings 73 self.sam.orig_h, self.sam.orig_w = image.shape[2:] 74 self.sam.input_h, self.sam.input_w = self.sam.transform.apply_image_torch(image).shape[2:] 75 self.sam.is_image_set = True 76 77 # We don't have image embeddings set and they were not passed. 78 elif not self.sam.is_image_set: 79 input_ = self.sam.transform.apply_image_torch(image) 80 self.sam.set_torch_image(input_, original_image_size=image.shape[2:]) 81 self.sam.orig_h, self.sam.orig_w = self.sam.original_size 82 self.sam.input_h, self.sam.input_w = self.sam.input_size 83 84 assert self.sam.is_image_set, "The predictor has not yet been initialized." 85 86 # Ensure input size and original size are set. 87 self.sam.input_size = (self.sam.input_h, self.sam.input_w) 88 self.sam.original_size = (self.sam.orig_h, self.sam.orig_w) 89 90 if box_prompts is None: 91 boxes = None 92 else: 93 boxes = self.sam.transform.apply_boxes_torch(box_prompts, original_size=self.sam.original_size) 94 95 if point_prompts is None: 96 point_coords = None 97 else: 98 assert point_labels is not None 99 point_coords = self.sam.transform.apply_coords_torch(point_prompts, original_size=self.sam.original_size)[0] 100 point_labels = point_labels[0] 101 102 if mask_prompts is None: 103 mask_input = None 104 else: 105 mask_input = mask_prompts[0] 106 107 masks, scores, _ = self.sam.predict_torch( 108 point_coords=point_coords, 109 point_labels=point_labels, 110 boxes=boxes, 111 mask_input=mask_input, 112 multimask_output=False 113 ) 114 115 assert masks.shape[2:] == image.shape[2:], \ 116 f"{masks.shape[2:]} is not as expected ({image.shape[2:]})" 117 118 # Ensure batch axis. 119 if masks.ndim == 4: 120 masks = masks[None] 121 assert scores.ndim == 2 122 scores = scores[None] 123 124 embeddings = self.sam.get_image_embedding() 125 return masks.to(dtype=torch.uint8), scores, embeddings
19class PredictorAdaptor(nn.Module): 20 """Wrapper around the SamPredictor. 21 22 This model supports the same functionality as SamPredictor and can provide mask segmentations 23 from box, point or mask input prompts. 24 25 Args: 26 model_type: The type of the model for the image encoder. 27 Can be one of 'vit_b', 'vit_l', 'vit_h' or 'vit_t'. 28 For 'vit_t' support the 'mobile_sam' package has to be installed. 29 """ 30 def __init__(self, model_type: str) -> None: 31 super().__init__() 32 sam_model = sam_model_registry[model_type]() 33 self.sam = SamPredictor(sam_model) 34 35 def load_state_dict(self, state): 36 self.sam.model.load_state_dict(state) 37 38 @torch.no_grad() 39 def forward( 40 self, 41 image: torch.Tensor, 42 box_prompts: Optional[torch.Tensor] = None, 43 point_prompts: Optional[torch.Tensor] = None, 44 point_labels: Optional[torch.Tensor] = None, 45 mask_prompts: Optional[torch.Tensor] = None, 46 embeddings: Optional[torch.Tensor] = None, 47 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 48 """ 49 50 Args: 51 image: torch inputs of dimensions B x C x H x W 52 box_prompts: box coordinates of dimensions B x OBJECTS x 4 53 point_prompts: point coordinates of dimension B x OBJECTS x POINTS x 2 54 point_labels: point labels of dimension B x OBJECTS x POINTS 55 mask_prompts: mask prompts of dimension B x OBJECTS x 256 x 256 56 embeddings: precomputed image embeddings B x 256 x 64 x 64 57 58 Returns: 59 The segmentation masks. 60 The scores for prediction quality. 61 The computed image embeddings. 62 """ 63 batch_size = image.shape[0] 64 if batch_size != 1: 65 raise ValueError 66 67 # We have image embeddings set and image embeddings were not passed. 68 if self.sam.is_image_set and embeddings is None: 69 pass # do nothing 70 71 # The embeddings are passed, so we set them. 72 elif embeddings is not None: 73 self.sam.features = embeddings 74 self.sam.orig_h, self.sam.orig_w = image.shape[2:] 75 self.sam.input_h, self.sam.input_w = self.sam.transform.apply_image_torch(image).shape[2:] 76 self.sam.is_image_set = True 77 78 # We don't have image embeddings set and they were not passed. 79 elif not self.sam.is_image_set: 80 input_ = self.sam.transform.apply_image_torch(image) 81 self.sam.set_torch_image(input_, original_image_size=image.shape[2:]) 82 self.sam.orig_h, self.sam.orig_w = self.sam.original_size 83 self.sam.input_h, self.sam.input_w = self.sam.input_size 84 85 assert self.sam.is_image_set, "The predictor has not yet been initialized." 86 87 # Ensure input size and original size are set. 88 self.sam.input_size = (self.sam.input_h, self.sam.input_w) 89 self.sam.original_size = (self.sam.orig_h, self.sam.orig_w) 90 91 if box_prompts is None: 92 boxes = None 93 else: 94 boxes = self.sam.transform.apply_boxes_torch(box_prompts, original_size=self.sam.original_size) 95 96 if point_prompts is None: 97 point_coords = None 98 else: 99 assert point_labels is not None 100 point_coords = self.sam.transform.apply_coords_torch(point_prompts, original_size=self.sam.original_size)[0] 101 point_labels = point_labels[0] 102 103 if mask_prompts is None: 104 mask_input = None 105 else: 106 mask_input = mask_prompts[0] 107 108 masks, scores, _ = self.sam.predict_torch( 109 point_coords=point_coords, 110 point_labels=point_labels, 111 boxes=boxes, 112 mask_input=mask_input, 113 multimask_output=False 114 ) 115 116 assert masks.shape[2:] == image.shape[2:], \ 117 f"{masks.shape[2:]} is not as expected ({image.shape[2:]})" 118 119 # Ensure batch axis. 120 if masks.ndim == 4: 121 masks = masks[None] 122 assert scores.ndim == 2 123 scores = scores[None] 124 125 embeddings = self.sam.get_image_embedding() 126 return masks.to(dtype=torch.uint8), scores, embeddings
Wrapper around the SamPredictor.
This model supports the same functionality as SamPredictor and can provide mask segmentations from box, point or mask input prompts.
Arguments:
- model_type: The type of the model for the image encoder. Can be one of 'vit_b', 'vit_l', 'vit_h' or 'vit_t'. For 'vit_t' support the 'mobile_sam' package has to be installed.
30 def __init__(self, model_type: str) -> None: 31 super().__init__() 32 sam_model = sam_model_registry[model_type]() 33 self.sam = SamPredictor(sam_model)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Copy parameters and buffers from state_dict
into this module and its descendants.
If strict
is True
, then
the keys of state_dict
must exactly match the keys returned
by this module's ~torch.nn.Module.state_dict()
function.
If assign
is True
the optimizer must be created after
the call to load_state_dict
unless
~torch.__future__.get_swap_module_params_on_conversion()
is True
.
Arguments:
- state_dict (dict): a dict containing parameters and persistent buffers.
- strict (bool, optional): whether to strictly enforce that the keys
in
state_dict
match the keys returned by this module's~torch.nn.Module.state_dict()
function. Default:True
- assign (bool, optional): When
False
, the properties of the tensors in the current module are preserved while whenTrue
, the properties of the Tensors in the state dict are preserved. The only exception is therequires_grad
field of~torch.nn.Parameter
s for which the value from the module is preserved. Default:False
Returns:
NamedTuple
withmissing_keys
andunexpected_keys
fields: * missing_keys is a list of str containing any keys that are expected by this module but missing from the providedstate_dict
. * unexpected_keys is a list of str containing the keys that are not expected by this module but present in the providedstate_dict
.
Note:
If a parameter or buffer is registered as
None
and its corresponding key exists instate_dict
,load_state_dict()
will raise aRuntimeError
.
38 @torch.no_grad() 39 def forward( 40 self, 41 image: torch.Tensor, 42 box_prompts: Optional[torch.Tensor] = None, 43 point_prompts: Optional[torch.Tensor] = None, 44 point_labels: Optional[torch.Tensor] = None, 45 mask_prompts: Optional[torch.Tensor] = None, 46 embeddings: Optional[torch.Tensor] = None, 47 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 48 """ 49 50 Args: 51 image: torch inputs of dimensions B x C x H x W 52 box_prompts: box coordinates of dimensions B x OBJECTS x 4 53 point_prompts: point coordinates of dimension B x OBJECTS x POINTS x 2 54 point_labels: point labels of dimension B x OBJECTS x POINTS 55 mask_prompts: mask prompts of dimension B x OBJECTS x 256 x 256 56 embeddings: precomputed image embeddings B x 256 x 64 x 64 57 58 Returns: 59 The segmentation masks. 60 The scores for prediction quality. 61 The computed image embeddings. 62 """ 63 batch_size = image.shape[0] 64 if batch_size != 1: 65 raise ValueError 66 67 # We have image embeddings set and image embeddings were not passed. 68 if self.sam.is_image_set and embeddings is None: 69 pass # do nothing 70 71 # The embeddings are passed, so we set them. 72 elif embeddings is not None: 73 self.sam.features = embeddings 74 self.sam.orig_h, self.sam.orig_w = image.shape[2:] 75 self.sam.input_h, self.sam.input_w = self.sam.transform.apply_image_torch(image).shape[2:] 76 self.sam.is_image_set = True 77 78 # We don't have image embeddings set and they were not passed. 79 elif not self.sam.is_image_set: 80 input_ = self.sam.transform.apply_image_torch(image) 81 self.sam.set_torch_image(input_, original_image_size=image.shape[2:]) 82 self.sam.orig_h, self.sam.orig_w = self.sam.original_size 83 self.sam.input_h, self.sam.input_w = self.sam.input_size 84 85 assert self.sam.is_image_set, "The predictor has not yet been initialized." 86 87 # Ensure input size and original size are set. 88 self.sam.input_size = (self.sam.input_h, self.sam.input_w) 89 self.sam.original_size = (self.sam.orig_h, self.sam.orig_w) 90 91 if box_prompts is None: 92 boxes = None 93 else: 94 boxes = self.sam.transform.apply_boxes_torch(box_prompts, original_size=self.sam.original_size) 95 96 if point_prompts is None: 97 point_coords = None 98 else: 99 assert point_labels is not None 100 point_coords = self.sam.transform.apply_coords_torch(point_prompts, original_size=self.sam.original_size)[0] 101 point_labels = point_labels[0] 102 103 if mask_prompts is None: 104 mask_input = None 105 else: 106 mask_input = mask_prompts[0] 107 108 masks, scores, _ = self.sam.predict_torch( 109 point_coords=point_coords, 110 point_labels=point_labels, 111 boxes=boxes, 112 mask_input=mask_input, 113 multimask_output=False 114 ) 115 116 assert masks.shape[2:] == image.shape[2:], \ 117 f"{masks.shape[2:]} is not as expected ({image.shape[2:]})" 118 119 # Ensure batch axis. 120 if masks.ndim == 4: 121 masks = masks[None] 122 assert scores.ndim == 2 123 scores = scores[None] 124 125 embeddings = self.sam.get_image_embedding() 126 return masks.to(dtype=torch.uint8), scores, embeddings
Arguments:
- image: torch inputs of dimensions B x C x H x W
- box_prompts: box coordinates of dimensions B x OBJECTS x 4
- point_prompts: point coordinates of dimension B x OBJECTS x POINTS x 2
- point_labels: point labels of dimension B x OBJECTS x POINTS
- mask_prompts: mask prompts of dimension B x OBJECTS x 256 x 256
- embeddings: precomputed image embeddings B x 256 x 64 x 64
Returns:
The segmentation masks. The scores for prediction quality. The computed image embeddings.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile