micro_sam.bioimageio.bioengine_export

  1import os
  2import warnings
  3from typing import Optional, Union
  4
  5import torch
  6
  7from segment_anything.utils.onnx import SamOnnxModel
  8
  9try:
 10    import onnxruntime
 11    onnxruntime_exists = True
 12except ImportError:
 13    onnxruntime_exists = False
 14
 15from ..util import get_sam_model
 16
 17
 18ENCODER_CONFIG = """name: "%s"
 19backend: "pytorch"
 20platform: "pytorch_libtorch"
 21
 22max_batch_size : 1
 23input [
 24  {
 25    name: "input0__0"
 26    data_type: TYPE_FP32
 27    dims: [3, -1, -1]
 28  }
 29]
 30output [
 31  {
 32    name: "output0__0"
 33    data_type: TYPE_FP32
 34    dims: [256, 64, 64]
 35  }
 36]
 37
 38parameters: {
 39  key: "INFERENCE_MODE"
 40  value: {
 41    string_value: "true"
 42  }
 43}"""
 44
 45
 46DECODER_CONFIG = """name: "%s"
 47backend: "onnxruntime"
 48platform: "onnxruntime_onnx"
 49
 50parameters: {
 51  key: "INFERENCE_MODE"
 52  value: {
 53    string_value: "true"
 54  }
 55}
 56
 57instance_group {
 58  count: 1
 59  kind: KIND_CPU
 60}"""
 61
 62
 63def _to_numpy(tensor):
 64    return tensor.cpu().numpy()
 65
 66
 67def export_image_encoder(
 68    model_type: str,
 69    output_root: Union[str, os.PathLike],
 70    export_name: Optional[str] = None,
 71    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
 72) -> None:
 73    """Export SAM image encoder to torchscript.
 74
 75    The torchscript image encoder can be used for predicting image embeddings
 76    with a backed, e.g. with [the bioengine](https://github.com/bioimage-io/bioengine-model-runner).
 77
 78    Args:
 79        model_type: The SAM model type.
 80        output_root: The output root directory where the exported model is saved.
 81        export_name: The name of the exported model.
 82        checkpoint_path: Optional checkpoint for loading the exported model.
 83    """
 84    if export_name is None:
 85        export_name = model_type
 86    name = f"sam-{export_name}-encoder"
 87
 88    output_folder = os.path.join(output_root, name)
 89    weight_output_folder = os.path.join(output_folder, "1")
 90    os.makedirs(weight_output_folder, exist_ok=True)
 91
 92    predictor = get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, device="cpu")
 93    encoder = predictor.model.image_encoder
 94
 95    encoder.eval()
 96    input_ = torch.rand(1, 3, 1024, 1024)
 97    traced_model = torch.jit.trace(encoder, input_)
 98    weight_path = os.path.join(weight_output_folder, "model.pt")
 99    traced_model.save(weight_path)
100
101    config_output_path = os.path.join(output_folder, "config.pbtxt")
102    with open(config_output_path, "w") as f:
103        f.write(ENCODER_CONFIG % name)
104
105
106def export_onnx_model(
107    model_type: str,
108    output_root: Union[str, os.PathLike],
109    opset: int = 17,
110    export_name: Optional[str] = None,
111    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
112    return_single_mask: bool = True,
113    gelu_approximate: bool = False,
114    use_stability_score: bool = False,
115    return_extra_metrics: bool = False,
116    quantize_model: bool = False,
117) -> None:
118    """Export SAM prompt encoder and mask decoder to onnx.
119
120    The onnx encoder and decoder can be used for interactive segmentation in the browser.
121    This code is adapted from
122    https://github.com/facebookresearch/segment-anything/blob/main/scripts/export_onnx_model.py
123
124    Args:
125        model_type: The SAM model type.
126        output_root: The output root directory where the exported model is saved.
127        opset: The ONNX opset version. The recommended opset version is 17.
128        export_name: The name of the exported model.
129        checkpoint_path: Optional checkpoint for loading the SAM model.
130        return_single_mask: Whether the mask decoder returns a single or multiple masks.
131        gelu_approximate: Whether to use a GeLU approximation, in case the ONNX backend
132            does not have an efficient GeLU implementation.
133        use_stability_score: Whether to use the stability score instead of the predicted score.
134        return_extra_metrics: Whether to return a larger set of metrics.
135        quantize_model: Whether to also export a quantized version of the model.
136            This only works for onnxruntime < 1.17.
137    """
138    if export_name is None:
139        export_name = model_type
140    name = f"sam-{export_name}-decoder"
141
142    output_folder = os.path.join(output_root, name)
143    weight_output_folder = os.path.join(output_folder, "1")
144    os.makedirs(weight_output_folder, exist_ok=True)
145
146    _, sam = get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, return_sam=True, device="cpu")
147
148    weight_path = os.path.join(weight_output_folder, "model.onnx")
149
150    onnx_model = SamOnnxModel(
151        model=sam,
152        return_single_mask=return_single_mask,
153        use_stability_score=use_stability_score,
154        return_extra_metrics=return_extra_metrics,
155    )
156
157    if gelu_approximate:
158        for n, m in onnx_model.named_modules:
159            if isinstance(m, torch.nn.GELU):
160                m.approximate = "tanh"
161
162    dynamic_axes = {"point_coords": {1: "num_points"}, "point_labels": {1: "num_points"}}
163
164    embed_dim = sam.prompt_encoder.embed_dim
165    embed_size = sam.prompt_encoder.image_embedding_size
166
167    mask_input_size = [4 * x for x in embed_size]
168    dummy_inputs = {
169        "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
170        "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
171        "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
172        "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
173        "has_mask_input": torch.tensor([1], dtype=torch.float),
174        "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
175    }
176
177    _ = onnx_model(**dummy_inputs)
178
179    output_names = ["masks", "iou_predictions", "low_res_masks"]
180
181    with warnings.catch_warnings():
182        warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
183        warnings.filterwarnings("ignore", category=UserWarning)
184        with open(weight_path, "wb") as f:
185            print(f"Exporting onnx model to {weight_path}...")
186            torch.onnx.export(
187                onnx_model,
188                tuple(dummy_inputs.values()),
189                f,
190                export_params=True,
191                verbose=False,
192                opset_version=opset,
193                do_constant_folding=True,
194                input_names=list(dummy_inputs.keys()),
195                output_names=output_names,
196                dynamic_axes=dynamic_axes,
197            )
198
199    if onnxruntime_exists:
200        ort_inputs = {k: _to_numpy(v) for k, v in dummy_inputs.items()}
201        # set cpu provider default
202        providers = ["CPUExecutionProvider"]
203        ort_session = onnxruntime.InferenceSession(weight_path, providers=providers)
204        _ = ort_session.run(None, ort_inputs)
205        print("Model has successfully been run with ONNXRuntime.")
206
207    # This requires onnxruntime < 1.17.
208    # See https://github.com/facebookresearch/segment-anything/issues/699#issuecomment-1984670808
209    if quantize_model:
210        assert onnxruntime_exists
211        from onnxruntime.quantization import QuantType
212        from onnxruntime.quantization.quantize import quantize_dynamic
213
214        quantized_path = os.path.join(weight_output_folder, "model_quantized.onnx")
215        quantize_dynamic(
216            model_input=weight_path,
217            model_output=quantized_path,
218            # optimize_model=True,
219            per_channel=False,
220            reduce_range=False,
221            weight_type=QuantType.QUInt8,
222        )
223
224    config_output_path = os.path.join(output_folder, "config.pbtxt")
225    with open(config_output_path, "w") as f:
226        f.write(DECODER_CONFIG % name)
227
228
229def export_bioengine_model(
230    model_type: str,
231    output_root: Union[str, os.PathLike],
232    opset: int,
233    export_name: Optional[str] = None,
234    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
235    return_single_mask: bool = True,
236    gelu_approximate: bool = False,
237    use_stability_score: bool = False,
238    return_extra_metrics: bool = False,
239) -> None:
240    """Export SAM model to a format compatible with the BioEngine.
241
242    [The bioengine](https://github.com/bioimage-io/bioengine-model-runner) enables running the
243    image encoder on an online backend, so that SAM can be used in an online tool, or to predict
244    the image embeddings via the online backend rather than on CPU.
245
246    Args:
247        model_type: The SAM model type.
248        output_root: The output root directory where the exported model is saved.
249        opset: The ONNX opset version.
250        export_name: The name of the exported model.
251        checkpoint_path: Optional checkpoint for loading the SAM model.
252        return_single_mask: Whether the mask decoder returns a single or multiple masks.
253        gelu_approximate: Whether to use a GeLU approximation, in case the ONNX backend
254            does not have an efficient GeLU implementation.
255        use_stability_score: Whether to use the stability score instead of the predicted score.
256        return_extra_metrics: Whether to return a larger set of metrics.
257    """
258    export_image_encoder(model_type, output_root, export_name, checkpoint_path)
259    export_onnx_model(
260        model_type=model_type,
261        output_root=output_root,
262        opset=opset,
263        export_name=export_name,
264        checkpoint_path=checkpoint_path,
265        return_single_mask=return_single_mask,
266        gelu_approximate=gelu_approximate,
267        use_stability_score=use_stability_score,
268        return_extra_metrics=return_extra_metrics,
269    )
ENCODER_CONFIG = 'name: "%s"\nbackend: "pytorch"\nplatform: "pytorch_libtorch"\n\nmax_batch_size : 1\ninput [\n {\n name: "input0__0"\n data_type: TYPE_FP32\n dims: [3, -1, -1]\n }\n]\noutput [\n {\n name: "output0__0"\n data_type: TYPE_FP32\n dims: [256, 64, 64]\n }\n]\n\nparameters: {\n key: "INFERENCE_MODE"\n value: {\n string_value: "true"\n }\n}'
DECODER_CONFIG = 'name: "%s"\nbackend: "onnxruntime"\nplatform: "onnxruntime_onnx"\n\nparameters: {\n key: "INFERENCE_MODE"\n value: {\n string_value: "true"\n }\n}\n\ninstance_group {\n count: 1\n kind: KIND_CPU\n}'
def export_image_encoder( model_type: str, output_root: Union[str, os.PathLike], export_name: Optional[str] = None, checkpoint_path: Union[os.PathLike, str, NoneType] = None) -> None:
 68def export_image_encoder(
 69    model_type: str,
 70    output_root: Union[str, os.PathLike],
 71    export_name: Optional[str] = None,
 72    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
 73) -> None:
 74    """Export SAM image encoder to torchscript.
 75
 76    The torchscript image encoder can be used for predicting image embeddings
 77    with a backed, e.g. with [the bioengine](https://github.com/bioimage-io/bioengine-model-runner).
 78
 79    Args:
 80        model_type: The SAM model type.
 81        output_root: The output root directory where the exported model is saved.
 82        export_name: The name of the exported model.
 83        checkpoint_path: Optional checkpoint for loading the exported model.
 84    """
 85    if export_name is None:
 86        export_name = model_type
 87    name = f"sam-{export_name}-encoder"
 88
 89    output_folder = os.path.join(output_root, name)
 90    weight_output_folder = os.path.join(output_folder, "1")
 91    os.makedirs(weight_output_folder, exist_ok=True)
 92
 93    predictor = get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, device="cpu")
 94    encoder = predictor.model.image_encoder
 95
 96    encoder.eval()
 97    input_ = torch.rand(1, 3, 1024, 1024)
 98    traced_model = torch.jit.trace(encoder, input_)
 99    weight_path = os.path.join(weight_output_folder, "model.pt")
100    traced_model.save(weight_path)
101
102    config_output_path = os.path.join(output_folder, "config.pbtxt")
103    with open(config_output_path, "w") as f:
104        f.write(ENCODER_CONFIG % name)

Export SAM image encoder to torchscript.

The torchscript image encoder can be used for predicting image embeddings with a backed, e.g. with the bioengine.

Arguments:
  • model_type: The SAM model type.
  • output_root: The output root directory where the exported model is saved.
  • export_name: The name of the exported model.
  • checkpoint_path: Optional checkpoint for loading the exported model.
def export_onnx_model( model_type: str, output_root: Union[str, os.PathLike], opset: int = 17, export_name: Optional[str] = None, checkpoint_path: Union[os.PathLike, str, NoneType] = None, return_single_mask: bool = True, gelu_approximate: bool = False, use_stability_score: bool = False, return_extra_metrics: bool = False, quantize_model: bool = False) -> None:
107def export_onnx_model(
108    model_type: str,
109    output_root: Union[str, os.PathLike],
110    opset: int = 17,
111    export_name: Optional[str] = None,
112    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
113    return_single_mask: bool = True,
114    gelu_approximate: bool = False,
115    use_stability_score: bool = False,
116    return_extra_metrics: bool = False,
117    quantize_model: bool = False,
118) -> None:
119    """Export SAM prompt encoder and mask decoder to onnx.
120
121    The onnx encoder and decoder can be used for interactive segmentation in the browser.
122    This code is adapted from
123    https://github.com/facebookresearch/segment-anything/blob/main/scripts/export_onnx_model.py
124
125    Args:
126        model_type: The SAM model type.
127        output_root: The output root directory where the exported model is saved.
128        opset: The ONNX opset version. The recommended opset version is 17.
129        export_name: The name of the exported model.
130        checkpoint_path: Optional checkpoint for loading the SAM model.
131        return_single_mask: Whether the mask decoder returns a single or multiple masks.
132        gelu_approximate: Whether to use a GeLU approximation, in case the ONNX backend
133            does not have an efficient GeLU implementation.
134        use_stability_score: Whether to use the stability score instead of the predicted score.
135        return_extra_metrics: Whether to return a larger set of metrics.
136        quantize_model: Whether to also export a quantized version of the model.
137            This only works for onnxruntime < 1.17.
138    """
139    if export_name is None:
140        export_name = model_type
141    name = f"sam-{export_name}-decoder"
142
143    output_folder = os.path.join(output_root, name)
144    weight_output_folder = os.path.join(output_folder, "1")
145    os.makedirs(weight_output_folder, exist_ok=True)
146
147    _, sam = get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, return_sam=True, device="cpu")
148
149    weight_path = os.path.join(weight_output_folder, "model.onnx")
150
151    onnx_model = SamOnnxModel(
152        model=sam,
153        return_single_mask=return_single_mask,
154        use_stability_score=use_stability_score,
155        return_extra_metrics=return_extra_metrics,
156    )
157
158    if gelu_approximate:
159        for n, m in onnx_model.named_modules:
160            if isinstance(m, torch.nn.GELU):
161                m.approximate = "tanh"
162
163    dynamic_axes = {"point_coords": {1: "num_points"}, "point_labels": {1: "num_points"}}
164
165    embed_dim = sam.prompt_encoder.embed_dim
166    embed_size = sam.prompt_encoder.image_embedding_size
167
168    mask_input_size = [4 * x for x in embed_size]
169    dummy_inputs = {
170        "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
171        "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
172        "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
173        "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
174        "has_mask_input": torch.tensor([1], dtype=torch.float),
175        "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
176    }
177
178    _ = onnx_model(**dummy_inputs)
179
180    output_names = ["masks", "iou_predictions", "low_res_masks"]
181
182    with warnings.catch_warnings():
183        warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
184        warnings.filterwarnings("ignore", category=UserWarning)
185        with open(weight_path, "wb") as f:
186            print(f"Exporting onnx model to {weight_path}...")
187            torch.onnx.export(
188                onnx_model,
189                tuple(dummy_inputs.values()),
190                f,
191                export_params=True,
192                verbose=False,
193                opset_version=opset,
194                do_constant_folding=True,
195                input_names=list(dummy_inputs.keys()),
196                output_names=output_names,
197                dynamic_axes=dynamic_axes,
198            )
199
200    if onnxruntime_exists:
201        ort_inputs = {k: _to_numpy(v) for k, v in dummy_inputs.items()}
202        # set cpu provider default
203        providers = ["CPUExecutionProvider"]
204        ort_session = onnxruntime.InferenceSession(weight_path, providers=providers)
205        _ = ort_session.run(None, ort_inputs)
206        print("Model has successfully been run with ONNXRuntime.")
207
208    # This requires onnxruntime < 1.17.
209    # See https://github.com/facebookresearch/segment-anything/issues/699#issuecomment-1984670808
210    if quantize_model:
211        assert onnxruntime_exists
212        from onnxruntime.quantization import QuantType
213        from onnxruntime.quantization.quantize import quantize_dynamic
214
215        quantized_path = os.path.join(weight_output_folder, "model_quantized.onnx")
216        quantize_dynamic(
217            model_input=weight_path,
218            model_output=quantized_path,
219            # optimize_model=True,
220            per_channel=False,
221            reduce_range=False,
222            weight_type=QuantType.QUInt8,
223        )
224
225    config_output_path = os.path.join(output_folder, "config.pbtxt")
226    with open(config_output_path, "w") as f:
227        f.write(DECODER_CONFIG % name)

Export SAM prompt encoder and mask decoder to onnx.

The onnx encoder and decoder can be used for interactive segmentation in the browser. This code is adapted from https://github.com/facebookresearch/segment-anything/blob/main/scripts/export_onnx_model.py

Arguments:
  • model_type: The SAM model type.
  • output_root: The output root directory where the exported model is saved.
  • opset: The ONNX opset version. The recommended opset version is 17.
  • export_name: The name of the exported model.
  • checkpoint_path: Optional checkpoint for loading the SAM model.
  • return_single_mask: Whether the mask decoder returns a single or multiple masks.
  • gelu_approximate: Whether to use a GeLU approximation, in case the ONNX backend does not have an efficient GeLU implementation.
  • use_stability_score: Whether to use the stability score instead of the predicted score.
  • return_extra_metrics: Whether to return a larger set of metrics.
  • quantize_model: Whether to also export a quantized version of the model. This only works for onnxruntime < 1.17.
def export_bioengine_model( model_type: str, output_root: Union[str, os.PathLike], opset: int, export_name: Optional[str] = None, checkpoint_path: Union[os.PathLike, str, NoneType] = None, return_single_mask: bool = True, gelu_approximate: bool = False, use_stability_score: bool = False, return_extra_metrics: bool = False) -> None:
230def export_bioengine_model(
231    model_type: str,
232    output_root: Union[str, os.PathLike],
233    opset: int,
234    export_name: Optional[str] = None,
235    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
236    return_single_mask: bool = True,
237    gelu_approximate: bool = False,
238    use_stability_score: bool = False,
239    return_extra_metrics: bool = False,
240) -> None:
241    """Export SAM model to a format compatible with the BioEngine.
242
243    [The bioengine](https://github.com/bioimage-io/bioengine-model-runner) enables running the
244    image encoder on an online backend, so that SAM can be used in an online tool, or to predict
245    the image embeddings via the online backend rather than on CPU.
246
247    Args:
248        model_type: The SAM model type.
249        output_root: The output root directory where the exported model is saved.
250        opset: The ONNX opset version.
251        export_name: The name of the exported model.
252        checkpoint_path: Optional checkpoint for loading the SAM model.
253        return_single_mask: Whether the mask decoder returns a single or multiple masks.
254        gelu_approximate: Whether to use a GeLU approximation, in case the ONNX backend
255            does not have an efficient GeLU implementation.
256        use_stability_score: Whether to use the stability score instead of the predicted score.
257        return_extra_metrics: Whether to return a larger set of metrics.
258    """
259    export_image_encoder(model_type, output_root, export_name, checkpoint_path)
260    export_onnx_model(
261        model_type=model_type,
262        output_root=output_root,
263        opset=opset,
264        export_name=export_name,
265        checkpoint_path=checkpoint_path,
266        return_single_mask=return_single_mask,
267        gelu_approximate=gelu_approximate,
268        use_stability_score=use_stability_score,
269        return_extra_metrics=return_extra_metrics,
270    )

Export SAM model to a format compatible with the BioEngine.

The bioengine enables running the image encoder on an online backend, so that SAM can be used in an online tool, or to predict the image embeddings via the online backend rather than on CPU.

Arguments:
  • model_type: The SAM model type.
  • output_root: The output root directory where the exported model is saved.
  • opset: The ONNX opset version.
  • export_name: The name of the exported model.
  • checkpoint_path: Optional checkpoint for loading the SAM model.
  • return_single_mask: Whether the mask decoder returns a single or multiple masks.
  • gelu_approximate: Whether to use a GeLU approximation, in case the ONNX backend does not have an efficient GeLU implementation.
  • use_stability_score: Whether to use the stability score instead of the predicted score.
  • return_extra_metrics: Whether to return a larger set of metrics.