micro_sam.bioimageio.model_export
1import os 2import tempfile 3from pathlib import Path 4from typing import Optional, Union 5 6import xarray 7import numpy as np 8import matplotlib.pyplot as plt 9 10import torch 11 12import bioimageio.core 13import bioimageio.spec.model.v0_5 as spec 14from bioimageio.spec import save_bioimageio_package 15from bioimageio.core.digest_spec import create_sample_for_model 16 17from .. import util 18from ..prompt_generators import PointAndBoxPromptGenerator 19from ..evaluation.model_comparison import _enhance_image, _overlay_outline, _overlay_box 20from ..prompt_based_segmentation import _compute_logits_from_mask 21from .predictor_adaptor import PredictorAdaptor 22 23 24DEFAULTS = { 25 "authors": [ 26 spec.Author(name="Anwai Archit", affiliation="University Goettingen", github_user="anwai98"), 27 spec.Author(name="Constantin Pape", affiliation="University Goettingen", github_user="constantinpape"), 28 ], 29 "description": "Finetuned Segment Anything Model for Microscopy", 30 "cite": [ 31 spec.CiteEntry(text="Archit et al. Segment Anything for Microscopy", doi=spec.Doi("10.1101/2023.08.21.554208")), 32 ], 33 "tags": ["segment-anything", "instance-segmentation"], 34} 35 36 37def _create_test_inputs_and_outputs( 38 image, 39 labels, 40 model_type, 41 checkpoint_path, 42 tmp_dir, 43): 44 # For now we just generate a single box prompt here, but we could also generate more input prompts. 45 generator = PointAndBoxPromptGenerator( 46 n_positive_points=1, 47 n_negative_points=2, 48 dilation_strength=2, 49 get_point_prompts=True, 50 get_box_prompts=True, 51 ) 52 centers, bounding_boxes = util.get_centers_and_bounding_boxes(labels) 53 masks = util.segmentation_to_one_hot(labels.astype("int64"), segmentation_ids=[1, 2]) # type: ignore 54 point_prompts, point_labels, box_prompts, _ = generator(masks, [bounding_boxes[1], bounding_boxes[2]]) 55 56 box_prompts = box_prompts.numpy()[None] 57 point_prompts = point_prompts.numpy()[None] 58 point_labels = point_labels.numpy()[None] 59 60 # Generate logits from the two 61 mask_prompts = np.stack( 62 [ 63 _compute_logits_from_mask(labels == 1), 64 _compute_logits_from_mask(labels == 2), 65 ] 66 )[None] 67 68 predictor = PredictorAdaptor(model_type=model_type) 69 predictor.load_state_dict(torch.load(checkpoint_path)) 70 71 input_ = util._to_image(image).transpose(2, 0, 1)[None] 72 image_path = os.path.join(tmp_dir, "input.npy") 73 np.save(image_path, input_) 74 75 masks, scores, embeddings = predictor( 76 image=torch.from_numpy(input_), 77 embeddings=None, 78 box_prompts=torch.from_numpy(box_prompts), 79 point_prompts=torch.from_numpy(point_prompts), 80 point_labels=torch.from_numpy(point_labels), 81 mask_prompts=torch.from_numpy(mask_prompts), 82 ) 83 84 box_prompt_path = os.path.join(tmp_dir, "box_prompts.npy") 85 point_prompt_path = os.path.join(tmp_dir, "point_prompts.npy") 86 point_label_path = os.path.join(tmp_dir, "point_labels.npy") 87 mask_prompt_path = os.path.join(tmp_dir, "mask_prompts.npy") 88 np.save(box_prompt_path, box_prompts.astype("int64")) 89 np.save(point_prompt_path, point_prompts) 90 np.save(point_label_path, point_labels) 91 np.save(mask_prompt_path, mask_prompts) 92 93 mask_path = os.path.join(tmp_dir, "mask.npy") 94 score_path = os.path.join(tmp_dir, "scores.npy") 95 embed_path = os.path.join(tmp_dir, "embeddings.npy") 96 np.save(mask_path, masks.numpy()) 97 np.save(score_path, scores.numpy()) 98 np.save(embed_path, embeddings.numpy()) 99 100 inputs = { 101 "image": image_path, 102 "box_prompts": box_prompt_path, 103 "point_prompts": point_prompt_path, 104 "point_labels": point_label_path, 105 "mask_prompts": mask_prompt_path, 106 } 107 outputs = { 108 "mask": mask_path, 109 "score": score_path, 110 "embeddings": embed_path 111 } 112 return inputs, outputs 113 114 115def _write_documentation(doc, model_type, tmp_dir): 116 tmp_doc_path = os.path.join(tmp_dir, "documentation.md") 117 118 if doc is None: 119 with open(tmp_doc_path, "w") as f: 120 f.write("# Segment Anything for Microscopy\n") 121 f.write("We extend Segment Anything, a vision foundation model for image segmentation ") 122 f.write("by training specialized models for microscopy data.\n") 123 return tmp_doc_path 124 125 elif os.path.exists(doc): 126 return doc 127 128 else: 129 with open(tmp_doc_path, "w") as f: 130 f.write(doc) 131 return tmp_doc_path 132 133 134def _get_checkpoint(model_type, checkpoint_path, tmp_dir): 135 # If we don't have a checkpoint we get the corresponding model from the registry. 136 if checkpoint_path is None: 137 model_registry = util.models() 138 checkpoint_path = model_registry.fetch(model_type) 139 return checkpoint_path, None 140 141 # Otherwise we have to load the checkpoint to see if it is the state dict of an encoder, 142 # or the checkpoint for a custom SAM model. 143 state, model_state = util._load_checkpoint(checkpoint_path) 144 145 if "model_state" in state: # This is a finetuning checkpoint -> we have to resave the state. 146 new_checkpoint_path = os.path.join(tmp_dir, f"{model_type}.pt") 147 torch.save(model_state, new_checkpoint_path) 148 149 # We may also have an instance segmentation decoder in that case. 150 # If we have it we also resave this one and return it. 151 if "decoder_state" in state: 152 decoder_path = os.path.join(tmp_dir, f"{model_type}_decoder.pt") 153 decoder_state = state["decoder_state"] 154 torch.save(decoder_state, decoder_path) 155 else: 156 decoder_path = None 157 158 return new_checkpoint_path, decoder_path 159 160 else: # This is a SAM encoder state -> we don't have to resave. 161 return checkpoint_path, None 162 163 164def _write_dependencies(dependency_file, require_mobile_sam): 165 content = """name: sam 166channels: 167 - pytorch 168 - conda-forge 169dependencies: 170 - segment-anything""" 171 if require_mobile_sam: 172 content += """ 173 - pip: 174 - git+https://github.com/ChaoningZhang/MobileSAM.git""" 175 with open(dependency_file, "w") as f: 176 f.write(content) 177 178 179def _generate_covers(input_paths, result_paths, tmp_dir): 180 image = np.load(input_paths["image"]).squeeze() 181 prompts = np.load(input_paths["box_prompts"]) 182 mask = np.load(result_paths["mask"]) 183 184 # create the image overlay 185 if image.ndim == 2: 186 overlay = np.stack([image, image, image]).transpose((1, 2, 0)) 187 elif image.shape[0] == 3: 188 overlay = image.transpose((1, 2, 0)) 189 else: 190 overlay = image 191 overlay = _enhance_image(overlay.astype("float32")) 192 193 # overlay the mask as outline 194 overlay = _overlay_outline(overlay, mask[0, 0, 0], outline_dilation=2) 195 196 # overlay the bounding box prompt 197 prompt = prompts[0, 0][[1, 0, 3, 2]] 198 prompt = np.array([prompt[:2], prompt[2:]]) 199 overlay = _overlay_box(overlay, prompt, outline_dilation=4) 200 201 # write the cover image 202 fig, ax = plt.subplots(1) 203 ax.axis("off") 204 ax.imshow(overlay.astype("uint8")) 205 cover_path = os.path.join(tmp_dir, "cover.jpeg") 206 plt.savefig(cover_path, bbox_inches="tight") 207 plt.close() 208 209 covers = [cover_path] 210 return covers 211 212 213def _check_model(model_description, input_paths, result_paths): 214 # Load inputs. 215 image = xarray.DataArray(np.load(input_paths["image"]), dims=tuple("bcyx")) 216 embeddings = xarray.DataArray(np.load(result_paths["embeddings"]), dims=tuple("bcyx")) 217 box_prompts = xarray.DataArray(np.load(input_paths["box_prompts"]), dims=tuple("bic")) 218 point_prompts = xarray.DataArray(np.load(input_paths["point_prompts"]), dims=tuple("biic")) 219 point_labels = xarray.DataArray(np.load(input_paths["point_labels"]), dims=tuple("bic")) 220 mask_prompts = xarray.DataArray(np.load(input_paths["mask_prompts"]), dims=tuple("bicyx")) 221 222 # Load outputs. 223 mask = np.load(result_paths["mask"]) 224 225 with bioimageio.core.create_prediction_pipeline(model_description) as pp: 226 227 # Check with all prompts. We only check the result for this setting, 228 # because this was used to generate the test data. 229 sample = create_sample_for_model( 230 model=model_description, 231 image=image, 232 box_prompts=box_prompts, 233 point_prompts=point_prompts, 234 point_labels=point_labels, 235 mask_prompts=mask_prompts, 236 embeddings=embeddings, 237 ).as_single_block() 238 prediction = pp.predict_sample_block(sample) 239 240 predicted_mask = prediction.blocks["masks"].data.data 241 assert predicted_mask.shape == mask.shape 242 assert np.allclose(mask, predicted_mask) 243 244 # Run the checks with partial prompts. 245 prompt_kwargs = [ 246 # With boxes. 247 {"box_prompts": box_prompts}, 248 # With point prompts. 249 {"point_prompts": point_prompts, "point_labels": point_labels}, 250 # With masks. 251 {"mask_prompts": mask_prompts}, 252 # With boxes and points. 253 {"box_prompts": box_prompts, "point_prompts": point_prompts, "point_labels": point_labels}, 254 # With boxes and masks. 255 {"box_prompts": box_prompts, "mask_prompts": mask_prompts}, 256 # With points and masks. 257 {"mask_prompts": mask_prompts, "point_prompts": point_prompts, "point_labels": point_labels}, 258 ] 259 260 for kwargs in prompt_kwargs: 261 sample = create_sample_for_model( 262 model=model_description, image=image, embeddings=embeddings, **kwargs 263 ).as_single_block() 264 prediction = pp.predict_sample_block(sample) 265 predicted_mask = prediction.blocks["masks"].data.data 266 assert predicted_mask.shape == mask.shape 267 268 269def export_sam_model( 270 image: np.ndarray, 271 label_image: np.ndarray, 272 model_type: str, 273 name: str, 274 output_path: Union[str, os.PathLike], 275 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 276 **kwargs 277) -> None: 278 """Export SAM model to BioImage.IO model format. 279 280 The exported model can be uploaded to [bioimage.io](https://bioimage.io/#/) and 281 be used in tools that support the BioImage.IO model format. 282 283 Args: 284 image: The image for generating test data. 285 label_image: The segmentation correspoding to `image`. 286 It is used to derive prompt inputs for the model. 287 model_type: The type of the SAM model. 288 name: The name of the exported model. 289 output_path: Where the exported model is saved. 290 checkpoint_path: Optional checkpoint for loading the SAM model. 291 """ 292 with tempfile.TemporaryDirectory() as tmp_dir: 293 checkpoint_path, decoder_path = _get_checkpoint(model_type, checkpoint_path, tmp_dir) 294 input_paths, result_paths = _create_test_inputs_and_outputs( 295 image, label_image, model_type, checkpoint_path, tmp_dir, 296 ) 297 input_descriptions = [ 298 # First input: the image data. 299 spec.InputTensorDescr( 300 id=spec.TensorId("image"), 301 axes=[ 302 spec.BatchAxis(size=1), 303 # NOTE: to support 1 and 3 channels we can add another preprocessing. 304 # Best solution: Have a pre-processing for this! (1C -> RGB) 305 spec.ChannelAxis(channel_names=[spec.Identifier(cname) for cname in "RGB"]), 306 spec.SpaceInputAxis(id=spec.AxisId("y"), size=spec.ARBITRARY_SIZE), 307 spec.SpaceInputAxis(id=spec.AxisId("x"), size=spec.ARBITRARY_SIZE), 308 ], 309 test_tensor=spec.FileDescr(source=input_paths["image"]), 310 data=spec.IntervalOrRatioDataDescr(type="uint8") 311 ), 312 313 # Second input: the box prompts (optional) 314 spec.InputTensorDescr( 315 id=spec.TensorId("box_prompts"), 316 optional=True, 317 axes=[ 318 spec.BatchAxis(size=1), 319 spec.IndexInputAxis( 320 id=spec.AxisId("object"), 321 size=spec.ARBITRARY_SIZE 322 ), 323 spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "hwxy"]), 324 ], 325 test_tensor=spec.FileDescr(source=input_paths["box_prompts"]), 326 data=spec.IntervalOrRatioDataDescr(type="int64") 327 ), 328 329 # Third input: the point prompt coordinates (optional) 330 spec.InputTensorDescr( 331 id=spec.TensorId("point_prompts"), 332 optional=True, 333 axes=[ 334 spec.BatchAxis(size=1), 335 spec.IndexInputAxis( 336 id=spec.AxisId("object"), 337 size=spec.ARBITRARY_SIZE 338 ), 339 spec.IndexInputAxis( 340 id=spec.AxisId("point"), 341 size=spec.ARBITRARY_SIZE 342 ), 343 spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "xy"]), 344 ], 345 test_tensor=spec.FileDescr(source=input_paths["point_prompts"]), 346 data=spec.IntervalOrRatioDataDescr(type="int64") 347 ), 348 349 # Fourth input: the point prompt labels (optional) 350 spec.InputTensorDescr( 351 id=spec.TensorId("point_labels"), 352 optional=True, 353 axes=[ 354 spec.BatchAxis(size=1), 355 spec.IndexInputAxis( 356 id=spec.AxisId("object"), 357 size=spec.ARBITRARY_SIZE 358 ), 359 spec.IndexInputAxis( 360 id=spec.AxisId("point"), 361 size=spec.ARBITRARY_SIZE 362 ), 363 ], 364 test_tensor=spec.FileDescr(source=input_paths["point_labels"]), 365 data=spec.IntervalOrRatioDataDescr(type="int64") 366 ), 367 368 # Fifth input: the mask prompts (optional) 369 spec.InputTensorDescr( 370 id=spec.TensorId("mask_prompts"), 371 optional=True, 372 axes=[ 373 spec.BatchAxis(size=1), 374 spec.IndexInputAxis( 375 id=spec.AxisId("object"), 376 size=spec.ARBITRARY_SIZE 377 ), 378 spec.ChannelAxis(channel_names=["channel"]), 379 spec.SpaceInputAxis(id=spec.AxisId("y"), size=256), 380 spec.SpaceInputAxis(id=spec.AxisId("x"), size=256), 381 ], 382 test_tensor=spec.FileDescr(source=input_paths["mask_prompts"]), 383 data=spec.IntervalOrRatioDataDescr(type="float32") 384 ), 385 386 # Sixth input: the image embeddings (optional) 387 spec.InputTensorDescr( 388 id=spec.TensorId("embeddings"), 389 optional=True, 390 axes=[ 391 spec.BatchAxis(size=1), 392 # NOTE: we currently have to specify all the channel names 393 # (It would be nice to also support size) 394 spec.ChannelAxis(channel_names=[spec.Identifier(f"c{i}") for i in range(256)]), 395 spec.SpaceInputAxis(id=spec.AxisId("y"), size=64), 396 spec.SpaceInputAxis(id=spec.AxisId("x"), size=64), 397 ], 398 test_tensor=spec.FileDescr(source=result_paths["embeddings"]), 399 data=spec.IntervalOrRatioDataDescr(type="float32") 400 ), 401 402 ] 403 404 output_descriptions = [ 405 # First output: The mask predictions. 406 spec.OutputTensorDescr( 407 id=spec.TensorId("masks"), 408 axes=[ 409 spec.BatchAxis(size=1), 410 # NOTE: we use the data dependent size here to avoid dependency on optional inputs 411 spec.IndexOutputAxis( 412 id=spec.AxisId("object"), size=spec.DataDependentSize(), 413 ), 414 # NOTE: this could be a 3 once we use multi-masking 415 spec.ChannelAxis(channel_names=[spec.Identifier("mask")]), 416 spec.SpaceOutputAxis( 417 id=spec.AxisId("y"), 418 size=spec.SizeReference( 419 tensor_id=spec.TensorId("image"), axis_id=spec.AxisId("y"), 420 ) 421 ), 422 spec.SpaceOutputAxis( 423 id=spec.AxisId("x"), 424 size=spec.SizeReference( 425 tensor_id=spec.TensorId("image"), axis_id=spec.AxisId("x"), 426 ) 427 ) 428 ], 429 data=spec.IntervalOrRatioDataDescr(type="uint8"), 430 test_tensor=spec.FileDescr(source=result_paths["mask"]) 431 ), 432 433 # The score predictions 434 spec.OutputTensorDescr( 435 id=spec.TensorId("scores"), 436 axes=[ 437 spec.BatchAxis(size=1), 438 # NOTE: we use the data dependent size here to avoid dependency on optional inputs 439 spec.IndexOutputAxis( 440 id=spec.AxisId("object"), size=spec.DataDependentSize(), 441 ), 442 # NOTE: this could be a 3 once we use multi-masking 443 spec.ChannelAxis(channel_names=[spec.Identifier("mask")]), 444 ], 445 data=spec.IntervalOrRatioDataDescr(type="float32"), 446 test_tensor=spec.FileDescr(source=result_paths["score"]) 447 ), 448 449 # The image embeddings 450 spec.OutputTensorDescr( 451 id=spec.TensorId("embeddings"), 452 axes=[ 453 spec.BatchAxis(size=1), 454 spec.ChannelAxis(channel_names=[spec.Identifier(f"c{i}") for i in range(256)]), 455 spec.SpaceOutputAxis(id=spec.AxisId("y"), size=64), 456 spec.SpaceOutputAxis(id=spec.AxisId("x"), size=64), 457 ], 458 data=spec.IntervalOrRatioDataDescr(type="float32"), 459 test_tensor=spec.FileDescr(source=result_paths["embeddings"]) 460 ) 461 ] 462 463 architecture_path = os.path.join(os.path.split(__file__)[0], "predictor_adaptor.py") 464 architecture = spec.ArchitectureFromFileDescr( 465 source=Path(architecture_path), 466 callable="PredictorAdaptor", 467 kwargs={"model_type": model_type} 468 ) 469 470 dependency_file = os.path.join(tmp_dir, "environment.yaml") 471 _write_dependencies(dependency_file, require_mobile_sam=model_type.startswith("vit_t")) 472 473 weight_descriptions = spec.WeightsDescr( 474 pytorch_state_dict=spec.PytorchStateDictWeightsDescr( 475 source=Path(checkpoint_path), 476 architecture=architecture, 477 pytorch_version=spec.Version(torch.__version__), 478 dependencies=spec.EnvironmentFileDescr(source=dependency_file), 479 ) 480 ) 481 482 doc_path = _write_documentation(kwargs.get("documentation", None), model_type, tmp_dir) 483 484 covers = kwargs.get("covers", None) 485 if covers is None: 486 covers = _generate_covers(input_paths, result_paths, tmp_dir) 487 else: 488 assert all(os.path.exists(cov) for cov in covers) 489 490 # the uploader information is only added if explicitly passed 491 extra_kwargs = {} 492 if "id" in kwargs: 493 extra_kwargs["id"] = kwargs["id"] 494 if "id_emoji" in kwargs: 495 extra_kwargs["id_emoji"] = kwargs["id_emoji"] 496 if "uploader" in kwargs: 497 extra_kwargs["uploader"] = kwargs["uploader"] 498 499 if decoder_path is not None: 500 extra_kwargs["attachments"] = [spec.FileDescr(source=decoder_path)] 501 502 model_description = spec.ModelDescr( 503 name=name, 504 inputs=input_descriptions, 505 outputs=output_descriptions, 506 weights=weight_descriptions, 507 description=kwargs.get("description", DEFAULTS["description"]), 508 authors=kwargs.get("authors", DEFAULTS["authors"]), 509 cite=kwargs.get("cite", DEFAULTS["cite"]), 510 license=spec.LicenseId("CC-BY-4.0"), 511 documentation=Path(doc_path), 512 git_repo=spec.HttpUrl("https://github.com/computational-cell-analytics/micro-sam"), 513 tags=kwargs.get("tags", DEFAULTS["tags"]), 514 covers=covers, 515 **extra_kwargs, 516 # TODO write specific settings in the config 517 # dict with yaml values, key must be a str 518 # micro_sam: ... 519 # config= 520 ) 521 522 _check_model(model_description, input_paths, result_paths) 523 524 save_bioimageio_package(model_description, output_path=output_path)
DEFAULTS =
{'authors': [Author(affiliation='University Goettingen', email=None, orcid=None, name='Anwai Archit', github_user='anwai98'), Author(affiliation='University Goettingen', email=None, orcid=None, name='Constantin Pape', github_user='constantinpape')], 'description': 'Finetuned Segment Anything Model for Microscopy', 'cite': [CiteEntry(text='Archit et al. Segment Anything for Microscopy', doi='10.1101/2023.08.21.554208', url=None)], 'tags': ['segment-anything', 'instance-segmentation']}
def
export_sam_model( image: numpy.ndarray, label_image: numpy.ndarray, model_type: str, name: str, output_path: Union[str, os.PathLike], checkpoint_path: Union[str, os.PathLike, NoneType] = None, **kwargs) -> None:
270def export_sam_model( 271 image: np.ndarray, 272 label_image: np.ndarray, 273 model_type: str, 274 name: str, 275 output_path: Union[str, os.PathLike], 276 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 277 **kwargs 278) -> None: 279 """Export SAM model to BioImage.IO model format. 280 281 The exported model can be uploaded to [bioimage.io](https://bioimage.io/#/) and 282 be used in tools that support the BioImage.IO model format. 283 284 Args: 285 image: The image for generating test data. 286 label_image: The segmentation correspoding to `image`. 287 It is used to derive prompt inputs for the model. 288 model_type: The type of the SAM model. 289 name: The name of the exported model. 290 output_path: Where the exported model is saved. 291 checkpoint_path: Optional checkpoint for loading the SAM model. 292 """ 293 with tempfile.TemporaryDirectory() as tmp_dir: 294 checkpoint_path, decoder_path = _get_checkpoint(model_type, checkpoint_path, tmp_dir) 295 input_paths, result_paths = _create_test_inputs_and_outputs( 296 image, label_image, model_type, checkpoint_path, tmp_dir, 297 ) 298 input_descriptions = [ 299 # First input: the image data. 300 spec.InputTensorDescr( 301 id=spec.TensorId("image"), 302 axes=[ 303 spec.BatchAxis(size=1), 304 # NOTE: to support 1 and 3 channels we can add another preprocessing. 305 # Best solution: Have a pre-processing for this! (1C -> RGB) 306 spec.ChannelAxis(channel_names=[spec.Identifier(cname) for cname in "RGB"]), 307 spec.SpaceInputAxis(id=spec.AxisId("y"), size=spec.ARBITRARY_SIZE), 308 spec.SpaceInputAxis(id=spec.AxisId("x"), size=spec.ARBITRARY_SIZE), 309 ], 310 test_tensor=spec.FileDescr(source=input_paths["image"]), 311 data=spec.IntervalOrRatioDataDescr(type="uint8") 312 ), 313 314 # Second input: the box prompts (optional) 315 spec.InputTensorDescr( 316 id=spec.TensorId("box_prompts"), 317 optional=True, 318 axes=[ 319 spec.BatchAxis(size=1), 320 spec.IndexInputAxis( 321 id=spec.AxisId("object"), 322 size=spec.ARBITRARY_SIZE 323 ), 324 spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "hwxy"]), 325 ], 326 test_tensor=spec.FileDescr(source=input_paths["box_prompts"]), 327 data=spec.IntervalOrRatioDataDescr(type="int64") 328 ), 329 330 # Third input: the point prompt coordinates (optional) 331 spec.InputTensorDescr( 332 id=spec.TensorId("point_prompts"), 333 optional=True, 334 axes=[ 335 spec.BatchAxis(size=1), 336 spec.IndexInputAxis( 337 id=spec.AxisId("object"), 338 size=spec.ARBITRARY_SIZE 339 ), 340 spec.IndexInputAxis( 341 id=spec.AxisId("point"), 342 size=spec.ARBITRARY_SIZE 343 ), 344 spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "xy"]), 345 ], 346 test_tensor=spec.FileDescr(source=input_paths["point_prompts"]), 347 data=spec.IntervalOrRatioDataDescr(type="int64") 348 ), 349 350 # Fourth input: the point prompt labels (optional) 351 spec.InputTensorDescr( 352 id=spec.TensorId("point_labels"), 353 optional=True, 354 axes=[ 355 spec.BatchAxis(size=1), 356 spec.IndexInputAxis( 357 id=spec.AxisId("object"), 358 size=spec.ARBITRARY_SIZE 359 ), 360 spec.IndexInputAxis( 361 id=spec.AxisId("point"), 362 size=spec.ARBITRARY_SIZE 363 ), 364 ], 365 test_tensor=spec.FileDescr(source=input_paths["point_labels"]), 366 data=spec.IntervalOrRatioDataDescr(type="int64") 367 ), 368 369 # Fifth input: the mask prompts (optional) 370 spec.InputTensorDescr( 371 id=spec.TensorId("mask_prompts"), 372 optional=True, 373 axes=[ 374 spec.BatchAxis(size=1), 375 spec.IndexInputAxis( 376 id=spec.AxisId("object"), 377 size=spec.ARBITRARY_SIZE 378 ), 379 spec.ChannelAxis(channel_names=["channel"]), 380 spec.SpaceInputAxis(id=spec.AxisId("y"), size=256), 381 spec.SpaceInputAxis(id=spec.AxisId("x"), size=256), 382 ], 383 test_tensor=spec.FileDescr(source=input_paths["mask_prompts"]), 384 data=spec.IntervalOrRatioDataDescr(type="float32") 385 ), 386 387 # Sixth input: the image embeddings (optional) 388 spec.InputTensorDescr( 389 id=spec.TensorId("embeddings"), 390 optional=True, 391 axes=[ 392 spec.BatchAxis(size=1), 393 # NOTE: we currently have to specify all the channel names 394 # (It would be nice to also support size) 395 spec.ChannelAxis(channel_names=[spec.Identifier(f"c{i}") for i in range(256)]), 396 spec.SpaceInputAxis(id=spec.AxisId("y"), size=64), 397 spec.SpaceInputAxis(id=spec.AxisId("x"), size=64), 398 ], 399 test_tensor=spec.FileDescr(source=result_paths["embeddings"]), 400 data=spec.IntervalOrRatioDataDescr(type="float32") 401 ), 402 403 ] 404 405 output_descriptions = [ 406 # First output: The mask predictions. 407 spec.OutputTensorDescr( 408 id=spec.TensorId("masks"), 409 axes=[ 410 spec.BatchAxis(size=1), 411 # NOTE: we use the data dependent size here to avoid dependency on optional inputs 412 spec.IndexOutputAxis( 413 id=spec.AxisId("object"), size=spec.DataDependentSize(), 414 ), 415 # NOTE: this could be a 3 once we use multi-masking 416 spec.ChannelAxis(channel_names=[spec.Identifier("mask")]), 417 spec.SpaceOutputAxis( 418 id=spec.AxisId("y"), 419 size=spec.SizeReference( 420 tensor_id=spec.TensorId("image"), axis_id=spec.AxisId("y"), 421 ) 422 ), 423 spec.SpaceOutputAxis( 424 id=spec.AxisId("x"), 425 size=spec.SizeReference( 426 tensor_id=spec.TensorId("image"), axis_id=spec.AxisId("x"), 427 ) 428 ) 429 ], 430 data=spec.IntervalOrRatioDataDescr(type="uint8"), 431 test_tensor=spec.FileDescr(source=result_paths["mask"]) 432 ), 433 434 # The score predictions 435 spec.OutputTensorDescr( 436 id=spec.TensorId("scores"), 437 axes=[ 438 spec.BatchAxis(size=1), 439 # NOTE: we use the data dependent size here to avoid dependency on optional inputs 440 spec.IndexOutputAxis( 441 id=spec.AxisId("object"), size=spec.DataDependentSize(), 442 ), 443 # NOTE: this could be a 3 once we use multi-masking 444 spec.ChannelAxis(channel_names=[spec.Identifier("mask")]), 445 ], 446 data=spec.IntervalOrRatioDataDescr(type="float32"), 447 test_tensor=spec.FileDescr(source=result_paths["score"]) 448 ), 449 450 # The image embeddings 451 spec.OutputTensorDescr( 452 id=spec.TensorId("embeddings"), 453 axes=[ 454 spec.BatchAxis(size=1), 455 spec.ChannelAxis(channel_names=[spec.Identifier(f"c{i}") for i in range(256)]), 456 spec.SpaceOutputAxis(id=spec.AxisId("y"), size=64), 457 spec.SpaceOutputAxis(id=spec.AxisId("x"), size=64), 458 ], 459 data=spec.IntervalOrRatioDataDescr(type="float32"), 460 test_tensor=spec.FileDescr(source=result_paths["embeddings"]) 461 ) 462 ] 463 464 architecture_path = os.path.join(os.path.split(__file__)[0], "predictor_adaptor.py") 465 architecture = spec.ArchitectureFromFileDescr( 466 source=Path(architecture_path), 467 callable="PredictorAdaptor", 468 kwargs={"model_type": model_type} 469 ) 470 471 dependency_file = os.path.join(tmp_dir, "environment.yaml") 472 _write_dependencies(dependency_file, require_mobile_sam=model_type.startswith("vit_t")) 473 474 weight_descriptions = spec.WeightsDescr( 475 pytorch_state_dict=spec.PytorchStateDictWeightsDescr( 476 source=Path(checkpoint_path), 477 architecture=architecture, 478 pytorch_version=spec.Version(torch.__version__), 479 dependencies=spec.EnvironmentFileDescr(source=dependency_file), 480 ) 481 ) 482 483 doc_path = _write_documentation(kwargs.get("documentation", None), model_type, tmp_dir) 484 485 covers = kwargs.get("covers", None) 486 if covers is None: 487 covers = _generate_covers(input_paths, result_paths, tmp_dir) 488 else: 489 assert all(os.path.exists(cov) for cov in covers) 490 491 # the uploader information is only added if explicitly passed 492 extra_kwargs = {} 493 if "id" in kwargs: 494 extra_kwargs["id"] = kwargs["id"] 495 if "id_emoji" in kwargs: 496 extra_kwargs["id_emoji"] = kwargs["id_emoji"] 497 if "uploader" in kwargs: 498 extra_kwargs["uploader"] = kwargs["uploader"] 499 500 if decoder_path is not None: 501 extra_kwargs["attachments"] = [spec.FileDescr(source=decoder_path)] 502 503 model_description = spec.ModelDescr( 504 name=name, 505 inputs=input_descriptions, 506 outputs=output_descriptions, 507 weights=weight_descriptions, 508 description=kwargs.get("description", DEFAULTS["description"]), 509 authors=kwargs.get("authors", DEFAULTS["authors"]), 510 cite=kwargs.get("cite", DEFAULTS["cite"]), 511 license=spec.LicenseId("CC-BY-4.0"), 512 documentation=Path(doc_path), 513 git_repo=spec.HttpUrl("https://github.com/computational-cell-analytics/micro-sam"), 514 tags=kwargs.get("tags", DEFAULTS["tags"]), 515 covers=covers, 516 **extra_kwargs, 517 # TODO write specific settings in the config 518 # dict with yaml values, key must be a str 519 # micro_sam: ... 520 # config= 521 ) 522 523 _check_model(model_description, input_paths, result_paths) 524 525 save_bioimageio_package(model_description, output_path=output_path)
Export SAM model to BioImage.IO model format.
The exported model can be uploaded to bioimage.io and be used in tools that support the BioImage.IO model format.
Arguments:
- image: The image for generating test data.
- label_image: The segmentation correspoding to
image
. It is used to derive prompt inputs for the model. - model_type: The type of the SAM model.
- name: The name of the exported model.
- output_path: Where the exported model is saved.
- checkpoint_path: Optional checkpoint for loading the SAM model.