micro_sam.bioimageio.model_export
1import os 2import tempfile 3from pathlib import Path 4from typing import Optional, Union 5 6import xarray 7import numpy as np 8import matplotlib.pyplot as plt 9 10import torch 11 12import bioimageio.core 13import bioimageio.spec.model.v0_5 as spec 14from bioimageio.spec import save_bioimageio_package 15from bioimageio.core.digest_spec import create_sample_for_model 16 17from .. import util 18from ..prompt_generators import PointAndBoxPromptGenerator 19from ..evaluation.model_comparison import _enhance_image, _overlay_outline, _overlay_box 20from ..prompt_based_segmentation import _compute_logits_from_mask 21from .predictor_adaptor import PredictorAdaptor 22 23 24DEFAULTS = { 25 "authors": [ 26 spec.Author(name="Anwai Archit", affiliation="University Goettingen", github_user="anwai98"), 27 spec.Author(name="Constantin Pape", affiliation="University Goettingen", github_user="constantinpape"), 28 ], 29 "description": "Finetuned Segment Anything Model for Microscopy", 30 "cite": [ 31 spec.CiteEntry( 32 text="Archit et al. Segment Anything for Microscopy", 33 doi=spec.Doi("10.1038/s41592-024-02580-4") 34 ), 35 ], 36 "tags": ["segment-anything", "instance-segmentation"], 37} 38 39# Reference: https://github.com/bioimage-io/spec-bioimage-io/commit/39d343681d427ec93cf69eef7597d9eb9678deb1#diff-0bbdaa8196fa31f945afabcf04a4295ff098f1f24400ef9e59b0f684d411905eL269 # noqa 40# We had this parameter in bioimageio.spec. This has been removed. We just make a copy of the same parameter. 41ARBITRARY_SIZE = spec.ParameterizedSize(min=1, step=1) 42 43 44def _create_test_inputs_and_outputs(image, labels, model_type, checkpoint_path, tmp_dir): 45 46 # For now we just generate a single box prompt here, but we could also generate more input prompts. 47 generator = PointAndBoxPromptGenerator( 48 n_positive_points=1, 49 n_negative_points=2, 50 dilation_strength=2, 51 get_point_prompts=True, 52 get_box_prompts=True, 53 ) 54 centers, bounding_boxes = util.get_centers_and_bounding_boxes(labels) 55 masks = util.segmentation_to_one_hot(labels.astype("int64"), segmentation_ids=[1, 2]) # type: ignore 56 point_prompts, point_labels, box_prompts, _ = generator(masks, [bounding_boxes[1], bounding_boxes[2]]) 57 58 box_prompts = box_prompts.numpy()[None] 59 point_prompts = point_prompts.numpy()[None] 60 point_labels = point_labels.numpy()[None] 61 62 # Generate logits from the two 63 mask_prompts = np.stack( 64 [_compute_logits_from_mask(labels == 1), _compute_logits_from_mask(labels == 2)] 65 )[None] 66 67 predictor = PredictorAdaptor(model_type=model_type) 68 predictor.load_state_dict(torch.load(checkpoint_path)) 69 70 input_ = util._to_image(image).transpose(2, 0, 1)[None] 71 image_path = os.path.join(tmp_dir, "input.npy") 72 np.save(image_path, input_) 73 74 masks, scores, embeddings = predictor( 75 image=torch.from_numpy(input_), 76 embeddings=None, 77 box_prompts=torch.from_numpy(box_prompts), 78 point_prompts=torch.from_numpy(point_prompts), 79 point_labels=torch.from_numpy(point_labels), 80 mask_prompts=torch.from_numpy(mask_prompts), 81 ) 82 83 box_prompt_path = os.path.join(tmp_dir, "box_prompts.npy") 84 point_prompt_path = os.path.join(tmp_dir, "point_prompts.npy") 85 point_label_path = os.path.join(tmp_dir, "point_labels.npy") 86 mask_prompt_path = os.path.join(tmp_dir, "mask_prompts.npy") 87 np.save(box_prompt_path, box_prompts.astype("int64")) 88 np.save(point_prompt_path, point_prompts) 89 np.save(point_label_path, point_labels) 90 np.save(mask_prompt_path, mask_prompts) 91 92 mask_path = os.path.join(tmp_dir, "mask.npy") 93 score_path = os.path.join(tmp_dir, "scores.npy") 94 embed_path = os.path.join(tmp_dir, "embeddings.npy") 95 np.save(mask_path, masks.numpy()) 96 np.save(score_path, scores.numpy()) 97 np.save(embed_path, embeddings.numpy()) 98 99 inputs = { 100 "image": image_path, 101 "box_prompts": box_prompt_path, 102 "point_prompts": point_prompt_path, 103 "point_labels": point_label_path, 104 "mask_prompts": mask_prompt_path, 105 } 106 outputs = {"mask": mask_path, "score": score_path, "embeddings": embed_path} 107 return inputs, outputs 108 109 110def _write_documentation(doc, model_type, tmp_dir): 111 tmp_doc_path = os.path.join(tmp_dir, "documentation.md") 112 113 if doc is None: 114 with open(tmp_doc_path, "w") as f: 115 f.write("# Segment Anything for Microscopy\n") 116 f.write("We extend Segment Anything, a vision foundation model for image segmentation ") 117 f.write("by training specialized models for microscopy data.\n") 118 return tmp_doc_path 119 120 elif os.path.exists(doc): 121 return doc 122 123 else: 124 with open(tmp_doc_path, "w") as f: 125 f.write(doc) 126 return tmp_doc_path 127 128 129def _get_checkpoint(model_type, checkpoint_path, tmp_dir): 130 # If we don't have a checkpoint we get the corresponding model from the registry. 131 if checkpoint_path is None: 132 model_registry = util.models() 133 checkpoint_path = model_registry.fetch(model_type) 134 return checkpoint_path, None 135 136 # Otherwise we have to load the checkpoint to see if it is the state dict of an encoder, 137 # or the checkpoint for a custom SAM model. 138 state, model_state = util._load_checkpoint(checkpoint_path) 139 140 if "model_state" in state: # This is a finetuning checkpoint -> we have to resave the state. 141 new_checkpoint_path = os.path.join(tmp_dir, f"{model_type}.pt") 142 torch.save(model_state, new_checkpoint_path) 143 144 # We may also have an instance segmentation decoder in that case. 145 # If we have it we also resave this one and return it. 146 if "decoder_state" in state: 147 decoder_path = os.path.join(tmp_dir, f"{model_type}_decoder.pt") 148 decoder_state = state["decoder_state"] 149 torch.save(decoder_state, decoder_path) 150 else: 151 decoder_path = None 152 153 return new_checkpoint_path, decoder_path 154 155 else: # This is a SAM encoder state -> we don't have to resave. 156 return checkpoint_path, None 157 158 159# TODO: Update this with our latest yaml file updates. 160def _write_dependencies(dependency_file, require_mobile_sam): 161 content = """name: sam 162channels: 163 - pytorch 164 - conda-forge 165dependencies: 166 - segment-anything""" 167 if require_mobile_sam: 168 content += """ 169 - timm 170 - pip: 171 - git+https://github.com/ChaoningZhang/MobileSAM.git""" 172 with open(dependency_file, "w") as f: 173 f.write(content) 174 175 176def _generate_covers(input_paths, result_paths, tmp_dir): 177 image = np.load(input_paths["image"]).squeeze() 178 prompts = np.load(input_paths["box_prompts"]) 179 mask = np.load(result_paths["mask"]) 180 181 # create the image overlay 182 if image.ndim == 2: 183 overlay = np.stack([image, image, image]).transpose((1, 2, 0)) 184 elif image.shape[0] == 3: 185 overlay = image.transpose((1, 2, 0)) 186 else: 187 overlay = image 188 overlay = _enhance_image(overlay.astype("float32")) 189 190 # overlay the mask as outline 191 overlay = _overlay_outline(overlay, mask[0, 0, 0], outline_dilation=2) 192 193 # overlay the bounding box prompt 194 prompt = prompts[0, 0][[1, 0, 3, 2]] 195 prompt = np.array([prompt[:2], prompt[2:]]) 196 overlay = _overlay_box(overlay, prompt, outline_dilation=4) 197 198 # write the cover image 199 fig, ax = plt.subplots(1) 200 ax.axis("off") 201 ax.imshow(overlay.astype("uint8")) 202 cover_path = os.path.join(tmp_dir, "cover.jpeg") 203 plt.savefig(cover_path, bbox_inches="tight") 204 plt.close() 205 206 covers = [cover_path] 207 return covers 208 209 210def _check_model(model_description, input_paths, result_paths): 211 # Load inputs. 212 image = xarray.DataArray(np.load(input_paths["image"]), dims=("batch", "channel", "y", "x")) 213 embeddings = xarray.DataArray(np.load(result_paths["embeddings"]), dims=("batch", "channel", "y", "x")) 214 box_prompts = xarray.DataArray(np.load(input_paths["box_prompts"]), dims=("batch", "object", "channel")) 215 point_prompts = xarray.DataArray( 216 np.load(input_paths["point_prompts"]), dims=("batch", "object", "point", "channel") 217 ) 218 point_labels = xarray.DataArray(np.load(input_paths["point_labels"]), dims=("batch", "object", "point")) 219 mask_prompts = xarray.DataArray(np.load(input_paths["mask_prompts"]), dims=("batch", "object", "channel", "y", "x")) 220 221 # Load outputs. 222 mask = np.load(result_paths["mask"]) 223 224 with bioimageio.core.create_prediction_pipeline(model_description) as pp: 225 226 # Check with all prompts. We only check the result for this setting, 227 # because this was used to generate the test data. 228 sample = create_sample_for_model( 229 model=model_description, 230 inputs={ 231 "image": image, 232 "box_prompts": box_prompts, 233 "point_prompts": point_prompts, 234 "point_labels": point_labels, 235 "mask_prompts": mask_prompts, 236 "embeddings": embeddings, 237 }, 238 ).as_single_block() 239 prediction = pp.predict_sample_block(sample) 240 241 predicted_mask = prediction.blocks["masks"].data.data 242 assert predicted_mask.shape == mask.shape 243 assert np.allclose(mask, predicted_mask) 244 245 # Run the checks with partial prompts. 246 prompt_kwargs = [ 247 # With boxes. 248 {"box_prompts": box_prompts}, 249 # With point prompts. 250 {"point_prompts": point_prompts, "point_labels": point_labels}, 251 # With masks. 252 {"mask_prompts": mask_prompts}, 253 # With boxes and points. 254 {"box_prompts": box_prompts, "point_prompts": point_prompts, "point_labels": point_labels}, 255 # With boxes and masks. 256 {"box_prompts": box_prompts, "mask_prompts": mask_prompts}, 257 # With points and masks. 258 {"mask_prompts": mask_prompts, "point_prompts": point_prompts, "point_labels": point_labels}, 259 ] 260 261 for kwargs in prompt_kwargs: 262 sample = create_sample_for_model( 263 model=model_description, inputs={"image": image, "embeddings": embeddings, **kwargs}, 264 ).as_single_block() 265 prediction = pp.predict_sample_block(sample) 266 predicted_mask = prediction.blocks["masks"].data.data 267 assert predicted_mask.shape == mask.shape 268 269 270def export_sam_model( 271 image: np.ndarray, 272 label_image: np.ndarray, 273 model_type: str, 274 name: str, 275 output_path: Union[str, os.PathLike], 276 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 277 **kwargs 278) -> None: 279 """Export SAM model to BioImage.IO model format. 280 281 The exported model can be uploaded to [bioimage.io](https://bioimage.io/#/) and 282 be used in tools that support the BioImage.IO model format. 283 284 Args: 285 image: The image for generating test data. 286 label_image: The segmentation correspoding to `image`. 287 It is used to derive prompt inputs for the model. 288 model_type: The type of the SAM model. 289 name: The name of the exported model. 290 output_path: Where the exported model is saved. 291 checkpoint_path: Optional checkpoint for loading the SAM model. 292 """ 293 with tempfile.TemporaryDirectory() as tmp_dir: 294 checkpoint_path, decoder_path = _get_checkpoint(model_type, checkpoint_path, tmp_dir) 295 input_paths, result_paths = _create_test_inputs_and_outputs( 296 image, label_image, model_type, checkpoint_path, tmp_dir, 297 ) 298 input_descriptions = [ 299 # First input: the image data. 300 spec.InputTensorDescr( 301 id=spec.TensorId("image"), 302 axes=[ 303 spec.BatchAxis(size=1), 304 # NOTE: to support 1 and 3 channels we can add another preprocessing. 305 # Best solution: Have a pre-processing for this! (1C -> RGB) 306 spec.ChannelAxis(channel_names=[spec.Identifier(cname) for cname in "RGB"]), 307 spec.SpaceInputAxis(id=spec.AxisId("y"), size=ARBITRARY_SIZE), 308 spec.SpaceInputAxis(id=spec.AxisId("x"), size=ARBITRARY_SIZE), 309 ], 310 test_tensor=spec.FileDescr(source=input_paths["image"]), 311 data=spec.IntervalOrRatioDataDescr(type="uint8") 312 ), 313 314 # Second input: the box prompts (optional) 315 spec.InputTensorDescr( 316 id=spec.TensorId("box_prompts"), 317 optional=True, 318 axes=[ 319 spec.BatchAxis(size=1), 320 spec.IndexInputAxis( 321 id=spec.AxisId("object"), 322 size=ARBITRARY_SIZE 323 ), 324 spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "hwxy"]), 325 ], 326 test_tensor=spec.FileDescr(source=input_paths["box_prompts"]), 327 data=spec.IntervalOrRatioDataDescr(type="int64") 328 ), 329 330 # Third input: the point prompt coordinates (optional) 331 spec.InputTensorDescr( 332 id=spec.TensorId("point_prompts"), 333 optional=True, 334 axes=[ 335 spec.BatchAxis(size=1), 336 spec.IndexInputAxis( 337 id=spec.AxisId("object"), 338 size=ARBITRARY_SIZE 339 ), 340 spec.IndexInputAxis( 341 id=spec.AxisId("point"), 342 size=ARBITRARY_SIZE 343 ), 344 spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "xy"]), 345 ], 346 test_tensor=spec.FileDescr(source=input_paths["point_prompts"]), 347 data=spec.IntervalOrRatioDataDescr(type="int64") 348 ), 349 350 # Fourth input: the point prompt labels (optional) 351 spec.InputTensorDescr( 352 id=spec.TensorId("point_labels"), 353 optional=True, 354 axes=[ 355 spec.BatchAxis(size=1), 356 spec.IndexInputAxis( 357 id=spec.AxisId("object"), 358 size=ARBITRARY_SIZE 359 ), 360 spec.IndexInputAxis( 361 id=spec.AxisId("point"), 362 size=ARBITRARY_SIZE 363 ), 364 ], 365 test_tensor=spec.FileDescr(source=input_paths["point_labels"]), 366 data=spec.IntervalOrRatioDataDescr(type="int64") 367 ), 368 369 # Fifth input: the mask prompts (optional) 370 spec.InputTensorDescr( 371 id=spec.TensorId("mask_prompts"), 372 optional=True, 373 axes=[ 374 spec.BatchAxis(size=1), 375 spec.IndexInputAxis( 376 id=spec.AxisId("object"), 377 size=ARBITRARY_SIZE 378 ), 379 spec.ChannelAxis(channel_names=["channel"]), 380 spec.SpaceInputAxis(id=spec.AxisId("y"), size=256), 381 spec.SpaceInputAxis(id=spec.AxisId("x"), size=256), 382 ], 383 test_tensor=spec.FileDescr(source=input_paths["mask_prompts"]), 384 data=spec.IntervalOrRatioDataDescr(type="float32") 385 ), 386 387 # Sixth input: the image embeddings (optional) 388 spec.InputTensorDescr( 389 id=spec.TensorId("embeddings"), 390 optional=True, 391 axes=[ 392 spec.BatchAxis(size=1), 393 # NOTE: we currently have to specify all the channel names 394 # (It would be nice to also support size) 395 spec.ChannelAxis(channel_names=[spec.Identifier(f"c{i}") for i in range(256)]), 396 spec.SpaceInputAxis(id=spec.AxisId("y"), size=64), 397 spec.SpaceInputAxis(id=spec.AxisId("x"), size=64), 398 ], 399 test_tensor=spec.FileDescr(source=result_paths["embeddings"]), 400 data=spec.IntervalOrRatioDataDescr(type="float32") 401 ), 402 403 ] 404 405 output_descriptions = [ 406 # First output: The mask predictions. 407 spec.OutputTensorDescr( 408 id=spec.TensorId("masks"), 409 axes=[ 410 spec.BatchAxis(size=1), 411 # NOTE: we use the data dependent size here to avoid dependency on optional inputs 412 spec.IndexOutputAxis( 413 id=spec.AxisId("object"), size=spec.DataDependentSize(), 414 ), 415 # NOTE: this could be a 3 once we use multi-masking 416 spec.ChannelAxis(channel_names=[spec.Identifier("mask")]), 417 spec.SpaceOutputAxis( 418 id=spec.AxisId("y"), 419 size=spec.SizeReference( 420 tensor_id=spec.TensorId("image"), axis_id=spec.AxisId("y"), 421 ) 422 ), 423 spec.SpaceOutputAxis( 424 id=spec.AxisId("x"), 425 size=spec.SizeReference( 426 tensor_id=spec.TensorId("image"), axis_id=spec.AxisId("x"), 427 ) 428 ) 429 ], 430 data=spec.IntervalOrRatioDataDescr(type="uint8"), 431 test_tensor=spec.FileDescr(source=result_paths["mask"]) 432 ), 433 434 # The score predictions 435 spec.OutputTensorDescr( 436 id=spec.TensorId("scores"), 437 axes=[ 438 spec.BatchAxis(size=1), 439 # NOTE: we use the data dependent size here to avoid dependency on optional inputs 440 spec.IndexOutputAxis( 441 id=spec.AxisId("object"), size=spec.DataDependentSize(), 442 ), 443 # NOTE: this could be a 3 once we use multi-masking 444 spec.ChannelAxis(channel_names=[spec.Identifier("mask")]), 445 ], 446 data=spec.IntervalOrRatioDataDescr(type="float32"), 447 test_tensor=spec.FileDescr(source=result_paths["score"]) 448 ), 449 450 # The image embeddings 451 spec.OutputTensorDescr( 452 id=spec.TensorId("embeddings"), 453 axes=[ 454 spec.BatchAxis(size=1), 455 spec.ChannelAxis(channel_names=[spec.Identifier(f"c{i}") for i in range(256)]), 456 spec.SpaceOutputAxis(id=spec.AxisId("y"), size=64), 457 spec.SpaceOutputAxis(id=spec.AxisId("x"), size=64), 458 ], 459 data=spec.IntervalOrRatioDataDescr(type="float32"), 460 test_tensor=spec.FileDescr(source=result_paths["embeddings"]) 461 ) 462 ] 463 464 architecture_path = os.path.join(os.path.split(__file__)[0], "predictor_adaptor.py") 465 architecture = spec.ArchitectureFromFileDescr( 466 source=Path(architecture_path), 467 callable="PredictorAdaptor", 468 kwargs={"model_type": model_type} 469 ) 470 471 dependency_file = os.path.join(tmp_dir, "environment.yaml") 472 _write_dependencies(dependency_file, require_mobile_sam=model_type.startswith("vit_t")) 473 474 weight_descriptions = spec.WeightsDescr( 475 pytorch_state_dict=spec.PytorchStateDictWeightsDescr( 476 source=Path(checkpoint_path), 477 architecture=architecture, 478 pytorch_version=spec.Version(torch.__version__), 479 dependencies=spec.EnvironmentFileDescr(source=dependency_file), 480 ) 481 ) 482 483 doc_path = _write_documentation(kwargs.get("documentation", None), model_type, tmp_dir) 484 485 covers = kwargs.get("covers", None) 486 if covers is None: 487 covers = _generate_covers(input_paths, result_paths, tmp_dir) 488 else: 489 assert all(os.path.exists(cov) for cov in covers) 490 491 # the uploader information is only added if explicitly passed 492 extra_kwargs = {} 493 if "id" in kwargs: 494 extra_kwargs["id"] = kwargs["id"] 495 if "id_emoji" in kwargs: 496 extra_kwargs["id_emoji"] = kwargs["id_emoji"] 497 if "uploader" in kwargs: 498 extra_kwargs["uploader"] = kwargs["uploader"] 499 if "version" in kwargs: 500 extra_kwargs["version"] = kwargs["version"] 501 502 if decoder_path is not None: 503 extra_kwargs["attachments"] = [spec.FileDescr(source=decoder_path)] 504 505 model_description = spec.ModelDescr( 506 name=name, 507 inputs=input_descriptions, 508 outputs=output_descriptions, 509 weights=weight_descriptions, 510 description=kwargs.get("description", DEFAULTS["description"]), 511 authors=kwargs.get("authors", DEFAULTS["authors"]), 512 cite=kwargs.get("cite", DEFAULTS["cite"]), 513 license=spec.LicenseId("CC-BY-4.0"), 514 documentation=Path(doc_path), 515 git_repo=spec.HttpUrl("https://github.com/computational-cell-analytics/micro-sam"), 516 tags=kwargs.get("tags", DEFAULTS["tags"]), 517 covers=covers, 518 **extra_kwargs, 519 # TODO write specific settings in the config 520 # dict with yaml values, key must be a str 521 # micro_sam: ... 522 # config= 523 ) 524 525 _check_model(model_description, input_paths, result_paths) 526 527 save_bioimageio_package(model_description, output_path=output_path)
DEFAULTS =
{'authors': [Author(affiliation='University Goettingen', email=None, orcid=None, name='Anwai Archit', github_user='anwai98'), Author(affiliation='University Goettingen', email=None, orcid=None, name='Constantin Pape', github_user='constantinpape')], 'description': 'Finetuned Segment Anything Model for Microscopy', 'cite': [CiteEntry(text='Archit et al. Segment Anything for Microscopy', doi='10.1038/s41592-024-02580-4', url=None)], 'tags': ['segment-anything', 'instance-segmentation']}
ARBITRARY_SIZE =
ParameterizedSize(min=1, step=1)
def
export_sam_model( image: numpy.ndarray, label_image: numpy.ndarray, model_type: str, name: str, output_path: Union[str, os.PathLike], checkpoint_path: Union[str, os.PathLike, NoneType] = None, **kwargs) -> None:
271def export_sam_model( 272 image: np.ndarray, 273 label_image: np.ndarray, 274 model_type: str, 275 name: str, 276 output_path: Union[str, os.PathLike], 277 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 278 **kwargs 279) -> None: 280 """Export SAM model to BioImage.IO model format. 281 282 The exported model can be uploaded to [bioimage.io](https://bioimage.io/#/) and 283 be used in tools that support the BioImage.IO model format. 284 285 Args: 286 image: The image for generating test data. 287 label_image: The segmentation correspoding to `image`. 288 It is used to derive prompt inputs for the model. 289 model_type: The type of the SAM model. 290 name: The name of the exported model. 291 output_path: Where the exported model is saved. 292 checkpoint_path: Optional checkpoint for loading the SAM model. 293 """ 294 with tempfile.TemporaryDirectory() as tmp_dir: 295 checkpoint_path, decoder_path = _get_checkpoint(model_type, checkpoint_path, tmp_dir) 296 input_paths, result_paths = _create_test_inputs_and_outputs( 297 image, label_image, model_type, checkpoint_path, tmp_dir, 298 ) 299 input_descriptions = [ 300 # First input: the image data. 301 spec.InputTensorDescr( 302 id=spec.TensorId("image"), 303 axes=[ 304 spec.BatchAxis(size=1), 305 # NOTE: to support 1 and 3 channels we can add another preprocessing. 306 # Best solution: Have a pre-processing for this! (1C -> RGB) 307 spec.ChannelAxis(channel_names=[spec.Identifier(cname) for cname in "RGB"]), 308 spec.SpaceInputAxis(id=spec.AxisId("y"), size=ARBITRARY_SIZE), 309 spec.SpaceInputAxis(id=spec.AxisId("x"), size=ARBITRARY_SIZE), 310 ], 311 test_tensor=spec.FileDescr(source=input_paths["image"]), 312 data=spec.IntervalOrRatioDataDescr(type="uint8") 313 ), 314 315 # Second input: the box prompts (optional) 316 spec.InputTensorDescr( 317 id=spec.TensorId("box_prompts"), 318 optional=True, 319 axes=[ 320 spec.BatchAxis(size=1), 321 spec.IndexInputAxis( 322 id=spec.AxisId("object"), 323 size=ARBITRARY_SIZE 324 ), 325 spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "hwxy"]), 326 ], 327 test_tensor=spec.FileDescr(source=input_paths["box_prompts"]), 328 data=spec.IntervalOrRatioDataDescr(type="int64") 329 ), 330 331 # Third input: the point prompt coordinates (optional) 332 spec.InputTensorDescr( 333 id=spec.TensorId("point_prompts"), 334 optional=True, 335 axes=[ 336 spec.BatchAxis(size=1), 337 spec.IndexInputAxis( 338 id=spec.AxisId("object"), 339 size=ARBITRARY_SIZE 340 ), 341 spec.IndexInputAxis( 342 id=spec.AxisId("point"), 343 size=ARBITRARY_SIZE 344 ), 345 spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "xy"]), 346 ], 347 test_tensor=spec.FileDescr(source=input_paths["point_prompts"]), 348 data=spec.IntervalOrRatioDataDescr(type="int64") 349 ), 350 351 # Fourth input: the point prompt labels (optional) 352 spec.InputTensorDescr( 353 id=spec.TensorId("point_labels"), 354 optional=True, 355 axes=[ 356 spec.BatchAxis(size=1), 357 spec.IndexInputAxis( 358 id=spec.AxisId("object"), 359 size=ARBITRARY_SIZE 360 ), 361 spec.IndexInputAxis( 362 id=spec.AxisId("point"), 363 size=ARBITRARY_SIZE 364 ), 365 ], 366 test_tensor=spec.FileDescr(source=input_paths["point_labels"]), 367 data=spec.IntervalOrRatioDataDescr(type="int64") 368 ), 369 370 # Fifth input: the mask prompts (optional) 371 spec.InputTensorDescr( 372 id=spec.TensorId("mask_prompts"), 373 optional=True, 374 axes=[ 375 spec.BatchAxis(size=1), 376 spec.IndexInputAxis( 377 id=spec.AxisId("object"), 378 size=ARBITRARY_SIZE 379 ), 380 spec.ChannelAxis(channel_names=["channel"]), 381 spec.SpaceInputAxis(id=spec.AxisId("y"), size=256), 382 spec.SpaceInputAxis(id=spec.AxisId("x"), size=256), 383 ], 384 test_tensor=spec.FileDescr(source=input_paths["mask_prompts"]), 385 data=spec.IntervalOrRatioDataDescr(type="float32") 386 ), 387 388 # Sixth input: the image embeddings (optional) 389 spec.InputTensorDescr( 390 id=spec.TensorId("embeddings"), 391 optional=True, 392 axes=[ 393 spec.BatchAxis(size=1), 394 # NOTE: we currently have to specify all the channel names 395 # (It would be nice to also support size) 396 spec.ChannelAxis(channel_names=[spec.Identifier(f"c{i}") for i in range(256)]), 397 spec.SpaceInputAxis(id=spec.AxisId("y"), size=64), 398 spec.SpaceInputAxis(id=spec.AxisId("x"), size=64), 399 ], 400 test_tensor=spec.FileDescr(source=result_paths["embeddings"]), 401 data=spec.IntervalOrRatioDataDescr(type="float32") 402 ), 403 404 ] 405 406 output_descriptions = [ 407 # First output: The mask predictions. 408 spec.OutputTensorDescr( 409 id=spec.TensorId("masks"), 410 axes=[ 411 spec.BatchAxis(size=1), 412 # NOTE: we use the data dependent size here to avoid dependency on optional inputs 413 spec.IndexOutputAxis( 414 id=spec.AxisId("object"), size=spec.DataDependentSize(), 415 ), 416 # NOTE: this could be a 3 once we use multi-masking 417 spec.ChannelAxis(channel_names=[spec.Identifier("mask")]), 418 spec.SpaceOutputAxis( 419 id=spec.AxisId("y"), 420 size=spec.SizeReference( 421 tensor_id=spec.TensorId("image"), axis_id=spec.AxisId("y"), 422 ) 423 ), 424 spec.SpaceOutputAxis( 425 id=spec.AxisId("x"), 426 size=spec.SizeReference( 427 tensor_id=spec.TensorId("image"), axis_id=spec.AxisId("x"), 428 ) 429 ) 430 ], 431 data=spec.IntervalOrRatioDataDescr(type="uint8"), 432 test_tensor=spec.FileDescr(source=result_paths["mask"]) 433 ), 434 435 # The score predictions 436 spec.OutputTensorDescr( 437 id=spec.TensorId("scores"), 438 axes=[ 439 spec.BatchAxis(size=1), 440 # NOTE: we use the data dependent size here to avoid dependency on optional inputs 441 spec.IndexOutputAxis( 442 id=spec.AxisId("object"), size=spec.DataDependentSize(), 443 ), 444 # NOTE: this could be a 3 once we use multi-masking 445 spec.ChannelAxis(channel_names=[spec.Identifier("mask")]), 446 ], 447 data=spec.IntervalOrRatioDataDescr(type="float32"), 448 test_tensor=spec.FileDescr(source=result_paths["score"]) 449 ), 450 451 # The image embeddings 452 spec.OutputTensorDescr( 453 id=spec.TensorId("embeddings"), 454 axes=[ 455 spec.BatchAxis(size=1), 456 spec.ChannelAxis(channel_names=[spec.Identifier(f"c{i}") for i in range(256)]), 457 spec.SpaceOutputAxis(id=spec.AxisId("y"), size=64), 458 spec.SpaceOutputAxis(id=spec.AxisId("x"), size=64), 459 ], 460 data=spec.IntervalOrRatioDataDescr(type="float32"), 461 test_tensor=spec.FileDescr(source=result_paths["embeddings"]) 462 ) 463 ] 464 465 architecture_path = os.path.join(os.path.split(__file__)[0], "predictor_adaptor.py") 466 architecture = spec.ArchitectureFromFileDescr( 467 source=Path(architecture_path), 468 callable="PredictorAdaptor", 469 kwargs={"model_type": model_type} 470 ) 471 472 dependency_file = os.path.join(tmp_dir, "environment.yaml") 473 _write_dependencies(dependency_file, require_mobile_sam=model_type.startswith("vit_t")) 474 475 weight_descriptions = spec.WeightsDescr( 476 pytorch_state_dict=spec.PytorchStateDictWeightsDescr( 477 source=Path(checkpoint_path), 478 architecture=architecture, 479 pytorch_version=spec.Version(torch.__version__), 480 dependencies=spec.EnvironmentFileDescr(source=dependency_file), 481 ) 482 ) 483 484 doc_path = _write_documentation(kwargs.get("documentation", None), model_type, tmp_dir) 485 486 covers = kwargs.get("covers", None) 487 if covers is None: 488 covers = _generate_covers(input_paths, result_paths, tmp_dir) 489 else: 490 assert all(os.path.exists(cov) for cov in covers) 491 492 # the uploader information is only added if explicitly passed 493 extra_kwargs = {} 494 if "id" in kwargs: 495 extra_kwargs["id"] = kwargs["id"] 496 if "id_emoji" in kwargs: 497 extra_kwargs["id_emoji"] = kwargs["id_emoji"] 498 if "uploader" in kwargs: 499 extra_kwargs["uploader"] = kwargs["uploader"] 500 if "version" in kwargs: 501 extra_kwargs["version"] = kwargs["version"] 502 503 if decoder_path is not None: 504 extra_kwargs["attachments"] = [spec.FileDescr(source=decoder_path)] 505 506 model_description = spec.ModelDescr( 507 name=name, 508 inputs=input_descriptions, 509 outputs=output_descriptions, 510 weights=weight_descriptions, 511 description=kwargs.get("description", DEFAULTS["description"]), 512 authors=kwargs.get("authors", DEFAULTS["authors"]), 513 cite=kwargs.get("cite", DEFAULTS["cite"]), 514 license=spec.LicenseId("CC-BY-4.0"), 515 documentation=Path(doc_path), 516 git_repo=spec.HttpUrl("https://github.com/computational-cell-analytics/micro-sam"), 517 tags=kwargs.get("tags", DEFAULTS["tags"]), 518 covers=covers, 519 **extra_kwargs, 520 # TODO write specific settings in the config 521 # dict with yaml values, key must be a str 522 # micro_sam: ... 523 # config= 524 ) 525 526 _check_model(model_description, input_paths, result_paths) 527 528 save_bioimageio_package(model_description, output_path=output_path)
Export SAM model to BioImage.IO model format.
The exported model can be uploaded to bioimage.io and be used in tools that support the BioImage.IO model format.
Arguments:
- image: The image for generating test data.
- label_image: The segmentation correspoding to
image
. It is used to derive prompt inputs for the model. - model_type: The type of the SAM model.
- name: The name of the exported model.
- output_path: Where the exported model is saved.
- checkpoint_path: Optional checkpoint for loading the SAM model.