micro_sam.bioimageio.model_export

  1import os
  2import tempfile
  3from pathlib import Path
  4from typing import Optional, Union
  5
  6import xarray
  7import numpy as np
  8import matplotlib.pyplot as plt
  9
 10import torch
 11
 12import bioimageio.core
 13import bioimageio.spec.model.v0_5 as spec
 14from bioimageio.spec import save_bioimageio_package
 15from bioimageio.core.digest_spec import create_sample_for_model
 16
 17from .. import util
 18from ..prompt_generators import PointAndBoxPromptGenerator
 19from ..evaluation.model_comparison import _enhance_image, _overlay_outline, _overlay_box
 20from ..prompt_based_segmentation import _compute_logits_from_mask
 21from .predictor_adaptor import PredictorAdaptor
 22
 23
 24DEFAULTS = {
 25    "authors": [
 26        spec.Author(name="Anwai Archit", affiliation="University Goettingen", github_user="anwai98"),
 27        spec.Author(name="Constantin Pape", affiliation="University Goettingen", github_user="constantinpape"),
 28    ],
 29    "description": "Finetuned Segment Anything Model for Microscopy",
 30    "cite": [
 31        spec.CiteEntry(
 32            text="Archit et al. Segment Anything for Microscopy",
 33            doi=spec.Doi("10.1038/s41592-024-02580-4")
 34        ),
 35    ],
 36    "tags": ["segment-anything", "instance-segmentation"],
 37}
 38
 39# Reference: https://github.com/bioimage-io/spec-bioimage-io/commit/39d343681d427ec93cf69eef7597d9eb9678deb1#diff-0bbdaa8196fa31f945afabcf04a4295ff098f1f24400ef9e59b0f684d411905eL269  # noqa
 40# We had this parameter in bioimageio.spec. This has been removed. We just make a copy of the same parameter.
 41ARBITRARY_SIZE = spec.ParameterizedSize(min=1, step=1)
 42
 43
 44def _create_test_inputs_and_outputs(image, labels, model_type, checkpoint_path, tmp_dir):
 45
 46    # For now we just generate a single box prompt here, but we could also generate more input prompts.
 47    generator = PointAndBoxPromptGenerator(
 48        n_positive_points=1,
 49        n_negative_points=2,
 50        dilation_strength=2,
 51        get_point_prompts=True,
 52        get_box_prompts=True,
 53    )
 54    centers, bounding_boxes = util.get_centers_and_bounding_boxes(labels)
 55    masks = util.segmentation_to_one_hot(labels.astype("int64"), segmentation_ids=[1, 2])  # type: ignore
 56    point_prompts, point_labels, box_prompts, _ = generator(masks, [bounding_boxes[1], bounding_boxes[2]])
 57
 58    box_prompts = box_prompts.numpy()[None]
 59    point_prompts = point_prompts.numpy()[None]
 60    point_labels = point_labels.numpy()[None]
 61
 62    # Generate logits from the two
 63    mask_prompts = np.stack(
 64        [_compute_logits_from_mask(labels == 1), _compute_logits_from_mask(labels == 2)]
 65    )[None]
 66
 67    predictor = PredictorAdaptor(model_type=model_type)
 68    predictor.load_state_dict(torch.load(checkpoint_path))
 69
 70    input_ = util._to_image(image).transpose(2, 0, 1)[None]
 71    image_path = os.path.join(tmp_dir, "input.npy")
 72    np.save(image_path, input_)
 73
 74    masks, scores, embeddings = predictor(
 75        image=torch.from_numpy(input_),
 76        embeddings=None,
 77        box_prompts=torch.from_numpy(box_prompts),
 78        point_prompts=torch.from_numpy(point_prompts),
 79        point_labels=torch.from_numpy(point_labels),
 80        mask_prompts=torch.from_numpy(mask_prompts),
 81    )
 82
 83    box_prompt_path = os.path.join(tmp_dir, "box_prompts.npy")
 84    point_prompt_path = os.path.join(tmp_dir, "point_prompts.npy")
 85    point_label_path = os.path.join(tmp_dir, "point_labels.npy")
 86    mask_prompt_path = os.path.join(tmp_dir, "mask_prompts.npy")
 87    np.save(box_prompt_path, box_prompts.astype("int64"))
 88    np.save(point_prompt_path, point_prompts)
 89    np.save(point_label_path, point_labels)
 90    np.save(mask_prompt_path, mask_prompts)
 91
 92    mask_path = os.path.join(tmp_dir, "mask.npy")
 93    score_path = os.path.join(tmp_dir, "scores.npy")
 94    embed_path = os.path.join(tmp_dir, "embeddings.npy")
 95    np.save(mask_path, masks.numpy())
 96    np.save(score_path, scores.numpy())
 97    np.save(embed_path, embeddings.numpy())
 98
 99    inputs = {
100        "image": image_path,
101        "box_prompts": box_prompt_path,
102        "point_prompts": point_prompt_path,
103        "point_labels": point_label_path,
104        "mask_prompts": mask_prompt_path,
105    }
106    outputs = {"mask": mask_path, "score": score_path, "embeddings": embed_path}
107    return inputs, outputs
108
109
110def _write_documentation(doc, model_type, tmp_dir):
111    tmp_doc_path = os.path.join(tmp_dir, "documentation.md")
112
113    if doc is None:
114        with open(tmp_doc_path, "w") as f:
115            f.write("# Segment Anything for Microscopy\n")
116            f.write("We extend Segment Anything, a vision foundation model for image segmentation ")
117            f.write("by training specialized models for microscopy data.\n")
118        return tmp_doc_path
119
120    elif os.path.exists(doc):
121        return doc
122
123    else:
124        with open(tmp_doc_path, "w") as f:
125            f.write(doc)
126        return tmp_doc_path
127
128
129def _get_checkpoint(model_type, checkpoint_path, tmp_dir):
130    # If we don't have a checkpoint we get the corresponding model from the registry.
131    if checkpoint_path is None:
132        model_registry = util.models()
133        checkpoint_path = model_registry.fetch(model_type)
134        return checkpoint_path, None
135
136    # Otherwise we have to load the checkpoint to see if it is the state dict of an encoder,
137    # or the checkpoint for a custom SAM model.
138    state, model_state = util._load_checkpoint(checkpoint_path)
139
140    if "model_state" in state:  # This is a finetuning checkpoint -> we have to resave the state.
141        new_checkpoint_path = os.path.join(tmp_dir, f"{model_type}.pt")
142        torch.save(model_state, new_checkpoint_path)
143
144        # We may also have an instance segmentation decoder in that case.
145        # If we have it we also resave this one and return it.
146        if "decoder_state" in state:
147            decoder_path = os.path.join(tmp_dir, f"{model_type}_decoder.pt")
148            decoder_state = state["decoder_state"]
149            torch.save(decoder_state, decoder_path)
150        else:
151            decoder_path = None
152
153        return new_checkpoint_path, decoder_path
154
155    else:  # This is a SAM encoder state -> we don't have to resave.
156        return checkpoint_path, None
157
158
159# TODO: Update this with our latest yaml file updates.
160def _write_dependencies(dependency_file, require_mobile_sam):
161    content = """name: sam
162channels:
163    - pytorch
164    - conda-forge
165dependencies:
166    - segment-anything"""
167    if require_mobile_sam:
168        content += """
169    - timm
170    - pip:
171        - git+https://github.com/ChaoningZhang/MobileSAM.git"""
172    with open(dependency_file, "w") as f:
173        f.write(content)
174
175
176def _generate_covers(input_paths, result_paths, tmp_dir):
177    image = np.load(input_paths["image"]).squeeze()
178    prompts = np.load(input_paths["box_prompts"])
179    mask = np.load(result_paths["mask"])
180
181    # create the image overlay
182    if image.ndim == 2:
183        overlay = np.stack([image, image, image]).transpose((1, 2, 0))
184    elif image.shape[0] == 3:
185        overlay = image.transpose((1, 2, 0))
186    else:
187        overlay = image
188    overlay = _enhance_image(overlay.astype("float32"))
189
190    # overlay the mask as outline
191    overlay = _overlay_outline(overlay, mask[0, 0, 0], outline_dilation=2)
192
193    # overlay the bounding box prompt
194    prompt = prompts[0, 0][[1, 0, 3, 2]]
195    prompt = np.array([prompt[:2], prompt[2:]])
196    overlay = _overlay_box(overlay, prompt, outline_dilation=4)
197
198    # write  the cover image
199    fig, ax = plt.subplots(1)
200    ax.axis("off")
201    ax.imshow(overlay.astype("uint8"))
202    cover_path = os.path.join(tmp_dir, "cover.jpeg")
203    plt.savefig(cover_path, bbox_inches="tight")
204    plt.close()
205
206    covers = [cover_path]
207    return covers
208
209
210def _check_model(model_description, input_paths, result_paths):
211    # Load inputs.
212    image = xarray.DataArray(np.load(input_paths["image"]), dims=tuple("bcyx"))
213    embeddings = xarray.DataArray(np.load(result_paths["embeddings"]), dims=tuple("bcyx"))
214    box_prompts = xarray.DataArray(np.load(input_paths["box_prompts"]), dims=tuple("bic"))
215    point_prompts = xarray.DataArray(np.load(input_paths["point_prompts"]), dims=tuple("bhwc"))
216    point_labels = xarray.DataArray(np.load(input_paths["point_labels"]), dims=tuple("bic"))
217    mask_prompts = xarray.DataArray(np.load(input_paths["mask_prompts"]), dims=tuple("bicyx"))
218
219    # Load outputs.
220    mask = np.load(result_paths["mask"])
221
222    with bioimageio.core.create_prediction_pipeline(model_description) as pp:
223
224        # Check with all prompts. We only check the result for this setting,
225        # because this was used to generate the test data.
226        sample = create_sample_for_model(
227            model=model_description,
228            image=image,
229            box_prompts=box_prompts,
230            point_prompts=point_prompts,
231            point_labels=point_labels,
232            mask_prompts=mask_prompts,
233            embeddings=embeddings,
234        ).as_single_block()
235        prediction = pp.predict_sample_block(sample)
236
237        predicted_mask = prediction.blocks["masks"].data.data
238        assert predicted_mask.shape == mask.shape
239        assert np.allclose(mask, predicted_mask)
240
241        # Run the checks with partial prompts.
242        prompt_kwargs = [
243            # With boxes.
244            {"box_prompts": box_prompts},
245            # With point prompts.
246            {"point_prompts": point_prompts, "point_labels": point_labels},
247            # With masks.
248            {"mask_prompts": mask_prompts},
249            # With boxes and points.
250            {"box_prompts": box_prompts, "point_prompts": point_prompts, "point_labels": point_labels},
251            # With boxes and masks.
252            {"box_prompts": box_prompts, "mask_prompts": mask_prompts},
253            # With points and masks.
254            {"mask_prompts": mask_prompts, "point_prompts": point_prompts, "point_labels": point_labels},
255        ]
256
257        for kwargs in prompt_kwargs:
258            sample = create_sample_for_model(
259                model=model_description, image=image, embeddings=embeddings, **kwargs
260            ).as_single_block()
261            prediction = pp.predict_sample_block(sample)
262            predicted_mask = prediction.blocks["masks"].data.data
263            assert predicted_mask.shape == mask.shape
264
265
266def export_sam_model(
267    image: np.ndarray,
268    label_image: np.ndarray,
269    model_type: str,
270    name: str,
271    output_path: Union[str, os.PathLike],
272    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
273    **kwargs
274) -> None:
275    """Export SAM model to BioImage.IO model format.
276
277    The exported model can be uploaded to [bioimage.io](https://bioimage.io/#/) and
278    be used in tools that support the BioImage.IO model format.
279
280    Args:
281        image: The image for generating test data.
282        label_image: The segmentation correspoding to `image`.
283            It is used to derive prompt inputs for the model.
284        model_type: The type of the SAM model.
285        name: The name of the exported model.
286        output_path: Where the exported model is saved.
287        checkpoint_path: Optional checkpoint for loading the SAM model.
288    """
289    with tempfile.TemporaryDirectory() as tmp_dir:
290        checkpoint_path, decoder_path = _get_checkpoint(model_type, checkpoint_path, tmp_dir)
291        input_paths, result_paths = _create_test_inputs_and_outputs(
292            image, label_image, model_type, checkpoint_path, tmp_dir,
293        )
294        input_descriptions = [
295            # First input: the image data.
296            spec.InputTensorDescr(
297                id=spec.TensorId("image"),
298                axes=[
299                    spec.BatchAxis(size=1),
300                    # NOTE: to support 1 and 3 channels we can add another preprocessing.
301                    # Best solution: Have a pre-processing for this! (1C -> RGB)
302                    spec.ChannelAxis(channel_names=[spec.Identifier(cname) for cname in "RGB"]),
303                    spec.SpaceInputAxis(id=spec.AxisId("y"), size=ARBITRARY_SIZE),
304                    spec.SpaceInputAxis(id=spec.AxisId("x"), size=ARBITRARY_SIZE),
305                ],
306                test_tensor=spec.FileDescr(source=input_paths["image"]),
307                data=spec.IntervalOrRatioDataDescr(type="uint8")
308            ),
309
310            # Second input: the box prompts (optional)
311            spec.InputTensorDescr(
312                id=spec.TensorId("box_prompts"),
313                optional=True,
314                axes=[
315                    spec.BatchAxis(size=1),
316                    spec.IndexInputAxis(
317                        id=spec.AxisId("object"),
318                        size=ARBITRARY_SIZE
319                    ),
320                    spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "hwxy"]),
321                ],
322                test_tensor=spec.FileDescr(source=input_paths["box_prompts"]),
323                data=spec.IntervalOrRatioDataDescr(type="int64")
324            ),
325
326            # Third input: the point prompt coordinates (optional)
327            spec.InputTensorDescr(
328                id=spec.TensorId("point_prompts"),
329                optional=True,
330                axes=[
331                    spec.BatchAxis(size=1),
332                    spec.IndexInputAxis(
333                        id=spec.AxisId("object"),
334                        size=ARBITRARY_SIZE
335                    ),
336                    spec.IndexInputAxis(
337                        id=spec.AxisId("point"),
338                        size=ARBITRARY_SIZE
339                    ),
340                    spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "xy"]),
341                ],
342                test_tensor=spec.FileDescr(source=input_paths["point_prompts"]),
343                data=spec.IntervalOrRatioDataDescr(type="int64")
344            ),
345
346            # Fourth input: the point prompt labels (optional)
347            spec.InputTensorDescr(
348                id=spec.TensorId("point_labels"),
349                optional=True,
350                axes=[
351                    spec.BatchAxis(size=1),
352                    spec.IndexInputAxis(
353                        id=spec.AxisId("object"),
354                        size=ARBITRARY_SIZE
355                    ),
356                    spec.IndexInputAxis(
357                        id=spec.AxisId("point"),
358                        size=ARBITRARY_SIZE
359                    ),
360                ],
361                test_tensor=spec.FileDescr(source=input_paths["point_labels"]),
362                data=spec.IntervalOrRatioDataDescr(type="int64")
363            ),
364
365            # Fifth input: the mask prompts (optional)
366            spec.InputTensorDescr(
367                id=spec.TensorId("mask_prompts"),
368                optional=True,
369                axes=[
370                    spec.BatchAxis(size=1),
371                    spec.IndexInputAxis(
372                        id=spec.AxisId("object"),
373                        size=ARBITRARY_SIZE
374                    ),
375                    spec.ChannelAxis(channel_names=["channel"]),
376                    spec.SpaceInputAxis(id=spec.AxisId("y"), size=256),
377                    spec.SpaceInputAxis(id=spec.AxisId("x"), size=256),
378                ],
379                test_tensor=spec.FileDescr(source=input_paths["mask_prompts"]),
380                data=spec.IntervalOrRatioDataDescr(type="float32")
381            ),
382
383            # Sixth input: the image embeddings (optional)
384            spec.InputTensorDescr(
385                id=spec.TensorId("embeddings"),
386                optional=True,
387                axes=[
388                    spec.BatchAxis(size=1),
389                    # NOTE: we currently have to specify all the channel names
390                    # (It would be nice to also support size)
391                    spec.ChannelAxis(channel_names=[spec.Identifier(f"c{i}") for i in range(256)]),
392                    spec.SpaceInputAxis(id=spec.AxisId("y"), size=64),
393                    spec.SpaceInputAxis(id=spec.AxisId("x"), size=64),
394                ],
395                test_tensor=spec.FileDescr(source=result_paths["embeddings"]),
396                data=spec.IntervalOrRatioDataDescr(type="float32")
397            ),
398
399        ]
400
401        output_descriptions = [
402            # First output: The mask predictions.
403            spec.OutputTensorDescr(
404                id=spec.TensorId("masks"),
405                axes=[
406                    spec.BatchAxis(size=1),
407                    # NOTE: we use the data dependent size here to avoid dependency on optional inputs
408                    spec.IndexOutputAxis(
409                        id=spec.AxisId("object"), size=spec.DataDependentSize(),
410                    ),
411                    # NOTE: this could be a 3 once we use multi-masking
412                    spec.ChannelAxis(channel_names=[spec.Identifier("mask")]),
413                    spec.SpaceOutputAxis(
414                        id=spec.AxisId("y"),
415                        size=spec.SizeReference(
416                            tensor_id=spec.TensorId("image"), axis_id=spec.AxisId("y"),
417                        )
418                    ),
419                    spec.SpaceOutputAxis(
420                        id=spec.AxisId("x"),
421                        size=spec.SizeReference(
422                            tensor_id=spec.TensorId("image"), axis_id=spec.AxisId("x"),
423                        )
424                    )
425                ],
426                data=spec.IntervalOrRatioDataDescr(type="uint8"),
427                test_tensor=spec.FileDescr(source=result_paths["mask"])
428            ),
429
430            # The score predictions
431            spec.OutputTensorDescr(
432                id=spec.TensorId("scores"),
433                axes=[
434                    spec.BatchAxis(size=1),
435                    # NOTE: we use the data dependent size here to avoid dependency on optional inputs
436                    spec.IndexOutputAxis(
437                        id=spec.AxisId("object"), size=spec.DataDependentSize(),
438                    ),
439                    # NOTE: this could be a 3 once we use multi-masking
440                    spec.ChannelAxis(channel_names=[spec.Identifier("mask")]),
441                ],
442                data=spec.IntervalOrRatioDataDescr(type="float32"),
443                test_tensor=spec.FileDescr(source=result_paths["score"])
444            ),
445
446            # The image embeddings
447            spec.OutputTensorDescr(
448                id=spec.TensorId("embeddings"),
449                axes=[
450                    spec.BatchAxis(size=1),
451                    spec.ChannelAxis(channel_names=[spec.Identifier(f"c{i}") for i in range(256)]),
452                    spec.SpaceOutputAxis(id=spec.AxisId("y"), size=64),
453                    spec.SpaceOutputAxis(id=spec.AxisId("x"), size=64),
454                ],
455                data=spec.IntervalOrRatioDataDescr(type="float32"),
456                test_tensor=spec.FileDescr(source=result_paths["embeddings"])
457            )
458        ]
459
460        architecture_path = os.path.join(os.path.split(__file__)[0], "predictor_adaptor.py")
461        architecture = spec.ArchitectureFromFileDescr(
462            source=Path(architecture_path),
463            callable="PredictorAdaptor",
464            kwargs={"model_type": model_type}
465        )
466
467        dependency_file = os.path.join(tmp_dir, "environment.yaml")
468        _write_dependencies(dependency_file, require_mobile_sam=model_type.startswith("vit_t"))
469
470        weight_descriptions = spec.WeightsDescr(
471            pytorch_state_dict=spec.PytorchStateDictWeightsDescr(
472                source=Path(checkpoint_path),
473                architecture=architecture,
474                pytorch_version=spec.Version(torch.__version__),
475                dependencies=spec.EnvironmentFileDescr(source=dependency_file),
476            )
477        )
478
479        doc_path = _write_documentation(kwargs.get("documentation", None), model_type, tmp_dir)
480
481        covers = kwargs.get("covers", None)
482        if covers is None:
483            covers = _generate_covers(input_paths, result_paths, tmp_dir)
484        else:
485            assert all(os.path.exists(cov) for cov in covers)
486
487        # the uploader information is only added if explicitly passed
488        extra_kwargs = {}
489        if "id" in kwargs:
490            extra_kwargs["id"] = kwargs["id"]
491        if "id_emoji" in kwargs:
492            extra_kwargs["id_emoji"] = kwargs["id_emoji"]
493        if "uploader" in kwargs:
494            extra_kwargs["uploader"] = kwargs["uploader"]
495        if "version" in kwargs:
496            extra_kwargs["version"] = kwargs["version"]
497
498        if decoder_path is not None:
499            extra_kwargs["attachments"] = [spec.FileDescr(source=decoder_path)]
500
501        model_description = spec.ModelDescr(
502            name=name,
503            inputs=input_descriptions,
504            outputs=output_descriptions,
505            weights=weight_descriptions,
506            description=kwargs.get("description", DEFAULTS["description"]),
507            authors=kwargs.get("authors", DEFAULTS["authors"]),
508            cite=kwargs.get("cite", DEFAULTS["cite"]),
509            license=spec.LicenseId("CC-BY-4.0"),
510            documentation=Path(doc_path),
511            git_repo=spec.HttpUrl("https://github.com/computational-cell-analytics/micro-sam"),
512            tags=kwargs.get("tags", DEFAULTS["tags"]),
513            covers=covers,
514            **extra_kwargs,
515            # TODO write specific settings in the config
516            # dict with yaml values, key must be a str
517            # micro_sam: ...
518            # config=
519        )
520
521        _check_model(model_description, input_paths, result_paths)
522
523        save_bioimageio_package(model_description, output_path=output_path)
DEFAULTS = {'authors': [Author(affiliation='University Goettingen', email=None, orcid=None, name='Anwai Archit', github_user='anwai98'), Author(affiliation='University Goettingen', email=None, orcid=None, name='Constantin Pape', github_user='constantinpape')], 'description': 'Finetuned Segment Anything Model for Microscopy', 'cite': [CiteEntry(text='Archit et al. Segment Anything for Microscopy', doi='10.1038/s41592-024-02580-4', url=None)], 'tags': ['segment-anything', 'instance-segmentation']}
ARBITRARY_SIZE = ParameterizedSize(min=1, step=1)
def export_sam_model( image: numpy.ndarray, label_image: numpy.ndarray, model_type: str, name: str, output_path: Union[str, os.PathLike], checkpoint_path: Union[str, os.PathLike, NoneType] = None, **kwargs) -> None:
267def export_sam_model(
268    image: np.ndarray,
269    label_image: np.ndarray,
270    model_type: str,
271    name: str,
272    output_path: Union[str, os.PathLike],
273    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
274    **kwargs
275) -> None:
276    """Export SAM model to BioImage.IO model format.
277
278    The exported model can be uploaded to [bioimage.io](https://bioimage.io/#/) and
279    be used in tools that support the BioImage.IO model format.
280
281    Args:
282        image: The image for generating test data.
283        label_image: The segmentation correspoding to `image`.
284            It is used to derive prompt inputs for the model.
285        model_type: The type of the SAM model.
286        name: The name of the exported model.
287        output_path: Where the exported model is saved.
288        checkpoint_path: Optional checkpoint for loading the SAM model.
289    """
290    with tempfile.TemporaryDirectory() as tmp_dir:
291        checkpoint_path, decoder_path = _get_checkpoint(model_type, checkpoint_path, tmp_dir)
292        input_paths, result_paths = _create_test_inputs_and_outputs(
293            image, label_image, model_type, checkpoint_path, tmp_dir,
294        )
295        input_descriptions = [
296            # First input: the image data.
297            spec.InputTensorDescr(
298                id=spec.TensorId("image"),
299                axes=[
300                    spec.BatchAxis(size=1),
301                    # NOTE: to support 1 and 3 channels we can add another preprocessing.
302                    # Best solution: Have a pre-processing for this! (1C -> RGB)
303                    spec.ChannelAxis(channel_names=[spec.Identifier(cname) for cname in "RGB"]),
304                    spec.SpaceInputAxis(id=spec.AxisId("y"), size=ARBITRARY_SIZE),
305                    spec.SpaceInputAxis(id=spec.AxisId("x"), size=ARBITRARY_SIZE),
306                ],
307                test_tensor=spec.FileDescr(source=input_paths["image"]),
308                data=spec.IntervalOrRatioDataDescr(type="uint8")
309            ),
310
311            # Second input: the box prompts (optional)
312            spec.InputTensorDescr(
313                id=spec.TensorId("box_prompts"),
314                optional=True,
315                axes=[
316                    spec.BatchAxis(size=1),
317                    spec.IndexInputAxis(
318                        id=spec.AxisId("object"),
319                        size=ARBITRARY_SIZE
320                    ),
321                    spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "hwxy"]),
322                ],
323                test_tensor=spec.FileDescr(source=input_paths["box_prompts"]),
324                data=spec.IntervalOrRatioDataDescr(type="int64")
325            ),
326
327            # Third input: the point prompt coordinates (optional)
328            spec.InputTensorDescr(
329                id=spec.TensorId("point_prompts"),
330                optional=True,
331                axes=[
332                    spec.BatchAxis(size=1),
333                    spec.IndexInputAxis(
334                        id=spec.AxisId("object"),
335                        size=ARBITRARY_SIZE
336                    ),
337                    spec.IndexInputAxis(
338                        id=spec.AxisId("point"),
339                        size=ARBITRARY_SIZE
340                    ),
341                    spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "xy"]),
342                ],
343                test_tensor=spec.FileDescr(source=input_paths["point_prompts"]),
344                data=spec.IntervalOrRatioDataDescr(type="int64")
345            ),
346
347            # Fourth input: the point prompt labels (optional)
348            spec.InputTensorDescr(
349                id=spec.TensorId("point_labels"),
350                optional=True,
351                axes=[
352                    spec.BatchAxis(size=1),
353                    spec.IndexInputAxis(
354                        id=spec.AxisId("object"),
355                        size=ARBITRARY_SIZE
356                    ),
357                    spec.IndexInputAxis(
358                        id=spec.AxisId("point"),
359                        size=ARBITRARY_SIZE
360                    ),
361                ],
362                test_tensor=spec.FileDescr(source=input_paths["point_labels"]),
363                data=spec.IntervalOrRatioDataDescr(type="int64")
364            ),
365
366            # Fifth input: the mask prompts (optional)
367            spec.InputTensorDescr(
368                id=spec.TensorId("mask_prompts"),
369                optional=True,
370                axes=[
371                    spec.BatchAxis(size=1),
372                    spec.IndexInputAxis(
373                        id=spec.AxisId("object"),
374                        size=ARBITRARY_SIZE
375                    ),
376                    spec.ChannelAxis(channel_names=["channel"]),
377                    spec.SpaceInputAxis(id=spec.AxisId("y"), size=256),
378                    spec.SpaceInputAxis(id=spec.AxisId("x"), size=256),
379                ],
380                test_tensor=spec.FileDescr(source=input_paths["mask_prompts"]),
381                data=spec.IntervalOrRatioDataDescr(type="float32")
382            ),
383
384            # Sixth input: the image embeddings (optional)
385            spec.InputTensorDescr(
386                id=spec.TensorId("embeddings"),
387                optional=True,
388                axes=[
389                    spec.BatchAxis(size=1),
390                    # NOTE: we currently have to specify all the channel names
391                    # (It would be nice to also support size)
392                    spec.ChannelAxis(channel_names=[spec.Identifier(f"c{i}") for i in range(256)]),
393                    spec.SpaceInputAxis(id=spec.AxisId("y"), size=64),
394                    spec.SpaceInputAxis(id=spec.AxisId("x"), size=64),
395                ],
396                test_tensor=spec.FileDescr(source=result_paths["embeddings"]),
397                data=spec.IntervalOrRatioDataDescr(type="float32")
398            ),
399
400        ]
401
402        output_descriptions = [
403            # First output: The mask predictions.
404            spec.OutputTensorDescr(
405                id=spec.TensorId("masks"),
406                axes=[
407                    spec.BatchAxis(size=1),
408                    # NOTE: we use the data dependent size here to avoid dependency on optional inputs
409                    spec.IndexOutputAxis(
410                        id=spec.AxisId("object"), size=spec.DataDependentSize(),
411                    ),
412                    # NOTE: this could be a 3 once we use multi-masking
413                    spec.ChannelAxis(channel_names=[spec.Identifier("mask")]),
414                    spec.SpaceOutputAxis(
415                        id=spec.AxisId("y"),
416                        size=spec.SizeReference(
417                            tensor_id=spec.TensorId("image"), axis_id=spec.AxisId("y"),
418                        )
419                    ),
420                    spec.SpaceOutputAxis(
421                        id=spec.AxisId("x"),
422                        size=spec.SizeReference(
423                            tensor_id=spec.TensorId("image"), axis_id=spec.AxisId("x"),
424                        )
425                    )
426                ],
427                data=spec.IntervalOrRatioDataDescr(type="uint8"),
428                test_tensor=spec.FileDescr(source=result_paths["mask"])
429            ),
430
431            # The score predictions
432            spec.OutputTensorDescr(
433                id=spec.TensorId("scores"),
434                axes=[
435                    spec.BatchAxis(size=1),
436                    # NOTE: we use the data dependent size here to avoid dependency on optional inputs
437                    spec.IndexOutputAxis(
438                        id=spec.AxisId("object"), size=spec.DataDependentSize(),
439                    ),
440                    # NOTE: this could be a 3 once we use multi-masking
441                    spec.ChannelAxis(channel_names=[spec.Identifier("mask")]),
442                ],
443                data=spec.IntervalOrRatioDataDescr(type="float32"),
444                test_tensor=spec.FileDescr(source=result_paths["score"])
445            ),
446
447            # The image embeddings
448            spec.OutputTensorDescr(
449                id=spec.TensorId("embeddings"),
450                axes=[
451                    spec.BatchAxis(size=1),
452                    spec.ChannelAxis(channel_names=[spec.Identifier(f"c{i}") for i in range(256)]),
453                    spec.SpaceOutputAxis(id=spec.AxisId("y"), size=64),
454                    spec.SpaceOutputAxis(id=spec.AxisId("x"), size=64),
455                ],
456                data=spec.IntervalOrRatioDataDescr(type="float32"),
457                test_tensor=spec.FileDescr(source=result_paths["embeddings"])
458            )
459        ]
460
461        architecture_path = os.path.join(os.path.split(__file__)[0], "predictor_adaptor.py")
462        architecture = spec.ArchitectureFromFileDescr(
463            source=Path(architecture_path),
464            callable="PredictorAdaptor",
465            kwargs={"model_type": model_type}
466        )
467
468        dependency_file = os.path.join(tmp_dir, "environment.yaml")
469        _write_dependencies(dependency_file, require_mobile_sam=model_type.startswith("vit_t"))
470
471        weight_descriptions = spec.WeightsDescr(
472            pytorch_state_dict=spec.PytorchStateDictWeightsDescr(
473                source=Path(checkpoint_path),
474                architecture=architecture,
475                pytorch_version=spec.Version(torch.__version__),
476                dependencies=spec.EnvironmentFileDescr(source=dependency_file),
477            )
478        )
479
480        doc_path = _write_documentation(kwargs.get("documentation", None), model_type, tmp_dir)
481
482        covers = kwargs.get("covers", None)
483        if covers is None:
484            covers = _generate_covers(input_paths, result_paths, tmp_dir)
485        else:
486            assert all(os.path.exists(cov) for cov in covers)
487
488        # the uploader information is only added if explicitly passed
489        extra_kwargs = {}
490        if "id" in kwargs:
491            extra_kwargs["id"] = kwargs["id"]
492        if "id_emoji" in kwargs:
493            extra_kwargs["id_emoji"] = kwargs["id_emoji"]
494        if "uploader" in kwargs:
495            extra_kwargs["uploader"] = kwargs["uploader"]
496        if "version" in kwargs:
497            extra_kwargs["version"] = kwargs["version"]
498
499        if decoder_path is not None:
500            extra_kwargs["attachments"] = [spec.FileDescr(source=decoder_path)]
501
502        model_description = spec.ModelDescr(
503            name=name,
504            inputs=input_descriptions,
505            outputs=output_descriptions,
506            weights=weight_descriptions,
507            description=kwargs.get("description", DEFAULTS["description"]),
508            authors=kwargs.get("authors", DEFAULTS["authors"]),
509            cite=kwargs.get("cite", DEFAULTS["cite"]),
510            license=spec.LicenseId("CC-BY-4.0"),
511            documentation=Path(doc_path),
512            git_repo=spec.HttpUrl("https://github.com/computational-cell-analytics/micro-sam"),
513            tags=kwargs.get("tags", DEFAULTS["tags"]),
514            covers=covers,
515            **extra_kwargs,
516            # TODO write specific settings in the config
517            # dict with yaml values, key must be a str
518            # micro_sam: ...
519            # config=
520        )
521
522        _check_model(model_description, input_paths, result_paths)
523
524        save_bioimageio_package(model_description, output_path=output_path)

Export SAM model to BioImage.IO model format.

The exported model can be uploaded to bioimage.io and be used in tools that support the BioImage.IO model format.

Arguments:
  • image: The image for generating test data.
  • label_image: The segmentation correspoding to image. It is used to derive prompt inputs for the model.
  • model_type: The type of the SAM model.
  • name: The name of the exported model.
  • output_path: Where the exported model is saved.
  • checkpoint_path: Optional checkpoint for loading the SAM model.