micro_sam.bioimageio.predictor_adaptor

  1import warnings
  2from typing import Optional, Tuple
  3
  4import torch
  5from torch import nn
  6
  7from segment_anything.predictor import SamPredictor
  8
  9try:
 10    # Avoid import warnings from mobile_sam
 11    with warnings.catch_warnings():
 12        warnings.simplefilter("ignore")
 13        from mobile_sam import sam_model_registry
 14except ImportError:
 15    from segment_anything import sam_model_registry
 16
 17
 18class PredictorAdaptor(nn.Module):
 19    """Wrapper around the SamPredictor.
 20
 21    This model supports the same functionality as SamPredictor and can provide mask segmentations
 22    from box, point or mask input prompts.
 23
 24    Args:
 25        model_type: The type of the model for the image encoder.
 26            Can be one of 'vit_b', 'vit_l', 'vit_h' or 'vit_t'.
 27            For 'vit_t' support the 'mobile_sam' package has to be installed.
 28    """
 29    def __init__(self, model_type: str) -> None:
 30        super().__init__()
 31        sam_model = sam_model_registry[model_type]()
 32        self.sam = SamPredictor(sam_model)
 33
 34    def load_state_dict(self, state):
 35        self.sam.model.load_state_dict(state)
 36
 37    @torch.no_grad()
 38    def forward(
 39        self,
 40        image: torch.Tensor,
 41        box_prompts: Optional[torch.Tensor] = None,
 42        point_prompts: Optional[torch.Tensor] = None,
 43        point_labels: Optional[torch.Tensor] = None,
 44        mask_prompts: Optional[torch.Tensor] = None,
 45        embeddings: Optional[torch.Tensor] = None,
 46    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 47        """
 48
 49        Args:
 50            image: torch inputs of dimensions B x C x H x W
 51            box_prompts: box coordinates of dimensions B x OBJECTS x 4
 52            point_prompts: point coordinates of dimension B x OBJECTS x POINTS x 2
 53            point_labels: point labels of dimension B x OBJECTS x POINTS
 54            mask_prompts: mask prompts of dimension B x OBJECTS x 256 x 256
 55            embeddings: precomputed image embeddings B x 256 x 64 x 64
 56
 57        Returns:
 58            The segmentation masks.
 59            The scores for prediction quality.
 60            The computed image embeddings.
 61        """
 62        batch_size = image.shape[0]
 63        if batch_size != 1:
 64            raise ValueError
 65
 66        # We have image embeddings set and image embeddings were not passed.
 67        if self.sam.is_image_set and embeddings is None:
 68            pass   # do nothing
 69
 70        # The embeddings are passed, so we set them.
 71        elif embeddings is not None:
 72            self.sam.features = embeddings
 73            self.sam.orig_h, self.sam.orig_w = image.shape[2:]
 74            self.sam.input_h, self.sam.input_w = self.sam.transform.apply_image_torch(image).shape[2:]
 75            self.sam.is_image_set = True
 76
 77        # We don't have image embeddings set and they were not passed.
 78        elif not self.sam.is_image_set:
 79            input_ = self.sam.transform.apply_image_torch(image)
 80            self.sam.set_torch_image(input_, original_image_size=image.shape[2:])
 81            self.sam.orig_h, self.sam.orig_w = self.sam.original_size
 82            self.sam.input_h, self.sam.input_w = self.sam.input_size
 83
 84        assert self.sam.is_image_set, "The predictor has not yet been initialized."
 85
 86        # Ensure input size and original size are set.
 87        self.sam.input_size = (self.sam.input_h, self.sam.input_w)
 88        self.sam.original_size = (self.sam.orig_h, self.sam.orig_w)
 89
 90        if box_prompts is None:
 91            boxes = None
 92        else:
 93            boxes = self.sam.transform.apply_boxes_torch(box_prompts, original_size=self.sam.original_size)
 94
 95        if point_prompts is None:
 96            point_coords = None
 97        else:
 98            assert point_labels is not None
 99            point_coords = self.sam.transform.apply_coords_torch(point_prompts, original_size=self.sam.original_size)[0]
100            point_labels = point_labels[0]
101
102        if mask_prompts is None:
103            mask_input = None
104        else:
105            mask_input = mask_prompts[0]
106
107        masks, scores, _ = self.sam.predict_torch(
108            point_coords=point_coords,
109            point_labels=point_labels,
110            boxes=boxes,
111            mask_input=mask_input,
112            multimask_output=False
113        )
114
115        assert masks.shape[2:] == image.shape[2:], \
116            f"{masks.shape[2:]} is not as expected ({image.shape[2:]})"
117
118        # Ensure batch axis.
119        if masks.ndim == 4:
120            masks = masks[None]
121            assert scores.ndim == 2
122            scores = scores[None]
123
124        embeddings = self.sam.get_image_embedding()
125        return masks.to(dtype=torch.uint8), scores, embeddings
class PredictorAdaptor(torch.nn.modules.module.Module):
 19class PredictorAdaptor(nn.Module):
 20    """Wrapper around the SamPredictor.
 21
 22    This model supports the same functionality as SamPredictor and can provide mask segmentations
 23    from box, point or mask input prompts.
 24
 25    Args:
 26        model_type: The type of the model for the image encoder.
 27            Can be one of 'vit_b', 'vit_l', 'vit_h' or 'vit_t'.
 28            For 'vit_t' support the 'mobile_sam' package has to be installed.
 29    """
 30    def __init__(self, model_type: str) -> None:
 31        super().__init__()
 32        sam_model = sam_model_registry[model_type]()
 33        self.sam = SamPredictor(sam_model)
 34
 35    def load_state_dict(self, state):
 36        self.sam.model.load_state_dict(state)
 37
 38    @torch.no_grad()
 39    def forward(
 40        self,
 41        image: torch.Tensor,
 42        box_prompts: Optional[torch.Tensor] = None,
 43        point_prompts: Optional[torch.Tensor] = None,
 44        point_labels: Optional[torch.Tensor] = None,
 45        mask_prompts: Optional[torch.Tensor] = None,
 46        embeddings: Optional[torch.Tensor] = None,
 47    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 48        """
 49
 50        Args:
 51            image: torch inputs of dimensions B x C x H x W
 52            box_prompts: box coordinates of dimensions B x OBJECTS x 4
 53            point_prompts: point coordinates of dimension B x OBJECTS x POINTS x 2
 54            point_labels: point labels of dimension B x OBJECTS x POINTS
 55            mask_prompts: mask prompts of dimension B x OBJECTS x 256 x 256
 56            embeddings: precomputed image embeddings B x 256 x 64 x 64
 57
 58        Returns:
 59            The segmentation masks.
 60            The scores for prediction quality.
 61            The computed image embeddings.
 62        """
 63        batch_size = image.shape[0]
 64        if batch_size != 1:
 65            raise ValueError
 66
 67        # We have image embeddings set and image embeddings were not passed.
 68        if self.sam.is_image_set and embeddings is None:
 69            pass   # do nothing
 70
 71        # The embeddings are passed, so we set them.
 72        elif embeddings is not None:
 73            self.sam.features = embeddings
 74            self.sam.orig_h, self.sam.orig_w = image.shape[2:]
 75            self.sam.input_h, self.sam.input_w = self.sam.transform.apply_image_torch(image).shape[2:]
 76            self.sam.is_image_set = True
 77
 78        # We don't have image embeddings set and they were not passed.
 79        elif not self.sam.is_image_set:
 80            input_ = self.sam.transform.apply_image_torch(image)
 81            self.sam.set_torch_image(input_, original_image_size=image.shape[2:])
 82            self.sam.orig_h, self.sam.orig_w = self.sam.original_size
 83            self.sam.input_h, self.sam.input_w = self.sam.input_size
 84
 85        assert self.sam.is_image_set, "The predictor has not yet been initialized."
 86
 87        # Ensure input size and original size are set.
 88        self.sam.input_size = (self.sam.input_h, self.sam.input_w)
 89        self.sam.original_size = (self.sam.orig_h, self.sam.orig_w)
 90
 91        if box_prompts is None:
 92            boxes = None
 93        else:
 94            boxes = self.sam.transform.apply_boxes_torch(box_prompts, original_size=self.sam.original_size)
 95
 96        if point_prompts is None:
 97            point_coords = None
 98        else:
 99            assert point_labels is not None
100            point_coords = self.sam.transform.apply_coords_torch(point_prompts, original_size=self.sam.original_size)[0]
101            point_labels = point_labels[0]
102
103        if mask_prompts is None:
104            mask_input = None
105        else:
106            mask_input = mask_prompts[0]
107
108        masks, scores, _ = self.sam.predict_torch(
109            point_coords=point_coords,
110            point_labels=point_labels,
111            boxes=boxes,
112            mask_input=mask_input,
113            multimask_output=False
114        )
115
116        assert masks.shape[2:] == image.shape[2:], \
117            f"{masks.shape[2:]} is not as expected ({image.shape[2:]})"
118
119        # Ensure batch axis.
120        if masks.ndim == 4:
121            masks = masks[None]
122            assert scores.ndim == 2
123            scores = scores[None]
124
125        embeddings = self.sam.get_image_embedding()
126        return masks.to(dtype=torch.uint8), scores, embeddings

Wrapper around the SamPredictor.

This model supports the same functionality as SamPredictor and can provide mask segmentations from box, point or mask input prompts.

Arguments:
  • model_type: The type of the model for the image encoder. Can be one of 'vit_b', 'vit_l', 'vit_h' or 'vit_t'. For 'vit_t' support the 'mobile_sam' package has to be installed.
PredictorAdaptor(model_type: str)
30    def __init__(self, model_type: str) -> None:
31        super().__init__()
32        sam_model = sam_model_registry[model_type]()
33        self.sam = SamPredictor(sam_model)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

sam
def load_state_dict(self, state):
35    def load_state_dict(self, state):
36        self.sam.model.load_state_dict(state)

Copy parameters and buffers from state_dict into this module and its descendants.

If strict is True, then the keys of state_dict must exactly match the keys returned by this module's ~torch.nn.Module.state_dict() function.

If assign is True the optimizer must be created after the call to load_state_dict unless ~torch.__future__.get_swap_module_params_on_conversion() is True.

Arguments:
  • state_dict (dict): a dict containing parameters and persistent buffers.
  • strict (bool, optional): whether to strictly enforce that the keys in state_dict match the keys returned by this module's ~torch.nn.Module.state_dict() function. Default: True
  • assign (bool, optional): When False, the properties of the tensors in the current module are preserved while when True, the properties of the Tensors in the state dict are preserved. The only exception is the requires_grad field of ~torch.nn.Parameters for which the value from the module is preserved. Default: False
Returns:

NamedTuple with missing_keys and unexpected_keys fields: * missing_keys is a list of str containing any keys that are expected by this module but missing from the provided state_dict. * unexpected_keys is a list of str containing the keys that are not expected by this module but present in the provided state_dict.

Note:

If a parameter or buffer is registered as None and its corresponding key exists in state_dict, load_state_dict() will raise a RuntimeError.

@torch.no_grad()
def forward( self, image: torch.Tensor, box_prompts: Optional[torch.Tensor] = None, point_prompts: Optional[torch.Tensor] = None, point_labels: Optional[torch.Tensor] = None, mask_prompts: Optional[torch.Tensor] = None, embeddings: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 38    @torch.no_grad()
 39    def forward(
 40        self,
 41        image: torch.Tensor,
 42        box_prompts: Optional[torch.Tensor] = None,
 43        point_prompts: Optional[torch.Tensor] = None,
 44        point_labels: Optional[torch.Tensor] = None,
 45        mask_prompts: Optional[torch.Tensor] = None,
 46        embeddings: Optional[torch.Tensor] = None,
 47    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 48        """
 49
 50        Args:
 51            image: torch inputs of dimensions B x C x H x W
 52            box_prompts: box coordinates of dimensions B x OBJECTS x 4
 53            point_prompts: point coordinates of dimension B x OBJECTS x POINTS x 2
 54            point_labels: point labels of dimension B x OBJECTS x POINTS
 55            mask_prompts: mask prompts of dimension B x OBJECTS x 256 x 256
 56            embeddings: precomputed image embeddings B x 256 x 64 x 64
 57
 58        Returns:
 59            The segmentation masks.
 60            The scores for prediction quality.
 61            The computed image embeddings.
 62        """
 63        batch_size = image.shape[0]
 64        if batch_size != 1:
 65            raise ValueError
 66
 67        # We have image embeddings set and image embeddings were not passed.
 68        if self.sam.is_image_set and embeddings is None:
 69            pass   # do nothing
 70
 71        # The embeddings are passed, so we set them.
 72        elif embeddings is not None:
 73            self.sam.features = embeddings
 74            self.sam.orig_h, self.sam.orig_w = image.shape[2:]
 75            self.sam.input_h, self.sam.input_w = self.sam.transform.apply_image_torch(image).shape[2:]
 76            self.sam.is_image_set = True
 77
 78        # We don't have image embeddings set and they were not passed.
 79        elif not self.sam.is_image_set:
 80            input_ = self.sam.transform.apply_image_torch(image)
 81            self.sam.set_torch_image(input_, original_image_size=image.shape[2:])
 82            self.sam.orig_h, self.sam.orig_w = self.sam.original_size
 83            self.sam.input_h, self.sam.input_w = self.sam.input_size
 84
 85        assert self.sam.is_image_set, "The predictor has not yet been initialized."
 86
 87        # Ensure input size and original size are set.
 88        self.sam.input_size = (self.sam.input_h, self.sam.input_w)
 89        self.sam.original_size = (self.sam.orig_h, self.sam.orig_w)
 90
 91        if box_prompts is None:
 92            boxes = None
 93        else:
 94            boxes = self.sam.transform.apply_boxes_torch(box_prompts, original_size=self.sam.original_size)
 95
 96        if point_prompts is None:
 97            point_coords = None
 98        else:
 99            assert point_labels is not None
100            point_coords = self.sam.transform.apply_coords_torch(point_prompts, original_size=self.sam.original_size)[0]
101            point_labels = point_labels[0]
102
103        if mask_prompts is None:
104            mask_input = None
105        else:
106            mask_input = mask_prompts[0]
107
108        masks, scores, _ = self.sam.predict_torch(
109            point_coords=point_coords,
110            point_labels=point_labels,
111            boxes=boxes,
112            mask_input=mask_input,
113            multimask_output=False
114        )
115
116        assert masks.shape[2:] == image.shape[2:], \
117            f"{masks.shape[2:]} is not as expected ({image.shape[2:]})"
118
119        # Ensure batch axis.
120        if masks.ndim == 4:
121            masks = masks[None]
122            assert scores.ndim == 2
123            scores = scores[None]
124
125        embeddings = self.sam.get_image_embedding()
126        return masks.to(dtype=torch.uint8), scores, embeddings
Arguments:
  • image: torch inputs of dimensions B x C x H x W
  • box_prompts: box coordinates of dimensions B x OBJECTS x 4
  • point_prompts: point coordinates of dimension B x OBJECTS x POINTS x 2
  • point_labels: point labels of dimension B x OBJECTS x POINTS
  • mask_prompts: mask prompts of dimension B x OBJECTS x 256 x 256
  • embeddings: precomputed image embeddings B x 256 x 64 x 64
Returns:

The segmentation masks. The scores for prediction quality. The computed image embeddings.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile