micro_sam.bioimageio.predictor_adaptor

  1import warnings
  2from typing import Optional, Tuple
  3
  4import torch
  5from torch import nn
  6
  7from segment_anything.predictor import SamPredictor
  8
  9try:
 10    # Avoid import warnings from mobile_sam
 11    with warnings.catch_warnings():
 12        warnings.simplefilter("ignore")
 13        from mobile_sam import sam_model_registry
 14except ImportError:
 15    from segment_anything import sam_model_registry
 16
 17
 18class PredictorAdaptor(nn.Module):
 19    """Wrapper around the SamPredictor.
 20
 21    This model supports the same functionality as SamPredictor and can provide mask segmentations
 22    from box, point or mask input prompts.
 23
 24    Args:
 25        model_type: The type of the model for the image encoder.
 26            Can be one of 'vit_b', 'vit_l', 'vit_h' or 'vit_t'.
 27            For 'vit_t' support the 'mobile_sam' package has to be installed.
 28    """
 29    def __init__(self, model_type: str) -> None:
 30        super().__init__()
 31        sam_model = sam_model_registry[model_type]()
 32        self.sam = SamPredictor(sam_model)
 33
 34    def load_state_dict(self, state):
 35        self.sam.model.load_state_dict(state)
 36
 37    @torch.no_grad()
 38    def forward(
 39        self,
 40        image: torch.Tensor,
 41        box_prompts: Optional[torch.Tensor] = None,
 42        point_prompts: Optional[torch.Tensor] = None,
 43        point_labels: Optional[torch.Tensor] = None,
 44        mask_prompts: Optional[torch.Tensor] = None,
 45        embeddings: Optional[torch.Tensor] = None,
 46    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 47        """
 48
 49        Args:
 50            image: torch inputs of dimensions B x C x H x W
 51            box_prompts: box coordinates of dimensions B x OBJECTS x 4
 52            point_prompts: point coordinates of dimension B x OBJECTS x POINTS x 2
 53            point_labels: point labels of dimension B x OBJECTS x POINTS
 54            mask_prompts: mask prompts of dimension B x OBJECTS x 256 x 256
 55            embeddings: precomputed image embeddings B x 256 x 64 x 64
 56
 57        Returns:
 58        """
 59        batch_size = image.shape[0]
 60        if batch_size != 1:
 61            raise ValueError
 62
 63        # We have image embeddings set and image embeddings were not passed.
 64        if self.sam.is_image_set and embeddings is None:
 65            pass   # do nothing
 66
 67        # The embeddings are passed, so we set them.
 68        elif embeddings is not None:
 69            self.sam.features = embeddings
 70            self.sam.orig_h, self.sam.orig_w = image.shape[2:]
 71            self.sam.input_h, self.sam.input_w = self.sam.transform.apply_image_torch(image).shape[2:]
 72            self.sam.is_image_set = True
 73
 74        # We don't have image embeddings set and they were not passed.
 75        elif not self.sam.is_image_set:
 76            input_ = self.sam.transform.apply_image_torch(image)
 77            self.sam.set_torch_image(input_, original_image_size=image.shape[2:])
 78            self.sam.orig_h, self.sam.orig_w = self.sam.original_size
 79            self.sam.input_h, self.sam.input_w = self.sam.input_size
 80
 81        assert self.sam.is_image_set, "The predictor has not yet been initialized."
 82
 83        # Ensure input size and original size are set.
 84        self.sam.input_size = (self.sam.input_h, self.sam.input_w)
 85        self.sam.original_size = (self.sam.orig_h, self.sam.orig_w)
 86
 87        if box_prompts is None:
 88            boxes = None
 89        else:
 90            boxes = self.sam.transform.apply_boxes_torch(box_prompts, original_size=self.sam.original_size)
 91
 92        if point_prompts is None:
 93            point_coords = None
 94        else:
 95            assert point_labels is not None
 96            point_coords = self.sam.transform.apply_coords_torch(point_prompts, original_size=self.sam.original_size)[0]
 97            point_labels = point_labels[0]
 98
 99        if mask_prompts is None:
100            mask_input = None
101        else:
102            mask_input = mask_prompts[0]
103
104        masks, scores, _ = self.sam.predict_torch(
105            point_coords=point_coords,
106            point_labels=point_labels,
107            boxes=boxes,
108            mask_input=mask_input,
109            multimask_output=False
110        )
111
112        assert masks.shape[2:] == image.shape[2:], \
113            f"{masks.shape[2:]} is not as expected ({image.shape[2:]})"
114
115        # Ensure batch axis.
116        if masks.ndim == 4:
117            masks = masks[None]
118            assert scores.ndim == 2
119            scores = scores[None]
120
121        embeddings = self.sam.get_image_embedding()
122        return masks.to(dtype=torch.uint8), scores, embeddings
class PredictorAdaptor(torch.nn.modules.module.Module):
 19class PredictorAdaptor(nn.Module):
 20    """Wrapper around the SamPredictor.
 21
 22    This model supports the same functionality as SamPredictor and can provide mask segmentations
 23    from box, point or mask input prompts.
 24
 25    Args:
 26        model_type: The type of the model for the image encoder.
 27            Can be one of 'vit_b', 'vit_l', 'vit_h' or 'vit_t'.
 28            For 'vit_t' support the 'mobile_sam' package has to be installed.
 29    """
 30    def __init__(self, model_type: str) -> None:
 31        super().__init__()
 32        sam_model = sam_model_registry[model_type]()
 33        self.sam = SamPredictor(sam_model)
 34
 35    def load_state_dict(self, state):
 36        self.sam.model.load_state_dict(state)
 37
 38    @torch.no_grad()
 39    def forward(
 40        self,
 41        image: torch.Tensor,
 42        box_prompts: Optional[torch.Tensor] = None,
 43        point_prompts: Optional[torch.Tensor] = None,
 44        point_labels: Optional[torch.Tensor] = None,
 45        mask_prompts: Optional[torch.Tensor] = None,
 46        embeddings: Optional[torch.Tensor] = None,
 47    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 48        """
 49
 50        Args:
 51            image: torch inputs of dimensions B x C x H x W
 52            box_prompts: box coordinates of dimensions B x OBJECTS x 4
 53            point_prompts: point coordinates of dimension B x OBJECTS x POINTS x 2
 54            point_labels: point labels of dimension B x OBJECTS x POINTS
 55            mask_prompts: mask prompts of dimension B x OBJECTS x 256 x 256
 56            embeddings: precomputed image embeddings B x 256 x 64 x 64
 57
 58        Returns:
 59        """
 60        batch_size = image.shape[0]
 61        if batch_size != 1:
 62            raise ValueError
 63
 64        # We have image embeddings set and image embeddings were not passed.
 65        if self.sam.is_image_set and embeddings is None:
 66            pass   # do nothing
 67
 68        # The embeddings are passed, so we set them.
 69        elif embeddings is not None:
 70            self.sam.features = embeddings
 71            self.sam.orig_h, self.sam.orig_w = image.shape[2:]
 72            self.sam.input_h, self.sam.input_w = self.sam.transform.apply_image_torch(image).shape[2:]
 73            self.sam.is_image_set = True
 74
 75        # We don't have image embeddings set and they were not passed.
 76        elif not self.sam.is_image_set:
 77            input_ = self.sam.transform.apply_image_torch(image)
 78            self.sam.set_torch_image(input_, original_image_size=image.shape[2:])
 79            self.sam.orig_h, self.sam.orig_w = self.sam.original_size
 80            self.sam.input_h, self.sam.input_w = self.sam.input_size
 81
 82        assert self.sam.is_image_set, "The predictor has not yet been initialized."
 83
 84        # Ensure input size and original size are set.
 85        self.sam.input_size = (self.sam.input_h, self.sam.input_w)
 86        self.sam.original_size = (self.sam.orig_h, self.sam.orig_w)
 87
 88        if box_prompts is None:
 89            boxes = None
 90        else:
 91            boxes = self.sam.transform.apply_boxes_torch(box_prompts, original_size=self.sam.original_size)
 92
 93        if point_prompts is None:
 94            point_coords = None
 95        else:
 96            assert point_labels is not None
 97            point_coords = self.sam.transform.apply_coords_torch(point_prompts, original_size=self.sam.original_size)[0]
 98            point_labels = point_labels[0]
 99
100        if mask_prompts is None:
101            mask_input = None
102        else:
103            mask_input = mask_prompts[0]
104
105        masks, scores, _ = self.sam.predict_torch(
106            point_coords=point_coords,
107            point_labels=point_labels,
108            boxes=boxes,
109            mask_input=mask_input,
110            multimask_output=False
111        )
112
113        assert masks.shape[2:] == image.shape[2:], \
114            f"{masks.shape[2:]} is not as expected ({image.shape[2:]})"
115
116        # Ensure batch axis.
117        if masks.ndim == 4:
118            masks = masks[None]
119            assert scores.ndim == 2
120            scores = scores[None]
121
122        embeddings = self.sam.get_image_embedding()
123        return masks.to(dtype=torch.uint8), scores, embeddings

Wrapper around the SamPredictor.

This model supports the same functionality as SamPredictor and can provide mask segmentations from box, point or mask input prompts.

Arguments:
  • model_type: The type of the model for the image encoder. Can be one of 'vit_b', 'vit_l', 'vit_h' or 'vit_t'. For 'vit_t' support the 'mobile_sam' package has to be installed.
PredictorAdaptor(model_type: str)
30    def __init__(self, model_type: str) -> None:
31        super().__init__()
32        sam_model = sam_model_registry[model_type]()
33        self.sam = SamPredictor(sam_model)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

sam
def load_state_dict(self, state):
35    def load_state_dict(self, state):
36        self.sam.model.load_state_dict(state)

Copy parameters and buffers from state_dict into this module and its descendants.

If strict is True, then the keys of state_dict must exactly match the keys returned by this module's ~torch.nn.Module.state_dict() function.

If assign is True the optimizer must be created after the call to load_state_dict unless ~torch.__future__.get_swap_module_params_on_conversion() is True.

Arguments:
  • state_dict (dict): a dict containing parameters and persistent buffers.
  • strict (bool, optional): whether to strictly enforce that the keys in state_dict match the keys returned by this module's ~torch.nn.Module.state_dict() function. Default: True
  • assign (bool, optional): When False, the properties of the tensors in the current module are preserved while when True, the properties of the Tensors in the state dict are preserved. The only exception is the requires_grad field of ~torch.nn.Parameters for which the value from the module is preserved. Default: False
Returns:

NamedTuple with missing_keys and unexpected_keys fields: * missing_keys is a list of str containing any keys that are expected by this module but missing from the provided state_dict. * unexpected_keys is a list of str containing the keys that are not expected by this module but present in the provided state_dict.

Note:

If a parameter or buffer is registered as None and its corresponding key exists in state_dict, load_state_dict() will raise a RuntimeError.

@torch.no_grad()
def forward( self, image: torch.Tensor, box_prompts: Optional[torch.Tensor] = None, point_prompts: Optional[torch.Tensor] = None, point_labels: Optional[torch.Tensor] = None, mask_prompts: Optional[torch.Tensor] = None, embeddings: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 38    @torch.no_grad()
 39    def forward(
 40        self,
 41        image: torch.Tensor,
 42        box_prompts: Optional[torch.Tensor] = None,
 43        point_prompts: Optional[torch.Tensor] = None,
 44        point_labels: Optional[torch.Tensor] = None,
 45        mask_prompts: Optional[torch.Tensor] = None,
 46        embeddings: Optional[torch.Tensor] = None,
 47    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 48        """
 49
 50        Args:
 51            image: torch inputs of dimensions B x C x H x W
 52            box_prompts: box coordinates of dimensions B x OBJECTS x 4
 53            point_prompts: point coordinates of dimension B x OBJECTS x POINTS x 2
 54            point_labels: point labels of dimension B x OBJECTS x POINTS
 55            mask_prompts: mask prompts of dimension B x OBJECTS x 256 x 256
 56            embeddings: precomputed image embeddings B x 256 x 64 x 64
 57
 58        Returns:
 59        """
 60        batch_size = image.shape[0]
 61        if batch_size != 1:
 62            raise ValueError
 63
 64        # We have image embeddings set and image embeddings were not passed.
 65        if self.sam.is_image_set and embeddings is None:
 66            pass   # do nothing
 67
 68        # The embeddings are passed, so we set them.
 69        elif embeddings is not None:
 70            self.sam.features = embeddings
 71            self.sam.orig_h, self.sam.orig_w = image.shape[2:]
 72            self.sam.input_h, self.sam.input_w = self.sam.transform.apply_image_torch(image).shape[2:]
 73            self.sam.is_image_set = True
 74
 75        # We don't have image embeddings set and they were not passed.
 76        elif not self.sam.is_image_set:
 77            input_ = self.sam.transform.apply_image_torch(image)
 78            self.sam.set_torch_image(input_, original_image_size=image.shape[2:])
 79            self.sam.orig_h, self.sam.orig_w = self.sam.original_size
 80            self.sam.input_h, self.sam.input_w = self.sam.input_size
 81
 82        assert self.sam.is_image_set, "The predictor has not yet been initialized."
 83
 84        # Ensure input size and original size are set.
 85        self.sam.input_size = (self.sam.input_h, self.sam.input_w)
 86        self.sam.original_size = (self.sam.orig_h, self.sam.orig_w)
 87
 88        if box_prompts is None:
 89            boxes = None
 90        else:
 91            boxes = self.sam.transform.apply_boxes_torch(box_prompts, original_size=self.sam.original_size)
 92
 93        if point_prompts is None:
 94            point_coords = None
 95        else:
 96            assert point_labels is not None
 97            point_coords = self.sam.transform.apply_coords_torch(point_prompts, original_size=self.sam.original_size)[0]
 98            point_labels = point_labels[0]
 99
100        if mask_prompts is None:
101            mask_input = None
102        else:
103            mask_input = mask_prompts[0]
104
105        masks, scores, _ = self.sam.predict_torch(
106            point_coords=point_coords,
107            point_labels=point_labels,
108            boxes=boxes,
109            mask_input=mask_input,
110            multimask_output=False
111        )
112
113        assert masks.shape[2:] == image.shape[2:], \
114            f"{masks.shape[2:]} is not as expected ({image.shape[2:]})"
115
116        # Ensure batch axis.
117        if masks.ndim == 4:
118            masks = masks[None]
119            assert scores.ndim == 2
120            scores = scores[None]
121
122        embeddings = self.sam.get_image_embedding()
123        return masks.to(dtype=torch.uint8), scores, embeddings
Arguments:
  • image: torch inputs of dimensions B x C x H x W
  • box_prompts: box coordinates of dimensions B x OBJECTS x 4
  • point_prompts: point coordinates of dimension B x OBJECTS x POINTS x 2
  • point_labels: point labels of dimension B x OBJECTS x POINTS
  • mask_prompts: mask prompts of dimension B x OBJECTS x 256 x 256
  • embeddings: precomputed image embeddings B x 256 x 64 x 64

Returns:

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile