micro_sam.bioimageio.model_export

  1import os
  2import tempfile
  3from pathlib import Path
  4from typing import Optional, Union
  5
  6import xarray
  7import numpy as np
  8import matplotlib.pyplot as plt
  9
 10import torch
 11
 12import bioimageio.core
 13import bioimageio.spec.model.v0_5 as spec
 14from bioimageio.spec import save_bioimageio_package
 15from bioimageio.core.digest_spec import create_sample_for_model
 16
 17from .. import util
 18from ..prompt_generators import PointAndBoxPromptGenerator
 19from ..evaluation.model_comparison import _enhance_image, _overlay_outline, _overlay_box
 20from ..prompt_based_segmentation import _compute_logits_from_mask
 21from .predictor_adaptor import PredictorAdaptor
 22
 23
 24DEFAULTS = {
 25    "authors": [
 26        spec.Author(name="Anwai Archit", affiliation="University Goettingen", github_user="anwai98"),
 27        spec.Author(name="Constantin Pape", affiliation="University Goettingen", github_user="constantinpape"),
 28    ],
 29    "description": "Finetuned Segment Anything Model for Microscopy",
 30    "cite": [
 31        spec.CiteEntry(
 32            text="Archit et al. Segment Anything for Microscopy",
 33            doi=spec.Doi("10.1038/s41592-024-02580-4")
 34        ),
 35    ],
 36    "tags": ["segment-anything", "instance-segmentation"],
 37}
 38
 39# Reference: https://github.com/bioimage-io/spec-bioimage-io/commit/39d343681d427ec93cf69eef7597d9eb9678deb1#diff-0bbdaa8196fa31f945afabcf04a4295ff098f1f24400ef9e59b0f684d411905eL269  # noqa
 40# We had this parameter in bioimageio.spec. This has been removed. We just make a copy of the same parameter.
 41ARBITRARY_SIZE = spec.ParameterizedSize(min=1, step=1)
 42
 43
 44def _create_test_inputs_and_outputs(image, labels, model_type, checkpoint_path, tmp_dir):
 45
 46    # For now we just generate a single box prompt here, but we could also generate more input prompts.
 47    generator = PointAndBoxPromptGenerator(
 48        n_positive_points=1,
 49        n_negative_points=2,
 50        dilation_strength=2,
 51        get_point_prompts=True,
 52        get_box_prompts=True,
 53    )
 54    centers, bounding_boxes = util.get_centers_and_bounding_boxes(labels)
 55    masks = util.segmentation_to_one_hot(labels.astype("int64"), segmentation_ids=[1, 2])  # type: ignore
 56    point_prompts, point_labels, box_prompts, _ = generator(masks, [bounding_boxes[1], bounding_boxes[2]])
 57
 58    box_prompts = box_prompts.numpy()[None]
 59    point_prompts = point_prompts.numpy()[None]
 60    point_labels = point_labels.numpy()[None]
 61
 62    # Generate logits from the two
 63    mask_prompts = np.stack(
 64        [_compute_logits_from_mask(labels == 1), _compute_logits_from_mask(labels == 2)]
 65    )[None]
 66
 67    predictor = PredictorAdaptor(model_type=model_type)
 68    predictor.load_state_dict(torch.load(checkpoint_path))
 69
 70    input_ = util._to_image(image).transpose(2, 0, 1)[None]
 71    image_path = os.path.join(tmp_dir, "input.npy")
 72    np.save(image_path, input_)
 73
 74    masks, scores, embeddings = predictor(
 75        image=torch.from_numpy(input_),
 76        embeddings=None,
 77        box_prompts=torch.from_numpy(box_prompts),
 78        point_prompts=torch.from_numpy(point_prompts),
 79        point_labels=torch.from_numpy(point_labels),
 80        mask_prompts=torch.from_numpy(mask_prompts),
 81    )
 82
 83    box_prompt_path = os.path.join(tmp_dir, "box_prompts.npy")
 84    point_prompt_path = os.path.join(tmp_dir, "point_prompts.npy")
 85    point_label_path = os.path.join(tmp_dir, "point_labels.npy")
 86    mask_prompt_path = os.path.join(tmp_dir, "mask_prompts.npy")
 87    np.save(box_prompt_path, box_prompts.astype("int64"))
 88    np.save(point_prompt_path, point_prompts)
 89    np.save(point_label_path, point_labels)
 90    np.save(mask_prompt_path, mask_prompts)
 91
 92    mask_path = os.path.join(tmp_dir, "mask.npy")
 93    score_path = os.path.join(tmp_dir, "scores.npy")
 94    embed_path = os.path.join(tmp_dir, "embeddings.npy")
 95    np.save(mask_path, masks.numpy())
 96    np.save(score_path, scores.numpy())
 97    np.save(embed_path, embeddings.numpy())
 98
 99    inputs = {
100        "image": image_path,
101        "box_prompts": box_prompt_path,
102        "point_prompts": point_prompt_path,
103        "point_labels": point_label_path,
104        "mask_prompts": mask_prompt_path,
105    }
106    outputs = {"mask": mask_path, "score": score_path, "embeddings": embed_path}
107    return inputs, outputs
108
109
110def _write_documentation(doc, model_type, tmp_dir):
111    tmp_doc_path = os.path.join(tmp_dir, "documentation.md")
112
113    if doc is None:
114        with open(tmp_doc_path, "w") as f:
115            f.write("# Segment Anything for Microscopy\n")
116            f.write("We extend Segment Anything, a vision foundation model for image segmentation ")
117            f.write("by training specialized models for microscopy data.\n")
118        return tmp_doc_path
119
120    elif os.path.exists(doc):
121        return doc
122
123    else:
124        with open(tmp_doc_path, "w") as f:
125            f.write(doc)
126        return tmp_doc_path
127
128
129def _get_checkpoint(model_type, checkpoint_path, tmp_dir):
130    # If we don't have a checkpoint we get the corresponding model from the registry.
131    if checkpoint_path is None:
132        model_registry = util.models()
133        checkpoint_path = model_registry.fetch(model_type)
134        return checkpoint_path, None
135
136    # Otherwise we have to load the checkpoint to see if it is the state dict of an encoder,
137    # or the checkpoint for a custom SAM model.
138    state, model_state = util._load_checkpoint(checkpoint_path)
139
140    if "model_state" in state:  # This is a finetuning checkpoint -> we have to resave the state.
141        new_checkpoint_path = os.path.join(tmp_dir, f"{model_type}.pt")
142        torch.save(model_state, new_checkpoint_path)
143
144        # We may also have an instance segmentation decoder in that case.
145        # If we have it we also resave this one and return it.
146        if "decoder_state" in state:
147            decoder_path = os.path.join(tmp_dir, f"{model_type}_decoder.pt")
148            decoder_state = state["decoder_state"]
149            torch.save(decoder_state, decoder_path)
150        else:
151            decoder_path = None
152
153        return new_checkpoint_path, decoder_path
154
155    else:  # This is a SAM encoder state -> we don't have to resave.
156        return checkpoint_path, None
157
158
159# TODO: Update this with our latest yaml file updates.
160def _write_dependencies(dependency_file, require_mobile_sam):
161    content = """name: sam
162channels:
163    - pytorch
164    - conda-forge
165dependencies:
166    - segment-anything"""
167    if require_mobile_sam:
168        content += """
169    - timm
170    - pip:
171        - git+https://github.com/ChaoningZhang/MobileSAM.git"""
172    with open(dependency_file, "w") as f:
173        f.write(content)
174
175
176def _generate_covers(input_paths, result_paths, tmp_dir):
177    image = np.load(input_paths["image"]).squeeze()
178    prompts = np.load(input_paths["box_prompts"])
179    mask = np.load(result_paths["mask"])
180
181    # create the image overlay
182    if image.ndim == 2:
183        overlay = np.stack([image, image, image]).transpose((1, 2, 0))
184    elif image.shape[0] == 3:
185        overlay = image.transpose((1, 2, 0))
186    else:
187        overlay = image
188    overlay = _enhance_image(overlay.astype("float32"))
189
190    # overlay the mask as outline
191    overlay = _overlay_outline(overlay, mask[0, 0, 0], outline_dilation=2)
192
193    # overlay the bounding box prompt
194    prompt = prompts[0, 0][[1, 0, 3, 2]]
195    prompt = np.array([prompt[:2], prompt[2:]])
196    overlay = _overlay_box(overlay, prompt, outline_dilation=4)
197
198    # write  the cover image
199    fig, ax = plt.subplots(1)
200    ax.axis("off")
201    ax.imshow(overlay.astype("uint8"))
202    cover_path = os.path.join(tmp_dir, "cover.jpeg")
203    plt.savefig(cover_path, bbox_inches="tight")
204    plt.close()
205
206    covers = [cover_path]
207    return covers
208
209
210def _check_model(model_description, input_paths, result_paths):
211    # Load inputs.
212    image = xarray.DataArray(np.load(input_paths["image"]), dims=("batch", "channel", "y", "x"))
213    embeddings = xarray.DataArray(np.load(result_paths["embeddings"]), dims=("batch", "channel", "y", "x"))
214    box_prompts = xarray.DataArray(np.load(input_paths["box_prompts"]), dims=("batch", "object", "channel"))
215    point_prompts = xarray.DataArray(
216        np.load(input_paths["point_prompts"]), dims=("batch", "object", "point", "channel")
217    )
218    point_labels = xarray.DataArray(np.load(input_paths["point_labels"]), dims=("batch", "object", "point"))
219    mask_prompts = xarray.DataArray(np.load(input_paths["mask_prompts"]), dims=("batch", "object", "channel", "y", "x"))
220
221    # Load outputs.
222    mask = np.load(result_paths["mask"])
223
224    with bioimageio.core.create_prediction_pipeline(model_description) as pp:
225
226        # Check with all prompts. We only check the result for this setting,
227        # because this was used to generate the test data.
228        sample = create_sample_for_model(
229            model=model_description,
230            inputs={
231                "image": image,
232                "box_prompts": box_prompts,
233                "point_prompts": point_prompts,
234                "point_labels": point_labels,
235                "mask_prompts": mask_prompts,
236                "embeddings": embeddings,
237            },
238        ).as_single_block()
239        prediction = pp.predict_sample_block(sample)
240
241        predicted_mask = prediction.blocks["masks"].data.data
242        assert predicted_mask.shape == mask.shape
243        assert np.allclose(mask, predicted_mask)
244
245        # Run the checks with partial prompts.
246        prompt_kwargs = [
247            # With boxes.
248            {"box_prompts": box_prompts},
249            # With point prompts.
250            {"point_prompts": point_prompts, "point_labels": point_labels},
251            # With masks.
252            {"mask_prompts": mask_prompts},
253            # With boxes and points.
254            {"box_prompts": box_prompts, "point_prompts": point_prompts, "point_labels": point_labels},
255            # With boxes and masks.
256            {"box_prompts": box_prompts, "mask_prompts": mask_prompts},
257            # With points and masks.
258            {"mask_prompts": mask_prompts, "point_prompts": point_prompts, "point_labels": point_labels},
259        ]
260
261        for kwargs in prompt_kwargs:
262            sample = create_sample_for_model(
263                model=model_description, inputs={"image": image, "embeddings": embeddings, **kwargs},
264            ).as_single_block()
265            prediction = pp.predict_sample_block(sample)
266            predicted_mask = prediction.blocks["masks"].data.data
267            assert predicted_mask.shape == mask.shape
268
269
270def export_sam_model(
271    image: np.ndarray,
272    label_image: np.ndarray,
273    model_type: str,
274    name: str,
275    output_path: Union[str, os.PathLike],
276    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
277    **kwargs
278) -> None:
279    """Export SAM model to BioImage.IO model format.
280
281    The exported model can be uploaded to [bioimage.io](https://bioimage.io/#/) and
282    be used in tools that support the BioImage.IO model format.
283
284    Args:
285        image: The image for generating test data.
286        label_image: The segmentation correspoding to `image`.
287            It is used to derive prompt inputs for the model.
288        model_type: The type of the SAM model.
289        name: The name of the exported model.
290        output_path: Where the exported model is saved.
291        checkpoint_path: Optional checkpoint for loading the SAM model.
292    """
293    with tempfile.TemporaryDirectory() as tmp_dir:
294        checkpoint_path, decoder_path = _get_checkpoint(model_type, checkpoint_path, tmp_dir)
295        input_paths, result_paths = _create_test_inputs_and_outputs(
296            image, label_image, model_type, checkpoint_path, tmp_dir,
297        )
298        input_descriptions = [
299            # First input: the image data.
300            spec.InputTensorDescr(
301                id=spec.TensorId("image"),
302                axes=[
303                    spec.BatchAxis(size=1),
304                    # NOTE: to support 1 and 3 channels we can add another preprocessing.
305                    # Best solution: Have a pre-processing for this! (1C -> RGB)
306                    spec.ChannelAxis(channel_names=[spec.Identifier(cname) for cname in "RGB"]),
307                    spec.SpaceInputAxis(id=spec.AxisId("y"), size=ARBITRARY_SIZE),
308                    spec.SpaceInputAxis(id=spec.AxisId("x"), size=ARBITRARY_SIZE),
309                ],
310                test_tensor=spec.FileDescr(source=input_paths["image"]),
311                data=spec.IntervalOrRatioDataDescr(type="uint8")
312            ),
313
314            # Second input: the box prompts (optional)
315            spec.InputTensorDescr(
316                id=spec.TensorId("box_prompts"),
317                optional=True,
318                axes=[
319                    spec.BatchAxis(size=1),
320                    spec.IndexInputAxis(
321                        id=spec.AxisId("object"),
322                        size=ARBITRARY_SIZE
323                    ),
324                    spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "hwxy"]),
325                ],
326                test_tensor=spec.FileDescr(source=input_paths["box_prompts"]),
327                data=spec.IntervalOrRatioDataDescr(type="int64")
328            ),
329
330            # Third input: the point prompt coordinates (optional)
331            spec.InputTensorDescr(
332                id=spec.TensorId("point_prompts"),
333                optional=True,
334                axes=[
335                    spec.BatchAxis(size=1),
336                    spec.IndexInputAxis(
337                        id=spec.AxisId("object"),
338                        size=ARBITRARY_SIZE
339                    ),
340                    spec.IndexInputAxis(
341                        id=spec.AxisId("point"),
342                        size=ARBITRARY_SIZE
343                    ),
344                    spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "xy"]),
345                ],
346                test_tensor=spec.FileDescr(source=input_paths["point_prompts"]),
347                data=spec.IntervalOrRatioDataDescr(type="int64")
348            ),
349
350            # Fourth input: the point prompt labels (optional)
351            spec.InputTensorDescr(
352                id=spec.TensorId("point_labels"),
353                optional=True,
354                axes=[
355                    spec.BatchAxis(size=1),
356                    spec.IndexInputAxis(
357                        id=spec.AxisId("object"),
358                        size=ARBITRARY_SIZE
359                    ),
360                    spec.IndexInputAxis(
361                        id=spec.AxisId("point"),
362                        size=ARBITRARY_SIZE
363                    ),
364                ],
365                test_tensor=spec.FileDescr(source=input_paths["point_labels"]),
366                data=spec.IntervalOrRatioDataDescr(type="int64")
367            ),
368
369            # Fifth input: the mask prompts (optional)
370            spec.InputTensorDescr(
371                id=spec.TensorId("mask_prompts"),
372                optional=True,
373                axes=[
374                    spec.BatchAxis(size=1),
375                    spec.IndexInputAxis(
376                        id=spec.AxisId("object"),
377                        size=ARBITRARY_SIZE
378                    ),
379                    spec.ChannelAxis(channel_names=["channel"]),
380                    spec.SpaceInputAxis(id=spec.AxisId("y"), size=256),
381                    spec.SpaceInputAxis(id=spec.AxisId("x"), size=256),
382                ],
383                test_tensor=spec.FileDescr(source=input_paths["mask_prompts"]),
384                data=spec.IntervalOrRatioDataDescr(type="float32")
385            ),
386
387            # Sixth input: the image embeddings (optional)
388            spec.InputTensorDescr(
389                id=spec.TensorId("embeddings"),
390                optional=True,
391                axes=[
392                    spec.BatchAxis(size=1),
393                    # NOTE: we currently have to specify all the channel names
394                    # (It would be nice to also support size)
395                    spec.ChannelAxis(channel_names=[spec.Identifier(f"c{i}") for i in range(256)]),
396                    spec.SpaceInputAxis(id=spec.AxisId("y"), size=64),
397                    spec.SpaceInputAxis(id=spec.AxisId("x"), size=64),
398                ],
399                test_tensor=spec.FileDescr(source=result_paths["embeddings"]),
400                data=spec.IntervalOrRatioDataDescr(type="float32")
401            ),
402
403        ]
404
405        output_descriptions = [
406            # First output: The mask predictions.
407            spec.OutputTensorDescr(
408                id=spec.TensorId("masks"),
409                axes=[
410                    spec.BatchAxis(size=1),
411                    # NOTE: we use the data dependent size here to avoid dependency on optional inputs
412                    spec.IndexOutputAxis(
413                        id=spec.AxisId("object"), size=spec.DataDependentSize(),
414                    ),
415                    # NOTE: this could be a 3 once we use multi-masking
416                    spec.ChannelAxis(channel_names=[spec.Identifier("mask")]),
417                    spec.SpaceOutputAxis(
418                        id=spec.AxisId("y"),
419                        size=spec.SizeReference(
420                            tensor_id=spec.TensorId("image"), axis_id=spec.AxisId("y"),
421                        )
422                    ),
423                    spec.SpaceOutputAxis(
424                        id=spec.AxisId("x"),
425                        size=spec.SizeReference(
426                            tensor_id=spec.TensorId("image"), axis_id=spec.AxisId("x"),
427                        )
428                    )
429                ],
430                data=spec.IntervalOrRatioDataDescr(type="uint8"),
431                test_tensor=spec.FileDescr(source=result_paths["mask"])
432            ),
433
434            # The score predictions
435            spec.OutputTensorDescr(
436                id=spec.TensorId("scores"),
437                axes=[
438                    spec.BatchAxis(size=1),
439                    # NOTE: we use the data dependent size here to avoid dependency on optional inputs
440                    spec.IndexOutputAxis(
441                        id=spec.AxisId("object"), size=spec.DataDependentSize(),
442                    ),
443                    # NOTE: this could be a 3 once we use multi-masking
444                    spec.ChannelAxis(channel_names=[spec.Identifier("mask")]),
445                ],
446                data=spec.IntervalOrRatioDataDescr(type="float32"),
447                test_tensor=spec.FileDescr(source=result_paths["score"])
448            ),
449
450            # The image embeddings
451            spec.OutputTensorDescr(
452                id=spec.TensorId("embeddings"),
453                axes=[
454                    spec.BatchAxis(size=1),
455                    spec.ChannelAxis(channel_names=[spec.Identifier(f"c{i}") for i in range(256)]),
456                    spec.SpaceOutputAxis(id=spec.AxisId("y"), size=64),
457                    spec.SpaceOutputAxis(id=spec.AxisId("x"), size=64),
458                ],
459                data=spec.IntervalOrRatioDataDescr(type="float32"),
460                test_tensor=spec.FileDescr(source=result_paths["embeddings"])
461            )
462        ]
463
464        architecture_path = os.path.join(os.path.split(__file__)[0], "predictor_adaptor.py")
465        architecture = spec.ArchitectureFromFileDescr(
466            source=Path(architecture_path),
467            callable="PredictorAdaptor",
468            kwargs={"model_type": model_type}
469        )
470
471        dependency_file = os.path.join(tmp_dir, "environment.yaml")
472        _write_dependencies(dependency_file, require_mobile_sam=model_type.startswith("vit_t"))
473
474        weight_descriptions = spec.WeightsDescr(
475            pytorch_state_dict=spec.PytorchStateDictWeightsDescr(
476                source=Path(checkpoint_path),
477                architecture=architecture,
478                pytorch_version=spec.Version(torch.__version__),
479                dependencies=spec.EnvironmentFileDescr(source=dependency_file),
480            )
481        )
482
483        doc_path = _write_documentation(kwargs.get("documentation", None), model_type, tmp_dir)
484
485        covers = kwargs.get("covers", None)
486        if covers is None:
487            covers = _generate_covers(input_paths, result_paths, tmp_dir)
488        else:
489            assert all(os.path.exists(cov) for cov in covers)
490
491        # the uploader information is only added if explicitly passed
492        extra_kwargs = {}
493        if "id" in kwargs:
494            extra_kwargs["id"] = kwargs["id"]
495        if "id_emoji" in kwargs:
496            extra_kwargs["id_emoji"] = kwargs["id_emoji"]
497        if "uploader" in kwargs:
498            extra_kwargs["uploader"] = kwargs["uploader"]
499        if "version" in kwargs:
500            extra_kwargs["version"] = kwargs["version"]
501
502        if decoder_path is not None:
503            extra_kwargs["attachments"] = [spec.FileDescr(source=decoder_path)]
504
505        model_description = spec.ModelDescr(
506            name=name,
507            inputs=input_descriptions,
508            outputs=output_descriptions,
509            weights=weight_descriptions,
510            description=kwargs.get("description", DEFAULTS["description"]),
511            authors=kwargs.get("authors", DEFAULTS["authors"]),
512            cite=kwargs.get("cite", DEFAULTS["cite"]),
513            license=spec.LicenseId("CC-BY-4.0"),
514            documentation=Path(doc_path),
515            git_repo=spec.HttpUrl("https://github.com/computational-cell-analytics/micro-sam"),
516            tags=kwargs.get("tags", DEFAULTS["tags"]),
517            covers=covers,
518            **extra_kwargs,
519            # TODO write specific settings in the config
520            # dict with yaml values, key must be a str
521            # micro_sam: ...
522            # config=
523        )
524
525        _check_model(model_description, input_paths, result_paths)
526
527        save_bioimageio_package(model_description, output_path=output_path)
DEFAULTS = {'authors': [Author(affiliation='University Goettingen', email=None, orcid=None, name='Anwai Archit', github_user='anwai98'), Author(affiliation='University Goettingen', email=None, orcid=None, name='Constantin Pape', github_user='constantinpape')], 'description': 'Finetuned Segment Anything Model for Microscopy', 'cite': [CiteEntry(text='Archit et al. Segment Anything for Microscopy', doi='10.1038/s41592-024-02580-4', url=None)], 'tags': ['segment-anything', 'instance-segmentation']}
ARBITRARY_SIZE = ParameterizedSize(min=1, step=1)
def export_sam_model( image: numpy.ndarray, label_image: numpy.ndarray, model_type: str, name: str, output_path: Union[str, os.PathLike], checkpoint_path: Union[str, os.PathLike, NoneType] = None, **kwargs) -> None:
271def export_sam_model(
272    image: np.ndarray,
273    label_image: np.ndarray,
274    model_type: str,
275    name: str,
276    output_path: Union[str, os.PathLike],
277    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
278    **kwargs
279) -> None:
280    """Export SAM model to BioImage.IO model format.
281
282    The exported model can be uploaded to [bioimage.io](https://bioimage.io/#/) and
283    be used in tools that support the BioImage.IO model format.
284
285    Args:
286        image: The image for generating test data.
287        label_image: The segmentation correspoding to `image`.
288            It is used to derive prompt inputs for the model.
289        model_type: The type of the SAM model.
290        name: The name of the exported model.
291        output_path: Where the exported model is saved.
292        checkpoint_path: Optional checkpoint for loading the SAM model.
293    """
294    with tempfile.TemporaryDirectory() as tmp_dir:
295        checkpoint_path, decoder_path = _get_checkpoint(model_type, checkpoint_path, tmp_dir)
296        input_paths, result_paths = _create_test_inputs_and_outputs(
297            image, label_image, model_type, checkpoint_path, tmp_dir,
298        )
299        input_descriptions = [
300            # First input: the image data.
301            spec.InputTensorDescr(
302                id=spec.TensorId("image"),
303                axes=[
304                    spec.BatchAxis(size=1),
305                    # NOTE: to support 1 and 3 channels we can add another preprocessing.
306                    # Best solution: Have a pre-processing for this! (1C -> RGB)
307                    spec.ChannelAxis(channel_names=[spec.Identifier(cname) for cname in "RGB"]),
308                    spec.SpaceInputAxis(id=spec.AxisId("y"), size=ARBITRARY_SIZE),
309                    spec.SpaceInputAxis(id=spec.AxisId("x"), size=ARBITRARY_SIZE),
310                ],
311                test_tensor=spec.FileDescr(source=input_paths["image"]),
312                data=spec.IntervalOrRatioDataDescr(type="uint8")
313            ),
314
315            # Second input: the box prompts (optional)
316            spec.InputTensorDescr(
317                id=spec.TensorId("box_prompts"),
318                optional=True,
319                axes=[
320                    spec.BatchAxis(size=1),
321                    spec.IndexInputAxis(
322                        id=spec.AxisId("object"),
323                        size=ARBITRARY_SIZE
324                    ),
325                    spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "hwxy"]),
326                ],
327                test_tensor=spec.FileDescr(source=input_paths["box_prompts"]),
328                data=spec.IntervalOrRatioDataDescr(type="int64")
329            ),
330
331            # Third input: the point prompt coordinates (optional)
332            spec.InputTensorDescr(
333                id=spec.TensorId("point_prompts"),
334                optional=True,
335                axes=[
336                    spec.BatchAxis(size=1),
337                    spec.IndexInputAxis(
338                        id=spec.AxisId("object"),
339                        size=ARBITRARY_SIZE
340                    ),
341                    spec.IndexInputAxis(
342                        id=spec.AxisId("point"),
343                        size=ARBITRARY_SIZE
344                    ),
345                    spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "xy"]),
346                ],
347                test_tensor=spec.FileDescr(source=input_paths["point_prompts"]),
348                data=spec.IntervalOrRatioDataDescr(type="int64")
349            ),
350
351            # Fourth input: the point prompt labels (optional)
352            spec.InputTensorDescr(
353                id=spec.TensorId("point_labels"),
354                optional=True,
355                axes=[
356                    spec.BatchAxis(size=1),
357                    spec.IndexInputAxis(
358                        id=spec.AxisId("object"),
359                        size=ARBITRARY_SIZE
360                    ),
361                    spec.IndexInputAxis(
362                        id=spec.AxisId("point"),
363                        size=ARBITRARY_SIZE
364                    ),
365                ],
366                test_tensor=spec.FileDescr(source=input_paths["point_labels"]),
367                data=spec.IntervalOrRatioDataDescr(type="int64")
368            ),
369
370            # Fifth input: the mask prompts (optional)
371            spec.InputTensorDescr(
372                id=spec.TensorId("mask_prompts"),
373                optional=True,
374                axes=[
375                    spec.BatchAxis(size=1),
376                    spec.IndexInputAxis(
377                        id=spec.AxisId("object"),
378                        size=ARBITRARY_SIZE
379                    ),
380                    spec.ChannelAxis(channel_names=["channel"]),
381                    spec.SpaceInputAxis(id=spec.AxisId("y"), size=256),
382                    spec.SpaceInputAxis(id=spec.AxisId("x"), size=256),
383                ],
384                test_tensor=spec.FileDescr(source=input_paths["mask_prompts"]),
385                data=spec.IntervalOrRatioDataDescr(type="float32")
386            ),
387
388            # Sixth input: the image embeddings (optional)
389            spec.InputTensorDescr(
390                id=spec.TensorId("embeddings"),
391                optional=True,
392                axes=[
393                    spec.BatchAxis(size=1),
394                    # NOTE: we currently have to specify all the channel names
395                    # (It would be nice to also support size)
396                    spec.ChannelAxis(channel_names=[spec.Identifier(f"c{i}") for i in range(256)]),
397                    spec.SpaceInputAxis(id=spec.AxisId("y"), size=64),
398                    spec.SpaceInputAxis(id=spec.AxisId("x"), size=64),
399                ],
400                test_tensor=spec.FileDescr(source=result_paths["embeddings"]),
401                data=spec.IntervalOrRatioDataDescr(type="float32")
402            ),
403
404        ]
405
406        output_descriptions = [
407            # First output: The mask predictions.
408            spec.OutputTensorDescr(
409                id=spec.TensorId("masks"),
410                axes=[
411                    spec.BatchAxis(size=1),
412                    # NOTE: we use the data dependent size here to avoid dependency on optional inputs
413                    spec.IndexOutputAxis(
414                        id=spec.AxisId("object"), size=spec.DataDependentSize(),
415                    ),
416                    # NOTE: this could be a 3 once we use multi-masking
417                    spec.ChannelAxis(channel_names=[spec.Identifier("mask")]),
418                    spec.SpaceOutputAxis(
419                        id=spec.AxisId("y"),
420                        size=spec.SizeReference(
421                            tensor_id=spec.TensorId("image"), axis_id=spec.AxisId("y"),
422                        )
423                    ),
424                    spec.SpaceOutputAxis(
425                        id=spec.AxisId("x"),
426                        size=spec.SizeReference(
427                            tensor_id=spec.TensorId("image"), axis_id=spec.AxisId("x"),
428                        )
429                    )
430                ],
431                data=spec.IntervalOrRatioDataDescr(type="uint8"),
432                test_tensor=spec.FileDescr(source=result_paths["mask"])
433            ),
434
435            # The score predictions
436            spec.OutputTensorDescr(
437                id=spec.TensorId("scores"),
438                axes=[
439                    spec.BatchAxis(size=1),
440                    # NOTE: we use the data dependent size here to avoid dependency on optional inputs
441                    spec.IndexOutputAxis(
442                        id=spec.AxisId("object"), size=spec.DataDependentSize(),
443                    ),
444                    # NOTE: this could be a 3 once we use multi-masking
445                    spec.ChannelAxis(channel_names=[spec.Identifier("mask")]),
446                ],
447                data=spec.IntervalOrRatioDataDescr(type="float32"),
448                test_tensor=spec.FileDescr(source=result_paths["score"])
449            ),
450
451            # The image embeddings
452            spec.OutputTensorDescr(
453                id=spec.TensorId("embeddings"),
454                axes=[
455                    spec.BatchAxis(size=1),
456                    spec.ChannelAxis(channel_names=[spec.Identifier(f"c{i}") for i in range(256)]),
457                    spec.SpaceOutputAxis(id=spec.AxisId("y"), size=64),
458                    spec.SpaceOutputAxis(id=spec.AxisId("x"), size=64),
459                ],
460                data=spec.IntervalOrRatioDataDescr(type="float32"),
461                test_tensor=spec.FileDescr(source=result_paths["embeddings"])
462            )
463        ]
464
465        architecture_path = os.path.join(os.path.split(__file__)[0], "predictor_adaptor.py")
466        architecture = spec.ArchitectureFromFileDescr(
467            source=Path(architecture_path),
468            callable="PredictorAdaptor",
469            kwargs={"model_type": model_type}
470        )
471
472        dependency_file = os.path.join(tmp_dir, "environment.yaml")
473        _write_dependencies(dependency_file, require_mobile_sam=model_type.startswith("vit_t"))
474
475        weight_descriptions = spec.WeightsDescr(
476            pytorch_state_dict=spec.PytorchStateDictWeightsDescr(
477                source=Path(checkpoint_path),
478                architecture=architecture,
479                pytorch_version=spec.Version(torch.__version__),
480                dependencies=spec.EnvironmentFileDescr(source=dependency_file),
481            )
482        )
483
484        doc_path = _write_documentation(kwargs.get("documentation", None), model_type, tmp_dir)
485
486        covers = kwargs.get("covers", None)
487        if covers is None:
488            covers = _generate_covers(input_paths, result_paths, tmp_dir)
489        else:
490            assert all(os.path.exists(cov) for cov in covers)
491
492        # the uploader information is only added if explicitly passed
493        extra_kwargs = {}
494        if "id" in kwargs:
495            extra_kwargs["id"] = kwargs["id"]
496        if "id_emoji" in kwargs:
497            extra_kwargs["id_emoji"] = kwargs["id_emoji"]
498        if "uploader" in kwargs:
499            extra_kwargs["uploader"] = kwargs["uploader"]
500        if "version" in kwargs:
501            extra_kwargs["version"] = kwargs["version"]
502
503        if decoder_path is not None:
504            extra_kwargs["attachments"] = [spec.FileDescr(source=decoder_path)]
505
506        model_description = spec.ModelDescr(
507            name=name,
508            inputs=input_descriptions,
509            outputs=output_descriptions,
510            weights=weight_descriptions,
511            description=kwargs.get("description", DEFAULTS["description"]),
512            authors=kwargs.get("authors", DEFAULTS["authors"]),
513            cite=kwargs.get("cite", DEFAULTS["cite"]),
514            license=spec.LicenseId("CC-BY-4.0"),
515            documentation=Path(doc_path),
516            git_repo=spec.HttpUrl("https://github.com/computational-cell-analytics/micro-sam"),
517            tags=kwargs.get("tags", DEFAULTS["tags"]),
518            covers=covers,
519            **extra_kwargs,
520            # TODO write specific settings in the config
521            # dict with yaml values, key must be a str
522            # micro_sam: ...
523            # config=
524        )
525
526        _check_model(model_description, input_paths, result_paths)
527
528        save_bioimageio_package(model_description, output_path=output_path)

Export SAM model to BioImage.IO model format.

The exported model can be uploaded to bioimage.io and be used in tools that support the BioImage.IO model format.

Arguments:
  • image: The image for generating test data.
  • label_image: The segmentation correspoding to image. It is used to derive prompt inputs for the model.
  • model_type: The type of the SAM model.
  • name: The name of the exported model.
  • output_path: Where the exported model is saved.
  • checkpoint_path: Optional checkpoint for loading the SAM model.