micro_sam.bioimageio.model_export
1import os 2import tempfile 3from pathlib import Path 4from typing import Optional, Union 5 6import xarray 7import numpy as np 8import matplotlib.pyplot as plt 9 10import torch 11 12import bioimageio.core 13import bioimageio.spec.model.v0_5 as spec 14from bioimageio.spec import save_bioimageio_package 15from bioimageio.core.digest_spec import create_sample_for_model 16 17from .. import util 18from ..prompt_generators import PointAndBoxPromptGenerator 19from ..evaluation.model_comparison import _enhance_image, _overlay_outline, _overlay_box 20from ..prompt_based_segmentation import _compute_logits_from_mask 21from .predictor_adaptor import PredictorAdaptor 22 23 24DEFAULTS = { 25 "authors": [ 26 spec.Author(name="Anwai Archit", affiliation="University Goettingen", github_user="anwai98"), 27 spec.Author(name="Constantin Pape", affiliation="University Goettingen", github_user="constantinpape"), 28 ], 29 "description": "Finetuned Segment Anything Model for Microscopy", 30 "cite": [ 31 spec.CiteEntry( 32 text="Archit et al. Segment Anything for Microscopy", 33 doi=spec.Doi("10.1038/s41592-024-02580-4") 34 ), 35 ], 36 "tags": ["segment-anything", "instance-segmentation"], 37} 38 39# Reference: https://github.com/bioimage-io/spec-bioimage-io/commit/39d343681d427ec93cf69eef7597d9eb9678deb1#diff-0bbdaa8196fa31f945afabcf04a4295ff098f1f24400ef9e59b0f684d411905eL269 # noqa 40# We had this parameter in bioimageio.spec. This has been removed. We just make a copy of the same parameter. 41ARBITRARY_SIZE = spec.ParameterizedSize(min=1, step=1) 42 43 44def _create_test_inputs_and_outputs(image, labels, model_type, checkpoint_path, tmp_dir): 45 46 # For now we just generate a single box prompt here, but we could also generate more input prompts. 47 generator = PointAndBoxPromptGenerator( 48 n_positive_points=1, 49 n_negative_points=2, 50 dilation_strength=2, 51 get_point_prompts=True, 52 get_box_prompts=True, 53 ) 54 centers, bounding_boxes = util.get_centers_and_bounding_boxes(labels) 55 masks = util.segmentation_to_one_hot(labels.astype("int64"), segmentation_ids=[1, 2]) # type: ignore 56 point_prompts, point_labels, box_prompts, _ = generator(masks, [bounding_boxes[1], bounding_boxes[2]]) 57 58 box_prompts = box_prompts.numpy()[None] 59 point_prompts = point_prompts.numpy()[None] 60 point_labels = point_labels.numpy()[None] 61 62 # Generate logits from the two 63 mask_prompts = np.stack( 64 [_compute_logits_from_mask(labels == 1), _compute_logits_from_mask(labels == 2)] 65 )[None] 66 67 predictor = PredictorAdaptor(model_type=model_type) 68 predictor.load_state_dict(torch.load(checkpoint_path)) 69 70 input_ = util._to_image(image).transpose(2, 0, 1)[None] 71 image_path = os.path.join(tmp_dir, "input.npy") 72 np.save(image_path, input_) 73 74 masks, scores, embeddings = predictor( 75 image=torch.from_numpy(input_), 76 embeddings=None, 77 box_prompts=torch.from_numpy(box_prompts), 78 point_prompts=torch.from_numpy(point_prompts), 79 point_labels=torch.from_numpy(point_labels), 80 mask_prompts=torch.from_numpy(mask_prompts), 81 ) 82 83 box_prompt_path = os.path.join(tmp_dir, "box_prompts.npy") 84 point_prompt_path = os.path.join(tmp_dir, "point_prompts.npy") 85 point_label_path = os.path.join(tmp_dir, "point_labels.npy") 86 mask_prompt_path = os.path.join(tmp_dir, "mask_prompts.npy") 87 np.save(box_prompt_path, box_prompts.astype("int64")) 88 np.save(point_prompt_path, point_prompts) 89 np.save(point_label_path, point_labels) 90 np.save(mask_prompt_path, mask_prompts) 91 92 mask_path = os.path.join(tmp_dir, "mask.npy") 93 score_path = os.path.join(tmp_dir, "scores.npy") 94 embed_path = os.path.join(tmp_dir, "embeddings.npy") 95 np.save(mask_path, masks.numpy()) 96 np.save(score_path, scores.numpy()) 97 np.save(embed_path, embeddings.numpy()) 98 99 inputs = { 100 "image": image_path, 101 "box_prompts": box_prompt_path, 102 "point_prompts": point_prompt_path, 103 "point_labels": point_label_path, 104 "mask_prompts": mask_prompt_path, 105 } 106 outputs = {"mask": mask_path, "score": score_path, "embeddings": embed_path} 107 return inputs, outputs 108 109 110def _write_documentation(doc, model_type, tmp_dir): 111 tmp_doc_path = os.path.join(tmp_dir, "documentation.md") 112 113 if doc is None: 114 with open(tmp_doc_path, "w") as f: 115 f.write("# Segment Anything for Microscopy\n") 116 f.write("We extend Segment Anything, a vision foundation model for image segmentation ") 117 f.write("by training specialized models for microscopy data.\n") 118 return tmp_doc_path 119 120 elif os.path.exists(doc): 121 return doc 122 123 else: 124 with open(tmp_doc_path, "w") as f: 125 f.write(doc) 126 return tmp_doc_path 127 128 129def _get_checkpoint(model_type, checkpoint_path, tmp_dir): 130 # If we don't have a checkpoint we get the corresponding model from the registry. 131 if checkpoint_path is None: 132 model_registry = util.models() 133 checkpoint_path = model_registry.fetch(model_type) 134 return checkpoint_path, None 135 136 # Otherwise we have to load the checkpoint to see if it is the state dict of an encoder, 137 # or the checkpoint for a custom SAM model. 138 state, model_state = util._load_checkpoint(checkpoint_path) 139 140 if "model_state" in state: # This is a finetuning checkpoint -> we have to resave the state. 141 new_checkpoint_path = os.path.join(tmp_dir, f"{model_type}.pt") 142 torch.save(model_state, new_checkpoint_path) 143 144 # We may also have an instance segmentation decoder in that case. 145 # If we have it we also resave this one and return it. 146 if "decoder_state" in state: 147 decoder_path = os.path.join(tmp_dir, f"{model_type}_decoder.pt") 148 decoder_state = state["decoder_state"] 149 torch.save(decoder_state, decoder_path) 150 else: 151 decoder_path = None 152 153 return new_checkpoint_path, decoder_path 154 155 else: # This is a SAM encoder state -> we don't have to resave. 156 return checkpoint_path, None 157 158 159# TODO: Update this with our latest yaml file updates. 160def _write_dependencies(dependency_file, require_mobile_sam): 161 content = """name: sam 162channels: 163 - pytorch 164 - conda-forge 165dependencies: 166 - segment-anything""" 167 if require_mobile_sam: 168 content += """ 169 - timm 170 - pip: 171 - git+https://github.com/ChaoningZhang/MobileSAM.git""" 172 with open(dependency_file, "w") as f: 173 f.write(content) 174 175 176def _generate_covers(input_paths, result_paths, tmp_dir): 177 image = np.load(input_paths["image"]).squeeze() 178 prompts = np.load(input_paths["box_prompts"]) 179 mask = np.load(result_paths["mask"]) 180 181 # create the image overlay 182 if image.ndim == 2: 183 overlay = np.stack([image, image, image]).transpose((1, 2, 0)) 184 elif image.shape[0] == 3: 185 overlay = image.transpose((1, 2, 0)) 186 else: 187 overlay = image 188 overlay = _enhance_image(overlay.astype("float32")) 189 190 # overlay the mask as outline 191 overlay = _overlay_outline(overlay, mask[0, 0, 0], outline_dilation=2) 192 193 # overlay the bounding box prompt 194 prompt = prompts[0, 0][[1, 0, 3, 2]] 195 prompt = np.array([prompt[:2], prompt[2:]]) 196 overlay = _overlay_box(overlay, prompt, outline_dilation=4) 197 198 # write the cover image 199 fig, ax = plt.subplots(1) 200 ax.axis("off") 201 ax.imshow(overlay.astype("uint8")) 202 cover_path = os.path.join(tmp_dir, "cover.jpeg") 203 plt.savefig(cover_path, bbox_inches="tight") 204 plt.close() 205 206 covers = [cover_path] 207 return covers 208 209 210def _check_model(model_description, input_paths, result_paths): 211 # Load inputs. 212 image = xarray.DataArray(np.load(input_paths["image"]), dims=tuple("bcyx")) 213 embeddings = xarray.DataArray(np.load(result_paths["embeddings"]), dims=tuple("bcyx")) 214 box_prompts = xarray.DataArray(np.load(input_paths["box_prompts"]), dims=tuple("bic")) 215 point_prompts = xarray.DataArray(np.load(input_paths["point_prompts"]), dims=tuple("bhwc")) 216 point_labels = xarray.DataArray(np.load(input_paths["point_labels"]), dims=tuple("bic")) 217 mask_prompts = xarray.DataArray(np.load(input_paths["mask_prompts"]), dims=tuple("bicyx")) 218 219 # Load outputs. 220 mask = np.load(result_paths["mask"]) 221 222 with bioimageio.core.create_prediction_pipeline(model_description) as pp: 223 224 # Check with all prompts. We only check the result for this setting, 225 # because this was used to generate the test data. 226 sample = create_sample_for_model( 227 model=model_description, 228 image=image, 229 box_prompts=box_prompts, 230 point_prompts=point_prompts, 231 point_labels=point_labels, 232 mask_prompts=mask_prompts, 233 embeddings=embeddings, 234 ).as_single_block() 235 prediction = pp.predict_sample_block(sample) 236 237 predicted_mask = prediction.blocks["masks"].data.data 238 assert predicted_mask.shape == mask.shape 239 assert np.allclose(mask, predicted_mask) 240 241 # Run the checks with partial prompts. 242 prompt_kwargs = [ 243 # With boxes. 244 {"box_prompts": box_prompts}, 245 # With point prompts. 246 {"point_prompts": point_prompts, "point_labels": point_labels}, 247 # With masks. 248 {"mask_prompts": mask_prompts}, 249 # With boxes and points. 250 {"box_prompts": box_prompts, "point_prompts": point_prompts, "point_labels": point_labels}, 251 # With boxes and masks. 252 {"box_prompts": box_prompts, "mask_prompts": mask_prompts}, 253 # With points and masks. 254 {"mask_prompts": mask_prompts, "point_prompts": point_prompts, "point_labels": point_labels}, 255 ] 256 257 for kwargs in prompt_kwargs: 258 sample = create_sample_for_model( 259 model=model_description, image=image, embeddings=embeddings, **kwargs 260 ).as_single_block() 261 prediction = pp.predict_sample_block(sample) 262 predicted_mask = prediction.blocks["masks"].data.data 263 assert predicted_mask.shape == mask.shape 264 265 266def export_sam_model( 267 image: np.ndarray, 268 label_image: np.ndarray, 269 model_type: str, 270 name: str, 271 output_path: Union[str, os.PathLike], 272 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 273 **kwargs 274) -> None: 275 """Export SAM model to BioImage.IO model format. 276 277 The exported model can be uploaded to [bioimage.io](https://bioimage.io/#/) and 278 be used in tools that support the BioImage.IO model format. 279 280 Args: 281 image: The image for generating test data. 282 label_image: The segmentation correspoding to `image`. 283 It is used to derive prompt inputs for the model. 284 model_type: The type of the SAM model. 285 name: The name of the exported model. 286 output_path: Where the exported model is saved. 287 checkpoint_path: Optional checkpoint for loading the SAM model. 288 """ 289 with tempfile.TemporaryDirectory() as tmp_dir: 290 checkpoint_path, decoder_path = _get_checkpoint(model_type, checkpoint_path, tmp_dir) 291 input_paths, result_paths = _create_test_inputs_and_outputs( 292 image, label_image, model_type, checkpoint_path, tmp_dir, 293 ) 294 input_descriptions = [ 295 # First input: the image data. 296 spec.InputTensorDescr( 297 id=spec.TensorId("image"), 298 axes=[ 299 spec.BatchAxis(size=1), 300 # NOTE: to support 1 and 3 channels we can add another preprocessing. 301 # Best solution: Have a pre-processing for this! (1C -> RGB) 302 spec.ChannelAxis(channel_names=[spec.Identifier(cname) for cname in "RGB"]), 303 spec.SpaceInputAxis(id=spec.AxisId("y"), size=ARBITRARY_SIZE), 304 spec.SpaceInputAxis(id=spec.AxisId("x"), size=ARBITRARY_SIZE), 305 ], 306 test_tensor=spec.FileDescr(source=input_paths["image"]), 307 data=spec.IntervalOrRatioDataDescr(type="uint8") 308 ), 309 310 # Second input: the box prompts (optional) 311 spec.InputTensorDescr( 312 id=spec.TensorId("box_prompts"), 313 optional=True, 314 axes=[ 315 spec.BatchAxis(size=1), 316 spec.IndexInputAxis( 317 id=spec.AxisId("object"), 318 size=ARBITRARY_SIZE 319 ), 320 spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "hwxy"]), 321 ], 322 test_tensor=spec.FileDescr(source=input_paths["box_prompts"]), 323 data=spec.IntervalOrRatioDataDescr(type="int64") 324 ), 325 326 # Third input: the point prompt coordinates (optional) 327 spec.InputTensorDescr( 328 id=spec.TensorId("point_prompts"), 329 optional=True, 330 axes=[ 331 spec.BatchAxis(size=1), 332 spec.IndexInputAxis( 333 id=spec.AxisId("object"), 334 size=ARBITRARY_SIZE 335 ), 336 spec.IndexInputAxis( 337 id=spec.AxisId("point"), 338 size=ARBITRARY_SIZE 339 ), 340 spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "xy"]), 341 ], 342 test_tensor=spec.FileDescr(source=input_paths["point_prompts"]), 343 data=spec.IntervalOrRatioDataDescr(type="int64") 344 ), 345 346 # Fourth input: the point prompt labels (optional) 347 spec.InputTensorDescr( 348 id=spec.TensorId("point_labels"), 349 optional=True, 350 axes=[ 351 spec.BatchAxis(size=1), 352 spec.IndexInputAxis( 353 id=spec.AxisId("object"), 354 size=ARBITRARY_SIZE 355 ), 356 spec.IndexInputAxis( 357 id=spec.AxisId("point"), 358 size=ARBITRARY_SIZE 359 ), 360 ], 361 test_tensor=spec.FileDescr(source=input_paths["point_labels"]), 362 data=spec.IntervalOrRatioDataDescr(type="int64") 363 ), 364 365 # Fifth input: the mask prompts (optional) 366 spec.InputTensorDescr( 367 id=spec.TensorId("mask_prompts"), 368 optional=True, 369 axes=[ 370 spec.BatchAxis(size=1), 371 spec.IndexInputAxis( 372 id=spec.AxisId("object"), 373 size=ARBITRARY_SIZE 374 ), 375 spec.ChannelAxis(channel_names=["channel"]), 376 spec.SpaceInputAxis(id=spec.AxisId("y"), size=256), 377 spec.SpaceInputAxis(id=spec.AxisId("x"), size=256), 378 ], 379 test_tensor=spec.FileDescr(source=input_paths["mask_prompts"]), 380 data=spec.IntervalOrRatioDataDescr(type="float32") 381 ), 382 383 # Sixth input: the image embeddings (optional) 384 spec.InputTensorDescr( 385 id=spec.TensorId("embeddings"), 386 optional=True, 387 axes=[ 388 spec.BatchAxis(size=1), 389 # NOTE: we currently have to specify all the channel names 390 # (It would be nice to also support size) 391 spec.ChannelAxis(channel_names=[spec.Identifier(f"c{i}") for i in range(256)]), 392 spec.SpaceInputAxis(id=spec.AxisId("y"), size=64), 393 spec.SpaceInputAxis(id=spec.AxisId("x"), size=64), 394 ], 395 test_tensor=spec.FileDescr(source=result_paths["embeddings"]), 396 data=spec.IntervalOrRatioDataDescr(type="float32") 397 ), 398 399 ] 400 401 output_descriptions = [ 402 # First output: The mask predictions. 403 spec.OutputTensorDescr( 404 id=spec.TensorId("masks"), 405 axes=[ 406 spec.BatchAxis(size=1), 407 # NOTE: we use the data dependent size here to avoid dependency on optional inputs 408 spec.IndexOutputAxis( 409 id=spec.AxisId("object"), size=spec.DataDependentSize(), 410 ), 411 # NOTE: this could be a 3 once we use multi-masking 412 spec.ChannelAxis(channel_names=[spec.Identifier("mask")]), 413 spec.SpaceOutputAxis( 414 id=spec.AxisId("y"), 415 size=spec.SizeReference( 416 tensor_id=spec.TensorId("image"), axis_id=spec.AxisId("y"), 417 ) 418 ), 419 spec.SpaceOutputAxis( 420 id=spec.AxisId("x"), 421 size=spec.SizeReference( 422 tensor_id=spec.TensorId("image"), axis_id=spec.AxisId("x"), 423 ) 424 ) 425 ], 426 data=spec.IntervalOrRatioDataDescr(type="uint8"), 427 test_tensor=spec.FileDescr(source=result_paths["mask"]) 428 ), 429 430 # The score predictions 431 spec.OutputTensorDescr( 432 id=spec.TensorId("scores"), 433 axes=[ 434 spec.BatchAxis(size=1), 435 # NOTE: we use the data dependent size here to avoid dependency on optional inputs 436 spec.IndexOutputAxis( 437 id=spec.AxisId("object"), size=spec.DataDependentSize(), 438 ), 439 # NOTE: this could be a 3 once we use multi-masking 440 spec.ChannelAxis(channel_names=[spec.Identifier("mask")]), 441 ], 442 data=spec.IntervalOrRatioDataDescr(type="float32"), 443 test_tensor=spec.FileDescr(source=result_paths["score"]) 444 ), 445 446 # The image embeddings 447 spec.OutputTensorDescr( 448 id=spec.TensorId("embeddings"), 449 axes=[ 450 spec.BatchAxis(size=1), 451 spec.ChannelAxis(channel_names=[spec.Identifier(f"c{i}") for i in range(256)]), 452 spec.SpaceOutputAxis(id=spec.AxisId("y"), size=64), 453 spec.SpaceOutputAxis(id=spec.AxisId("x"), size=64), 454 ], 455 data=spec.IntervalOrRatioDataDescr(type="float32"), 456 test_tensor=spec.FileDescr(source=result_paths["embeddings"]) 457 ) 458 ] 459 460 architecture_path = os.path.join(os.path.split(__file__)[0], "predictor_adaptor.py") 461 architecture = spec.ArchitectureFromFileDescr( 462 source=Path(architecture_path), 463 callable="PredictorAdaptor", 464 kwargs={"model_type": model_type} 465 ) 466 467 dependency_file = os.path.join(tmp_dir, "environment.yaml") 468 _write_dependencies(dependency_file, require_mobile_sam=model_type.startswith("vit_t")) 469 470 weight_descriptions = spec.WeightsDescr( 471 pytorch_state_dict=spec.PytorchStateDictWeightsDescr( 472 source=Path(checkpoint_path), 473 architecture=architecture, 474 pytorch_version=spec.Version(torch.__version__), 475 dependencies=spec.EnvironmentFileDescr(source=dependency_file), 476 ) 477 ) 478 479 doc_path = _write_documentation(kwargs.get("documentation", None), model_type, tmp_dir) 480 481 covers = kwargs.get("covers", None) 482 if covers is None: 483 covers = _generate_covers(input_paths, result_paths, tmp_dir) 484 else: 485 assert all(os.path.exists(cov) for cov in covers) 486 487 # the uploader information is only added if explicitly passed 488 extra_kwargs = {} 489 if "id" in kwargs: 490 extra_kwargs["id"] = kwargs["id"] 491 if "id_emoji" in kwargs: 492 extra_kwargs["id_emoji"] = kwargs["id_emoji"] 493 if "uploader" in kwargs: 494 extra_kwargs["uploader"] = kwargs["uploader"] 495 if "version" in kwargs: 496 extra_kwargs["version"] = kwargs["version"] 497 498 if decoder_path is not None: 499 extra_kwargs["attachments"] = [spec.FileDescr(source=decoder_path)] 500 501 model_description = spec.ModelDescr( 502 name=name, 503 inputs=input_descriptions, 504 outputs=output_descriptions, 505 weights=weight_descriptions, 506 description=kwargs.get("description", DEFAULTS["description"]), 507 authors=kwargs.get("authors", DEFAULTS["authors"]), 508 cite=kwargs.get("cite", DEFAULTS["cite"]), 509 license=spec.LicenseId("CC-BY-4.0"), 510 documentation=Path(doc_path), 511 git_repo=spec.HttpUrl("https://github.com/computational-cell-analytics/micro-sam"), 512 tags=kwargs.get("tags", DEFAULTS["tags"]), 513 covers=covers, 514 **extra_kwargs, 515 # TODO write specific settings in the config 516 # dict with yaml values, key must be a str 517 # micro_sam: ... 518 # config= 519 ) 520 521 _check_model(model_description, input_paths, result_paths) 522 523 save_bioimageio_package(model_description, output_path=output_path)
DEFAULTS =
{'authors': [Author(affiliation='University Goettingen', email=None, orcid=None, name='Anwai Archit', github_user='anwai98'), Author(affiliation='University Goettingen', email=None, orcid=None, name='Constantin Pape', github_user='constantinpape')], 'description': 'Finetuned Segment Anything Model for Microscopy', 'cite': [CiteEntry(text='Archit et al. Segment Anything for Microscopy', doi='10.1038/s41592-024-02580-4', url=None)], 'tags': ['segment-anything', 'instance-segmentation']}
ARBITRARY_SIZE =
ParameterizedSize(min=1, step=1)
def
export_sam_model( image: numpy.ndarray, label_image: numpy.ndarray, model_type: str, name: str, output_path: Union[str, os.PathLike], checkpoint_path: Union[str, os.PathLike, NoneType] = None, **kwargs) -> None:
267def export_sam_model( 268 image: np.ndarray, 269 label_image: np.ndarray, 270 model_type: str, 271 name: str, 272 output_path: Union[str, os.PathLike], 273 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 274 **kwargs 275) -> None: 276 """Export SAM model to BioImage.IO model format. 277 278 The exported model can be uploaded to [bioimage.io](https://bioimage.io/#/) and 279 be used in tools that support the BioImage.IO model format. 280 281 Args: 282 image: The image for generating test data. 283 label_image: The segmentation correspoding to `image`. 284 It is used to derive prompt inputs for the model. 285 model_type: The type of the SAM model. 286 name: The name of the exported model. 287 output_path: Where the exported model is saved. 288 checkpoint_path: Optional checkpoint for loading the SAM model. 289 """ 290 with tempfile.TemporaryDirectory() as tmp_dir: 291 checkpoint_path, decoder_path = _get_checkpoint(model_type, checkpoint_path, tmp_dir) 292 input_paths, result_paths = _create_test_inputs_and_outputs( 293 image, label_image, model_type, checkpoint_path, tmp_dir, 294 ) 295 input_descriptions = [ 296 # First input: the image data. 297 spec.InputTensorDescr( 298 id=spec.TensorId("image"), 299 axes=[ 300 spec.BatchAxis(size=1), 301 # NOTE: to support 1 and 3 channels we can add another preprocessing. 302 # Best solution: Have a pre-processing for this! (1C -> RGB) 303 spec.ChannelAxis(channel_names=[spec.Identifier(cname) for cname in "RGB"]), 304 spec.SpaceInputAxis(id=spec.AxisId("y"), size=ARBITRARY_SIZE), 305 spec.SpaceInputAxis(id=spec.AxisId("x"), size=ARBITRARY_SIZE), 306 ], 307 test_tensor=spec.FileDescr(source=input_paths["image"]), 308 data=spec.IntervalOrRatioDataDescr(type="uint8") 309 ), 310 311 # Second input: the box prompts (optional) 312 spec.InputTensorDescr( 313 id=spec.TensorId("box_prompts"), 314 optional=True, 315 axes=[ 316 spec.BatchAxis(size=1), 317 spec.IndexInputAxis( 318 id=spec.AxisId("object"), 319 size=ARBITRARY_SIZE 320 ), 321 spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "hwxy"]), 322 ], 323 test_tensor=spec.FileDescr(source=input_paths["box_prompts"]), 324 data=spec.IntervalOrRatioDataDescr(type="int64") 325 ), 326 327 # Third input: the point prompt coordinates (optional) 328 spec.InputTensorDescr( 329 id=spec.TensorId("point_prompts"), 330 optional=True, 331 axes=[ 332 spec.BatchAxis(size=1), 333 spec.IndexInputAxis( 334 id=spec.AxisId("object"), 335 size=ARBITRARY_SIZE 336 ), 337 spec.IndexInputAxis( 338 id=spec.AxisId("point"), 339 size=ARBITRARY_SIZE 340 ), 341 spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "xy"]), 342 ], 343 test_tensor=spec.FileDescr(source=input_paths["point_prompts"]), 344 data=spec.IntervalOrRatioDataDescr(type="int64") 345 ), 346 347 # Fourth input: the point prompt labels (optional) 348 spec.InputTensorDescr( 349 id=spec.TensorId("point_labels"), 350 optional=True, 351 axes=[ 352 spec.BatchAxis(size=1), 353 spec.IndexInputAxis( 354 id=spec.AxisId("object"), 355 size=ARBITRARY_SIZE 356 ), 357 spec.IndexInputAxis( 358 id=spec.AxisId("point"), 359 size=ARBITRARY_SIZE 360 ), 361 ], 362 test_tensor=spec.FileDescr(source=input_paths["point_labels"]), 363 data=spec.IntervalOrRatioDataDescr(type="int64") 364 ), 365 366 # Fifth input: the mask prompts (optional) 367 spec.InputTensorDescr( 368 id=spec.TensorId("mask_prompts"), 369 optional=True, 370 axes=[ 371 spec.BatchAxis(size=1), 372 spec.IndexInputAxis( 373 id=spec.AxisId("object"), 374 size=ARBITRARY_SIZE 375 ), 376 spec.ChannelAxis(channel_names=["channel"]), 377 spec.SpaceInputAxis(id=spec.AxisId("y"), size=256), 378 spec.SpaceInputAxis(id=spec.AxisId("x"), size=256), 379 ], 380 test_tensor=spec.FileDescr(source=input_paths["mask_prompts"]), 381 data=spec.IntervalOrRatioDataDescr(type="float32") 382 ), 383 384 # Sixth input: the image embeddings (optional) 385 spec.InputTensorDescr( 386 id=spec.TensorId("embeddings"), 387 optional=True, 388 axes=[ 389 spec.BatchAxis(size=1), 390 # NOTE: we currently have to specify all the channel names 391 # (It would be nice to also support size) 392 spec.ChannelAxis(channel_names=[spec.Identifier(f"c{i}") for i in range(256)]), 393 spec.SpaceInputAxis(id=spec.AxisId("y"), size=64), 394 spec.SpaceInputAxis(id=spec.AxisId("x"), size=64), 395 ], 396 test_tensor=spec.FileDescr(source=result_paths["embeddings"]), 397 data=spec.IntervalOrRatioDataDescr(type="float32") 398 ), 399 400 ] 401 402 output_descriptions = [ 403 # First output: The mask predictions. 404 spec.OutputTensorDescr( 405 id=spec.TensorId("masks"), 406 axes=[ 407 spec.BatchAxis(size=1), 408 # NOTE: we use the data dependent size here to avoid dependency on optional inputs 409 spec.IndexOutputAxis( 410 id=spec.AxisId("object"), size=spec.DataDependentSize(), 411 ), 412 # NOTE: this could be a 3 once we use multi-masking 413 spec.ChannelAxis(channel_names=[spec.Identifier("mask")]), 414 spec.SpaceOutputAxis( 415 id=spec.AxisId("y"), 416 size=spec.SizeReference( 417 tensor_id=spec.TensorId("image"), axis_id=spec.AxisId("y"), 418 ) 419 ), 420 spec.SpaceOutputAxis( 421 id=spec.AxisId("x"), 422 size=spec.SizeReference( 423 tensor_id=spec.TensorId("image"), axis_id=spec.AxisId("x"), 424 ) 425 ) 426 ], 427 data=spec.IntervalOrRatioDataDescr(type="uint8"), 428 test_tensor=spec.FileDescr(source=result_paths["mask"]) 429 ), 430 431 # The score predictions 432 spec.OutputTensorDescr( 433 id=spec.TensorId("scores"), 434 axes=[ 435 spec.BatchAxis(size=1), 436 # NOTE: we use the data dependent size here to avoid dependency on optional inputs 437 spec.IndexOutputAxis( 438 id=spec.AxisId("object"), size=spec.DataDependentSize(), 439 ), 440 # NOTE: this could be a 3 once we use multi-masking 441 spec.ChannelAxis(channel_names=[spec.Identifier("mask")]), 442 ], 443 data=spec.IntervalOrRatioDataDescr(type="float32"), 444 test_tensor=spec.FileDescr(source=result_paths["score"]) 445 ), 446 447 # The image embeddings 448 spec.OutputTensorDescr( 449 id=spec.TensorId("embeddings"), 450 axes=[ 451 spec.BatchAxis(size=1), 452 spec.ChannelAxis(channel_names=[spec.Identifier(f"c{i}") for i in range(256)]), 453 spec.SpaceOutputAxis(id=spec.AxisId("y"), size=64), 454 spec.SpaceOutputAxis(id=spec.AxisId("x"), size=64), 455 ], 456 data=spec.IntervalOrRatioDataDescr(type="float32"), 457 test_tensor=spec.FileDescr(source=result_paths["embeddings"]) 458 ) 459 ] 460 461 architecture_path = os.path.join(os.path.split(__file__)[0], "predictor_adaptor.py") 462 architecture = spec.ArchitectureFromFileDescr( 463 source=Path(architecture_path), 464 callable="PredictorAdaptor", 465 kwargs={"model_type": model_type} 466 ) 467 468 dependency_file = os.path.join(tmp_dir, "environment.yaml") 469 _write_dependencies(dependency_file, require_mobile_sam=model_type.startswith("vit_t")) 470 471 weight_descriptions = spec.WeightsDescr( 472 pytorch_state_dict=spec.PytorchStateDictWeightsDescr( 473 source=Path(checkpoint_path), 474 architecture=architecture, 475 pytorch_version=spec.Version(torch.__version__), 476 dependencies=spec.EnvironmentFileDescr(source=dependency_file), 477 ) 478 ) 479 480 doc_path = _write_documentation(kwargs.get("documentation", None), model_type, tmp_dir) 481 482 covers = kwargs.get("covers", None) 483 if covers is None: 484 covers = _generate_covers(input_paths, result_paths, tmp_dir) 485 else: 486 assert all(os.path.exists(cov) for cov in covers) 487 488 # the uploader information is only added if explicitly passed 489 extra_kwargs = {} 490 if "id" in kwargs: 491 extra_kwargs["id"] = kwargs["id"] 492 if "id_emoji" in kwargs: 493 extra_kwargs["id_emoji"] = kwargs["id_emoji"] 494 if "uploader" in kwargs: 495 extra_kwargs["uploader"] = kwargs["uploader"] 496 if "version" in kwargs: 497 extra_kwargs["version"] = kwargs["version"] 498 499 if decoder_path is not None: 500 extra_kwargs["attachments"] = [spec.FileDescr(source=decoder_path)] 501 502 model_description = spec.ModelDescr( 503 name=name, 504 inputs=input_descriptions, 505 outputs=output_descriptions, 506 weights=weight_descriptions, 507 description=kwargs.get("description", DEFAULTS["description"]), 508 authors=kwargs.get("authors", DEFAULTS["authors"]), 509 cite=kwargs.get("cite", DEFAULTS["cite"]), 510 license=spec.LicenseId("CC-BY-4.0"), 511 documentation=Path(doc_path), 512 git_repo=spec.HttpUrl("https://github.com/computational-cell-analytics/micro-sam"), 513 tags=kwargs.get("tags", DEFAULTS["tags"]), 514 covers=covers, 515 **extra_kwargs, 516 # TODO write specific settings in the config 517 # dict with yaml values, key must be a str 518 # micro_sam: ... 519 # config= 520 ) 521 522 _check_model(model_description, input_paths, result_paths) 523 524 save_bioimageio_package(model_description, output_path=output_path)
Export SAM model to BioImage.IO model format.
The exported model can be uploaded to bioimage.io and be used in tools that support the BioImage.IO model format.
Arguments:
- image: The image for generating test data.
- label_image: The segmentation correspoding to
image
. It is used to derive prompt inputs for the model. - model_type: The type of the SAM model.
- name: The name of the exported model.
- output_path: Where the exported model is saved.
- checkpoint_path: Optional checkpoint for loading the SAM model.