micro_sam.bioimageio.model_export

  1import os
  2import tempfile
  3from pathlib import Path
  4from typing import Optional, Union
  5
  6import xarray
  7import numpy as np
  8import matplotlib.pyplot as plt
  9
 10import torch
 11
 12import bioimageio.core
 13import bioimageio.spec.model.v0_5 as spec
 14from bioimageio.spec import save_bioimageio_package
 15from bioimageio.core.digest_spec import create_sample_for_model
 16
 17from .. import util
 18from ..prompt_generators import PointAndBoxPromptGenerator
 19from ..evaluation.model_comparison import _enhance_image, _overlay_outline, _overlay_box
 20from ..prompt_based_segmentation import _compute_logits_from_mask
 21from .predictor_adaptor import PredictorAdaptor
 22
 23
 24DEFAULTS = {
 25    "authors": [
 26        spec.Author(name="Anwai Archit", affiliation="University Goettingen", github_user="anwai98"),
 27        spec.Author(name="Constantin Pape", affiliation="University Goettingen", github_user="constantinpape"),
 28    ],
 29    "description": "Finetuned Segment Anything Model for Microscopy",
 30    "cite": [
 31        spec.CiteEntry(text="Archit et al. Segment Anything for Microscopy", doi=spec.Doi("10.1101/2023.08.21.554208")),
 32    ],
 33    "tags": ["segment-anything", "instance-segmentation"],
 34}
 35
 36
 37def _create_test_inputs_and_outputs(
 38    image,
 39    labels,
 40    model_type,
 41    checkpoint_path,
 42    tmp_dir,
 43):
 44    # For now we just generate a single box prompt here, but we could also generate more input prompts.
 45    generator = PointAndBoxPromptGenerator(
 46        n_positive_points=1,
 47        n_negative_points=2,
 48        dilation_strength=2,
 49        get_point_prompts=True,
 50        get_box_prompts=True,
 51    )
 52    centers, bounding_boxes = util.get_centers_and_bounding_boxes(labels)
 53    masks = util.segmentation_to_one_hot(labels.astype("int64"), segmentation_ids=[1, 2])  # type: ignore
 54    point_prompts, point_labels, box_prompts, _ = generator(masks, [bounding_boxes[1], bounding_boxes[2]])
 55
 56    box_prompts = box_prompts.numpy()[None]
 57    point_prompts = point_prompts.numpy()[None]
 58    point_labels = point_labels.numpy()[None]
 59
 60    # Generate logits from the two
 61    mask_prompts = np.stack(
 62        [
 63            _compute_logits_from_mask(labels == 1),
 64            _compute_logits_from_mask(labels == 2),
 65        ]
 66    )[None]
 67
 68    predictor = PredictorAdaptor(model_type=model_type)
 69    predictor.load_state_dict(torch.load(checkpoint_path))
 70
 71    input_ = util._to_image(image).transpose(2, 0, 1)[None]
 72    image_path = os.path.join(tmp_dir, "input.npy")
 73    np.save(image_path, input_)
 74
 75    masks, scores, embeddings = predictor(
 76        image=torch.from_numpy(input_),
 77        embeddings=None,
 78        box_prompts=torch.from_numpy(box_prompts),
 79        point_prompts=torch.from_numpy(point_prompts),
 80        point_labels=torch.from_numpy(point_labels),
 81        mask_prompts=torch.from_numpy(mask_prompts),
 82    )
 83
 84    box_prompt_path = os.path.join(tmp_dir, "box_prompts.npy")
 85    point_prompt_path = os.path.join(tmp_dir, "point_prompts.npy")
 86    point_label_path = os.path.join(tmp_dir, "point_labels.npy")
 87    mask_prompt_path = os.path.join(tmp_dir, "mask_prompts.npy")
 88    np.save(box_prompt_path, box_prompts.astype("int64"))
 89    np.save(point_prompt_path, point_prompts)
 90    np.save(point_label_path, point_labels)
 91    np.save(mask_prompt_path, mask_prompts)
 92
 93    mask_path = os.path.join(tmp_dir, "mask.npy")
 94    score_path = os.path.join(tmp_dir, "scores.npy")
 95    embed_path = os.path.join(tmp_dir, "embeddings.npy")
 96    np.save(mask_path, masks.numpy())
 97    np.save(score_path, scores.numpy())
 98    np.save(embed_path, embeddings.numpy())
 99
100    inputs = {
101        "image": image_path,
102        "box_prompts": box_prompt_path,
103        "point_prompts": point_prompt_path,
104        "point_labels": point_label_path,
105        "mask_prompts": mask_prompt_path,
106    }
107    outputs = {
108        "mask": mask_path,
109        "score": score_path,
110        "embeddings": embed_path
111    }
112    return inputs, outputs
113
114
115def _write_documentation(doc, model_type, tmp_dir):
116    tmp_doc_path = os.path.join(tmp_dir, "documentation.md")
117
118    if doc is None:
119        with open(tmp_doc_path, "w") as f:
120            f.write("# Segment Anything for Microscopy\n")
121            f.write("We extend Segment Anything, a vision foundation model for image segmentation ")
122            f.write("by training specialized models for microscopy data.\n")
123        return tmp_doc_path
124
125    elif os.path.exists(doc):
126        return doc
127
128    else:
129        with open(tmp_doc_path, "w") as f:
130            f.write(doc)
131        return tmp_doc_path
132
133
134def _get_checkpoint(model_type, checkpoint_path, tmp_dir):
135    # If we don't have a checkpoint we get the corresponding model from the registry.
136    if checkpoint_path is None:
137        model_registry = util.models()
138        checkpoint_path = model_registry.fetch(model_type)
139        return checkpoint_path, None
140
141    # Otherwise we have to load the checkpoint to see if it is the state dict of an encoder,
142    # or the checkpoint for a custom SAM model.
143    state, model_state = util._load_checkpoint(checkpoint_path)
144
145    if "model_state" in state:  # This is a finetuning checkpoint -> we have to resave the state.
146        new_checkpoint_path = os.path.join(tmp_dir, f"{model_type}.pt")
147        torch.save(model_state, new_checkpoint_path)
148
149        # We may also have an instance segmentation decoder in that case.
150        # If we have it we also resave this one and return it.
151        if "decoder_state" in state:
152            decoder_path = os.path.join(tmp_dir, f"{model_type}_decoder.pt")
153            decoder_state = state["decoder_state"]
154            torch.save(decoder_state, decoder_path)
155        else:
156            decoder_path = None
157
158        return new_checkpoint_path, decoder_path
159
160    else:  # This is a SAM encoder state -> we don't have to resave.
161        return checkpoint_path, None
162
163
164def _write_dependencies(dependency_file, require_mobile_sam):
165    content = """name: sam
166channels:
167    - pytorch
168    - conda-forge
169dependencies:
170    - segment-anything"""
171    if require_mobile_sam:
172        content += """
173    - pip:
174        - git+https://github.com/ChaoningZhang/MobileSAM.git"""
175    with open(dependency_file, "w") as f:
176        f.write(content)
177
178
179def _generate_covers(input_paths, result_paths, tmp_dir):
180    image = np.load(input_paths["image"]).squeeze()
181    prompts = np.load(input_paths["box_prompts"])
182    mask = np.load(result_paths["mask"])
183
184    # create the image overlay
185    if image.ndim == 2:
186        overlay = np.stack([image, image, image]).transpose((1, 2, 0))
187    elif image.shape[0] == 3:
188        overlay = image.transpose((1, 2, 0))
189    else:
190        overlay = image
191    overlay = _enhance_image(overlay.astype("float32"))
192
193    # overlay the mask as outline
194    overlay = _overlay_outline(overlay, mask[0, 0, 0], outline_dilation=2)
195
196    # overlay the bounding box prompt
197    prompt = prompts[0, 0][[1, 0, 3, 2]]
198    prompt = np.array([prompt[:2], prompt[2:]])
199    overlay = _overlay_box(overlay, prompt, outline_dilation=4)
200
201    # write  the cover image
202    fig, ax = plt.subplots(1)
203    ax.axis("off")
204    ax.imshow(overlay.astype("uint8"))
205    cover_path = os.path.join(tmp_dir, "cover.jpeg")
206    plt.savefig(cover_path, bbox_inches="tight")
207    plt.close()
208
209    covers = [cover_path]
210    return covers
211
212
213def _check_model(model_description, input_paths, result_paths):
214    # Load inputs.
215    image = xarray.DataArray(np.load(input_paths["image"]), dims=tuple("bcyx"))
216    embeddings = xarray.DataArray(np.load(result_paths["embeddings"]), dims=tuple("bcyx"))
217    box_prompts = xarray.DataArray(np.load(input_paths["box_prompts"]), dims=tuple("bic"))
218    point_prompts = xarray.DataArray(np.load(input_paths["point_prompts"]), dims=tuple("biic"))
219    point_labels = xarray.DataArray(np.load(input_paths["point_labels"]), dims=tuple("bic"))
220    mask_prompts = xarray.DataArray(np.load(input_paths["mask_prompts"]), dims=tuple("bicyx"))
221
222    # Load outputs.
223    mask = np.load(result_paths["mask"])
224
225    with bioimageio.core.create_prediction_pipeline(model_description) as pp:
226
227        # Check with all prompts. We only check the result for this setting,
228        # because this was used to generate the test data.
229        sample = create_sample_for_model(
230            model=model_description,
231            image=image,
232            box_prompts=box_prompts,
233            point_prompts=point_prompts,
234            point_labels=point_labels,
235            mask_prompts=mask_prompts,
236            embeddings=embeddings,
237        ).as_single_block()
238        prediction = pp.predict_sample_block(sample)
239
240        predicted_mask = prediction.blocks["masks"].data.data
241        assert predicted_mask.shape == mask.shape
242        assert np.allclose(mask, predicted_mask)
243
244        # Run the checks with partial prompts.
245        prompt_kwargs = [
246            # With boxes.
247            {"box_prompts": box_prompts},
248            # With point prompts.
249            {"point_prompts": point_prompts, "point_labels": point_labels},
250            # With masks.
251            {"mask_prompts": mask_prompts},
252            # With boxes and points.
253            {"box_prompts": box_prompts, "point_prompts": point_prompts, "point_labels": point_labels},
254            # With boxes and masks.
255            {"box_prompts": box_prompts, "mask_prompts": mask_prompts},
256            # With points and masks.
257            {"mask_prompts": mask_prompts, "point_prompts": point_prompts, "point_labels": point_labels},
258        ]
259
260        for kwargs in prompt_kwargs:
261            sample = create_sample_for_model(
262                model=model_description, image=image, embeddings=embeddings, **kwargs
263            ).as_single_block()
264            prediction = pp.predict_sample_block(sample)
265            predicted_mask = prediction.blocks["masks"].data.data
266            assert predicted_mask.shape == mask.shape
267
268
269def export_sam_model(
270    image: np.ndarray,
271    label_image: np.ndarray,
272    model_type: str,
273    name: str,
274    output_path: Union[str, os.PathLike],
275    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
276    **kwargs
277) -> None:
278    """Export SAM model to BioImage.IO model format.
279
280    The exported model can be uploaded to [bioimage.io](https://bioimage.io/#/) and
281    be used in tools that support the BioImage.IO model format.
282
283    Args:
284        image: The image for generating test data.
285        label_image: The segmentation correspoding to `image`.
286            It is used to derive prompt inputs for the model.
287        model_type: The type of the SAM model.
288        name: The name of the exported model.
289        output_path: Where the exported model is saved.
290        checkpoint_path: Optional checkpoint for loading the SAM model.
291    """
292    with tempfile.TemporaryDirectory() as tmp_dir:
293        checkpoint_path, decoder_path = _get_checkpoint(model_type, checkpoint_path, tmp_dir)
294        input_paths, result_paths = _create_test_inputs_and_outputs(
295            image, label_image, model_type, checkpoint_path, tmp_dir,
296        )
297        input_descriptions = [
298            # First input: the image data.
299            spec.InputTensorDescr(
300                id=spec.TensorId("image"),
301                axes=[
302                    spec.BatchAxis(size=1),
303                    # NOTE: to support 1 and 3 channels we can add another preprocessing.
304                    # Best solution: Have a pre-processing for this! (1C -> RGB)
305                    spec.ChannelAxis(channel_names=[spec.Identifier(cname) for cname in "RGB"]),
306                    spec.SpaceInputAxis(id=spec.AxisId("y"), size=spec.ARBITRARY_SIZE),
307                    spec.SpaceInputAxis(id=spec.AxisId("x"), size=spec.ARBITRARY_SIZE),
308                ],
309                test_tensor=spec.FileDescr(source=input_paths["image"]),
310                data=spec.IntervalOrRatioDataDescr(type="uint8")
311            ),
312
313            # Second input: the box prompts (optional)
314            spec.InputTensorDescr(
315                id=spec.TensorId("box_prompts"),
316                optional=True,
317                axes=[
318                    spec.BatchAxis(size=1),
319                    spec.IndexInputAxis(
320                        id=spec.AxisId("object"),
321                        size=spec.ARBITRARY_SIZE
322                    ),
323                    spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "hwxy"]),
324                ],
325                test_tensor=spec.FileDescr(source=input_paths["box_prompts"]),
326                data=spec.IntervalOrRatioDataDescr(type="int64")
327            ),
328
329            # Third input: the point prompt coordinates (optional)
330            spec.InputTensorDescr(
331                id=spec.TensorId("point_prompts"),
332                optional=True,
333                axes=[
334                    spec.BatchAxis(size=1),
335                    spec.IndexInputAxis(
336                        id=spec.AxisId("object"),
337                        size=spec.ARBITRARY_SIZE
338                    ),
339                    spec.IndexInputAxis(
340                        id=spec.AxisId("point"),
341                        size=spec.ARBITRARY_SIZE
342                    ),
343                    spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "xy"]),
344                ],
345                test_tensor=spec.FileDescr(source=input_paths["point_prompts"]),
346                data=spec.IntervalOrRatioDataDescr(type="int64")
347            ),
348
349            # Fourth input: the point prompt labels (optional)
350            spec.InputTensorDescr(
351                id=spec.TensorId("point_labels"),
352                optional=True,
353                axes=[
354                    spec.BatchAxis(size=1),
355                    spec.IndexInputAxis(
356                        id=spec.AxisId("object"),
357                        size=spec.ARBITRARY_SIZE
358                    ),
359                    spec.IndexInputAxis(
360                        id=spec.AxisId("point"),
361                        size=spec.ARBITRARY_SIZE
362                    ),
363                ],
364                test_tensor=spec.FileDescr(source=input_paths["point_labels"]),
365                data=spec.IntervalOrRatioDataDescr(type="int64")
366            ),
367
368            # Fifth input: the mask prompts (optional)
369            spec.InputTensorDescr(
370                id=spec.TensorId("mask_prompts"),
371                optional=True,
372                axes=[
373                    spec.BatchAxis(size=1),
374                    spec.IndexInputAxis(
375                        id=spec.AxisId("object"),
376                        size=spec.ARBITRARY_SIZE
377                    ),
378                    spec.ChannelAxis(channel_names=["channel"]),
379                    spec.SpaceInputAxis(id=spec.AxisId("y"), size=256),
380                    spec.SpaceInputAxis(id=spec.AxisId("x"), size=256),
381                ],
382                test_tensor=spec.FileDescr(source=input_paths["mask_prompts"]),
383                data=spec.IntervalOrRatioDataDescr(type="float32")
384            ),
385
386            # Sixth input: the image embeddings (optional)
387            spec.InputTensorDescr(
388                id=spec.TensorId("embeddings"),
389                optional=True,
390                axes=[
391                    spec.BatchAxis(size=1),
392                    # NOTE: we currently have to specify all the channel names
393                    # (It would be nice to also support size)
394                    spec.ChannelAxis(channel_names=[spec.Identifier(f"c{i}") for i in range(256)]),
395                    spec.SpaceInputAxis(id=spec.AxisId("y"), size=64),
396                    spec.SpaceInputAxis(id=spec.AxisId("x"), size=64),
397                ],
398                test_tensor=spec.FileDescr(source=result_paths["embeddings"]),
399                data=spec.IntervalOrRatioDataDescr(type="float32")
400            ),
401
402        ]
403
404        output_descriptions = [
405            # First output: The mask predictions.
406            spec.OutputTensorDescr(
407                id=spec.TensorId("masks"),
408                axes=[
409                    spec.BatchAxis(size=1),
410                    # NOTE: we use the data dependent size here to avoid dependency on optional inputs
411                    spec.IndexOutputAxis(
412                        id=spec.AxisId("object"), size=spec.DataDependentSize(),
413                    ),
414                    # NOTE: this could be a 3 once we use multi-masking
415                    spec.ChannelAxis(channel_names=[spec.Identifier("mask")]),
416                    spec.SpaceOutputAxis(
417                        id=spec.AxisId("y"),
418                        size=spec.SizeReference(
419                            tensor_id=spec.TensorId("image"), axis_id=spec.AxisId("y"),
420                        )
421                    ),
422                    spec.SpaceOutputAxis(
423                        id=spec.AxisId("x"),
424                        size=spec.SizeReference(
425                            tensor_id=spec.TensorId("image"), axis_id=spec.AxisId("x"),
426                        )
427                    )
428                ],
429                data=spec.IntervalOrRatioDataDescr(type="uint8"),
430                test_tensor=spec.FileDescr(source=result_paths["mask"])
431            ),
432
433            # The score predictions
434            spec.OutputTensorDescr(
435                id=spec.TensorId("scores"),
436                axes=[
437                    spec.BatchAxis(size=1),
438                    # NOTE: we use the data dependent size here to avoid dependency on optional inputs
439                    spec.IndexOutputAxis(
440                        id=spec.AxisId("object"), size=spec.DataDependentSize(),
441                    ),
442                    # NOTE: this could be a 3 once we use multi-masking
443                    spec.ChannelAxis(channel_names=[spec.Identifier("mask")]),
444                ],
445                data=spec.IntervalOrRatioDataDescr(type="float32"),
446                test_tensor=spec.FileDescr(source=result_paths["score"])
447            ),
448
449            # The image embeddings
450            spec.OutputTensorDescr(
451                id=spec.TensorId("embeddings"),
452                axes=[
453                    spec.BatchAxis(size=1),
454                    spec.ChannelAxis(channel_names=[spec.Identifier(f"c{i}") for i in range(256)]),
455                    spec.SpaceOutputAxis(id=spec.AxisId("y"), size=64),
456                    spec.SpaceOutputAxis(id=spec.AxisId("x"), size=64),
457                ],
458                data=spec.IntervalOrRatioDataDescr(type="float32"),
459                test_tensor=spec.FileDescr(source=result_paths["embeddings"])
460            )
461        ]
462
463        architecture_path = os.path.join(os.path.split(__file__)[0], "predictor_adaptor.py")
464        architecture = spec.ArchitectureFromFileDescr(
465            source=Path(architecture_path),
466            callable="PredictorAdaptor",
467            kwargs={"model_type": model_type}
468        )
469
470        dependency_file = os.path.join(tmp_dir, "environment.yaml")
471        _write_dependencies(dependency_file, require_mobile_sam=model_type.startswith("vit_t"))
472
473        weight_descriptions = spec.WeightsDescr(
474            pytorch_state_dict=spec.PytorchStateDictWeightsDescr(
475                source=Path(checkpoint_path),
476                architecture=architecture,
477                pytorch_version=spec.Version(torch.__version__),
478                dependencies=spec.EnvironmentFileDescr(source=dependency_file),
479            )
480        )
481
482        doc_path = _write_documentation(kwargs.get("documentation", None), model_type, tmp_dir)
483
484        covers = kwargs.get("covers", None)
485        if covers is None:
486            covers = _generate_covers(input_paths, result_paths, tmp_dir)
487        else:
488            assert all(os.path.exists(cov) for cov in covers)
489
490        # the uploader information is only added if explicitly passed
491        extra_kwargs = {}
492        if "id" in kwargs:
493            extra_kwargs["id"] = kwargs["id"]
494        if "id_emoji" in kwargs:
495            extra_kwargs["id_emoji"] = kwargs["id_emoji"]
496        if "uploader" in kwargs:
497            extra_kwargs["uploader"] = kwargs["uploader"]
498
499        if decoder_path is not None:
500            extra_kwargs["attachments"] = [spec.FileDescr(source=decoder_path)]
501
502        model_description = spec.ModelDescr(
503            name=name,
504            inputs=input_descriptions,
505            outputs=output_descriptions,
506            weights=weight_descriptions,
507            description=kwargs.get("description", DEFAULTS["description"]),
508            authors=kwargs.get("authors", DEFAULTS["authors"]),
509            cite=kwargs.get("cite", DEFAULTS["cite"]),
510            license=spec.LicenseId("CC-BY-4.0"),
511            documentation=Path(doc_path),
512            git_repo=spec.HttpUrl("https://github.com/computational-cell-analytics/micro-sam"),
513            tags=kwargs.get("tags", DEFAULTS["tags"]),
514            covers=covers,
515            **extra_kwargs,
516            # TODO write specific settings in the config
517            # dict with yaml values, key must be a str
518            # micro_sam: ...
519            # config=
520        )
521
522        _check_model(model_description, input_paths, result_paths)
523
524        save_bioimageio_package(model_description, output_path=output_path)
DEFAULTS = {'authors': [Author(affiliation='University Goettingen', email=None, orcid=None, name='Anwai Archit', github_user='anwai98'), Author(affiliation='University Goettingen', email=None, orcid=None, name='Constantin Pape', github_user='constantinpape')], 'description': 'Finetuned Segment Anything Model for Microscopy', 'cite': [CiteEntry(text='Archit et al. Segment Anything for Microscopy', doi='10.1101/2023.08.21.554208', url=None)], 'tags': ['segment-anything', 'instance-segmentation']}
def export_sam_model( image: numpy.ndarray, label_image: numpy.ndarray, model_type: str, name: str, output_path: Union[str, os.PathLike], checkpoint_path: Union[str, os.PathLike, NoneType] = None, **kwargs) -> None:
270def export_sam_model(
271    image: np.ndarray,
272    label_image: np.ndarray,
273    model_type: str,
274    name: str,
275    output_path: Union[str, os.PathLike],
276    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
277    **kwargs
278) -> None:
279    """Export SAM model to BioImage.IO model format.
280
281    The exported model can be uploaded to [bioimage.io](https://bioimage.io/#/) and
282    be used in tools that support the BioImage.IO model format.
283
284    Args:
285        image: The image for generating test data.
286        label_image: The segmentation correspoding to `image`.
287            It is used to derive prompt inputs for the model.
288        model_type: The type of the SAM model.
289        name: The name of the exported model.
290        output_path: Where the exported model is saved.
291        checkpoint_path: Optional checkpoint for loading the SAM model.
292    """
293    with tempfile.TemporaryDirectory() as tmp_dir:
294        checkpoint_path, decoder_path = _get_checkpoint(model_type, checkpoint_path, tmp_dir)
295        input_paths, result_paths = _create_test_inputs_and_outputs(
296            image, label_image, model_type, checkpoint_path, tmp_dir,
297        )
298        input_descriptions = [
299            # First input: the image data.
300            spec.InputTensorDescr(
301                id=spec.TensorId("image"),
302                axes=[
303                    spec.BatchAxis(size=1),
304                    # NOTE: to support 1 and 3 channels we can add another preprocessing.
305                    # Best solution: Have a pre-processing for this! (1C -> RGB)
306                    spec.ChannelAxis(channel_names=[spec.Identifier(cname) for cname in "RGB"]),
307                    spec.SpaceInputAxis(id=spec.AxisId("y"), size=spec.ARBITRARY_SIZE),
308                    spec.SpaceInputAxis(id=spec.AxisId("x"), size=spec.ARBITRARY_SIZE),
309                ],
310                test_tensor=spec.FileDescr(source=input_paths["image"]),
311                data=spec.IntervalOrRatioDataDescr(type="uint8")
312            ),
313
314            # Second input: the box prompts (optional)
315            spec.InputTensorDescr(
316                id=spec.TensorId("box_prompts"),
317                optional=True,
318                axes=[
319                    spec.BatchAxis(size=1),
320                    spec.IndexInputAxis(
321                        id=spec.AxisId("object"),
322                        size=spec.ARBITRARY_SIZE
323                    ),
324                    spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "hwxy"]),
325                ],
326                test_tensor=spec.FileDescr(source=input_paths["box_prompts"]),
327                data=spec.IntervalOrRatioDataDescr(type="int64")
328            ),
329
330            # Third input: the point prompt coordinates (optional)
331            spec.InputTensorDescr(
332                id=spec.TensorId("point_prompts"),
333                optional=True,
334                axes=[
335                    spec.BatchAxis(size=1),
336                    spec.IndexInputAxis(
337                        id=spec.AxisId("object"),
338                        size=spec.ARBITRARY_SIZE
339                    ),
340                    spec.IndexInputAxis(
341                        id=spec.AxisId("point"),
342                        size=spec.ARBITRARY_SIZE
343                    ),
344                    spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "xy"]),
345                ],
346                test_tensor=spec.FileDescr(source=input_paths["point_prompts"]),
347                data=spec.IntervalOrRatioDataDescr(type="int64")
348            ),
349
350            # Fourth input: the point prompt labels (optional)
351            spec.InputTensorDescr(
352                id=spec.TensorId("point_labels"),
353                optional=True,
354                axes=[
355                    spec.BatchAxis(size=1),
356                    spec.IndexInputAxis(
357                        id=spec.AxisId("object"),
358                        size=spec.ARBITRARY_SIZE
359                    ),
360                    spec.IndexInputAxis(
361                        id=spec.AxisId("point"),
362                        size=spec.ARBITRARY_SIZE
363                    ),
364                ],
365                test_tensor=spec.FileDescr(source=input_paths["point_labels"]),
366                data=spec.IntervalOrRatioDataDescr(type="int64")
367            ),
368
369            # Fifth input: the mask prompts (optional)
370            spec.InputTensorDescr(
371                id=spec.TensorId("mask_prompts"),
372                optional=True,
373                axes=[
374                    spec.BatchAxis(size=1),
375                    spec.IndexInputAxis(
376                        id=spec.AxisId("object"),
377                        size=spec.ARBITRARY_SIZE
378                    ),
379                    spec.ChannelAxis(channel_names=["channel"]),
380                    spec.SpaceInputAxis(id=spec.AxisId("y"), size=256),
381                    spec.SpaceInputAxis(id=spec.AxisId("x"), size=256),
382                ],
383                test_tensor=spec.FileDescr(source=input_paths["mask_prompts"]),
384                data=spec.IntervalOrRatioDataDescr(type="float32")
385            ),
386
387            # Sixth input: the image embeddings (optional)
388            spec.InputTensorDescr(
389                id=spec.TensorId("embeddings"),
390                optional=True,
391                axes=[
392                    spec.BatchAxis(size=1),
393                    # NOTE: we currently have to specify all the channel names
394                    # (It would be nice to also support size)
395                    spec.ChannelAxis(channel_names=[spec.Identifier(f"c{i}") for i in range(256)]),
396                    spec.SpaceInputAxis(id=spec.AxisId("y"), size=64),
397                    spec.SpaceInputAxis(id=spec.AxisId("x"), size=64),
398                ],
399                test_tensor=spec.FileDescr(source=result_paths["embeddings"]),
400                data=spec.IntervalOrRatioDataDescr(type="float32")
401            ),
402
403        ]
404
405        output_descriptions = [
406            # First output: The mask predictions.
407            spec.OutputTensorDescr(
408                id=spec.TensorId("masks"),
409                axes=[
410                    spec.BatchAxis(size=1),
411                    # NOTE: we use the data dependent size here to avoid dependency on optional inputs
412                    spec.IndexOutputAxis(
413                        id=spec.AxisId("object"), size=spec.DataDependentSize(),
414                    ),
415                    # NOTE: this could be a 3 once we use multi-masking
416                    spec.ChannelAxis(channel_names=[spec.Identifier("mask")]),
417                    spec.SpaceOutputAxis(
418                        id=spec.AxisId("y"),
419                        size=spec.SizeReference(
420                            tensor_id=spec.TensorId("image"), axis_id=spec.AxisId("y"),
421                        )
422                    ),
423                    spec.SpaceOutputAxis(
424                        id=spec.AxisId("x"),
425                        size=spec.SizeReference(
426                            tensor_id=spec.TensorId("image"), axis_id=spec.AxisId("x"),
427                        )
428                    )
429                ],
430                data=spec.IntervalOrRatioDataDescr(type="uint8"),
431                test_tensor=spec.FileDescr(source=result_paths["mask"])
432            ),
433
434            # The score predictions
435            spec.OutputTensorDescr(
436                id=spec.TensorId("scores"),
437                axes=[
438                    spec.BatchAxis(size=1),
439                    # NOTE: we use the data dependent size here to avoid dependency on optional inputs
440                    spec.IndexOutputAxis(
441                        id=spec.AxisId("object"), size=spec.DataDependentSize(),
442                    ),
443                    # NOTE: this could be a 3 once we use multi-masking
444                    spec.ChannelAxis(channel_names=[spec.Identifier("mask")]),
445                ],
446                data=spec.IntervalOrRatioDataDescr(type="float32"),
447                test_tensor=spec.FileDescr(source=result_paths["score"])
448            ),
449
450            # The image embeddings
451            spec.OutputTensorDescr(
452                id=spec.TensorId("embeddings"),
453                axes=[
454                    spec.BatchAxis(size=1),
455                    spec.ChannelAxis(channel_names=[spec.Identifier(f"c{i}") for i in range(256)]),
456                    spec.SpaceOutputAxis(id=spec.AxisId("y"), size=64),
457                    spec.SpaceOutputAxis(id=spec.AxisId("x"), size=64),
458                ],
459                data=spec.IntervalOrRatioDataDescr(type="float32"),
460                test_tensor=spec.FileDescr(source=result_paths["embeddings"])
461            )
462        ]
463
464        architecture_path = os.path.join(os.path.split(__file__)[0], "predictor_adaptor.py")
465        architecture = spec.ArchitectureFromFileDescr(
466            source=Path(architecture_path),
467            callable="PredictorAdaptor",
468            kwargs={"model_type": model_type}
469        )
470
471        dependency_file = os.path.join(tmp_dir, "environment.yaml")
472        _write_dependencies(dependency_file, require_mobile_sam=model_type.startswith("vit_t"))
473
474        weight_descriptions = spec.WeightsDescr(
475            pytorch_state_dict=spec.PytorchStateDictWeightsDescr(
476                source=Path(checkpoint_path),
477                architecture=architecture,
478                pytorch_version=spec.Version(torch.__version__),
479                dependencies=spec.EnvironmentFileDescr(source=dependency_file),
480            )
481        )
482
483        doc_path = _write_documentation(kwargs.get("documentation", None), model_type, tmp_dir)
484
485        covers = kwargs.get("covers", None)
486        if covers is None:
487            covers = _generate_covers(input_paths, result_paths, tmp_dir)
488        else:
489            assert all(os.path.exists(cov) for cov in covers)
490
491        # the uploader information is only added if explicitly passed
492        extra_kwargs = {}
493        if "id" in kwargs:
494            extra_kwargs["id"] = kwargs["id"]
495        if "id_emoji" in kwargs:
496            extra_kwargs["id_emoji"] = kwargs["id_emoji"]
497        if "uploader" in kwargs:
498            extra_kwargs["uploader"] = kwargs["uploader"]
499
500        if decoder_path is not None:
501            extra_kwargs["attachments"] = [spec.FileDescr(source=decoder_path)]
502
503        model_description = spec.ModelDescr(
504            name=name,
505            inputs=input_descriptions,
506            outputs=output_descriptions,
507            weights=weight_descriptions,
508            description=kwargs.get("description", DEFAULTS["description"]),
509            authors=kwargs.get("authors", DEFAULTS["authors"]),
510            cite=kwargs.get("cite", DEFAULTS["cite"]),
511            license=spec.LicenseId("CC-BY-4.0"),
512            documentation=Path(doc_path),
513            git_repo=spec.HttpUrl("https://github.com/computational-cell-analytics/micro-sam"),
514            tags=kwargs.get("tags", DEFAULTS["tags"]),
515            covers=covers,
516            **extra_kwargs,
517            # TODO write specific settings in the config
518            # dict with yaml values, key must be a str
519            # micro_sam: ...
520            # config=
521        )
522
523        _check_model(model_description, input_paths, result_paths)
524
525        save_bioimageio_package(model_description, output_path=output_path)

Export SAM model to BioImage.IO model format.

The exported model can be uploaded to bioimage.io and be used in tools that support the BioImage.IO model format.

Arguments:
  • image: The image for generating test data.
  • label_image: The segmentation correspoding to image. It is used to derive prompt inputs for the model.
  • model_type: The type of the SAM model.
  • name: The name of the exported model.
  • output_path: Where the exported model is saved.
  • checkpoint_path: Optional checkpoint for loading the SAM model.