micro_sam.instance_segmentation
Automated instance segmentation functionality. The classes implemented here extend the automatic instance segmentation from Segment Anything: https://computational-cell-analytics.github.io/micro-sam/micro_sam.html
1"""Automated instance segmentation functionality. 2The classes implemented here extend the automatic instance segmentation from Segment Anything: 3https://computational-cell-analytics.github.io/micro-sam/micro_sam.html 4""" 5 6import os 7import shutil 8import tempfile 9import warnings 10from abc import ABC 11from contextlib import contextmanager 12from copy import deepcopy 13from collections import OrderedDict 14from typing import Any, Dict, Literal, List, Optional, Tuple, Union 15 16import numpy as np 17import zarr 18from skimage.measure import regionprops 19from skimage.segmentation import find_boundaries 20 21from bioimage_cpp import filters as filter_impl 22 23import torch 24from torchvision.ops.boxes import batched_nms, box_area 25 26from torch_em.model import UNETR 27from torch_em.util.segmentation import watershed_from_center_and_boundary_distances 28 29import elf.parallel as parallel_impl 30from elf.parallel.filters import apply_filter 31from elf.wrapper.base import MultiTransformationWrapper 32from elf.wrapper.generic import ThresholdWrapper 33 34from bioimage_cpp.utils import Blocking 35 36import segment_anything.utils.amg as amg_utils 37from segment_anything.predictor import SamPredictor 38 39from . import util 40from .inference import batched_inference, batched_tiled_inference 41from ._vendored import batched_mask_to_box, mask_to_rle_pytorch 42 43# We may change this to 'apg' in version 1.8. 44DEFAULT_SEGMENTATION_MODE_WITH_DECODER = "ais" 45 46# 47# Utility Functionality 48# 49 50 51class _FakeInput: 52 def __init__(self, shape): 53 self.shape = shape 54 55 def __getitem__(self, index): 56 block_shape = tuple(ind.stop - ind.start for ind in index) 57 return np.zeros(block_shape, dtype="float32") 58 59 60# 61# Classes for automatic instance segmentation 62# 63 64 65class AMGBase(ABC): 66 """Base class for the automatic mask generators. 67 """ 68 def __init__(self): 69 # the state that has to be computed by the 'initialize' method of the child classes 70 self._is_initialized = False 71 self._crop_list = None 72 self._crop_boxes = None 73 self._original_size = None 74 75 @property 76 def is_initialized(self): 77 """Whether the mask generator has already been initialized. 78 """ 79 return self._is_initialized 80 81 @property 82 def crop_list(self): 83 """The list of mask data after initialization. 84 """ 85 return self._crop_list 86 87 @property 88 def crop_boxes(self): 89 """The list of crop boxes. 90 """ 91 return self._crop_boxes 92 93 @property 94 def original_size(self): 95 """The original image size. 96 """ 97 return self._original_size 98 99 def _postprocess_batch( 100 self, 101 data, 102 crop_box, 103 original_size, 104 pred_iou_thresh, 105 stability_score_thresh, 106 box_nms_thresh, 107 ): 108 orig_h, orig_w = original_size 109 110 # filter by predicted IoU 111 if pred_iou_thresh > 0.0: 112 keep_mask = data["iou_preds"] > pred_iou_thresh 113 data.filter(keep_mask) 114 115 # filter by stability score 116 if stability_score_thresh > 0.0: 117 keep_mask = data["stability_score"] >= stability_score_thresh 118 data.filter(keep_mask) 119 120 # filter boxes that touch crop boundaries 121 keep_mask = ~amg_utils.is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) 122 if not torch.all(keep_mask): 123 data.filter(keep_mask) 124 125 # remove duplicates within this crop. 126 keep_by_nms = batched_nms( 127 data["boxes"].float(), 128 data["iou_preds"], 129 torch.zeros_like(data["boxes"][:, 0]), # categories 130 iou_threshold=box_nms_thresh, 131 ) 132 data.filter(keep_by_nms) 133 134 # return to the original image frame 135 data["boxes"] = amg_utils.uncrop_boxes_xyxy(data["boxes"], crop_box) 136 data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) 137 # the data from embedding based segmentation doesn't have the points 138 # so we skip if the corresponding key can't be found 139 try: 140 data["points"] = amg_utils.uncrop_points(data["points"], crop_box) 141 except KeyError: 142 pass 143 144 return data 145 146 def _postprocess_small_regions(self, mask_data, min_area, nms_thresh): 147 148 if len(mask_data["rles"]) == 0: 149 return mask_data 150 151 # filter small disconnected regions and holes 152 new_masks = [] 153 scores = [] 154 for rle in mask_data["rles"]: 155 mask = amg_utils.rle_to_mask(rle) 156 157 mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="holes") 158 unchanged = not changed 159 mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="islands") 160 unchanged = unchanged and not changed 161 162 new_masks.append(torch.as_tensor(mask, dtype=torch.int).unsqueeze(0)) 163 # give score=0 to changed masks and score=1 to unchanged masks 164 # so NMS will prefer ones that didn't need postprocessing 165 scores.append(float(unchanged)) 166 167 # recalculate boxes and remove any new duplicates 168 masks = torch.cat(new_masks, dim=0) 169 boxes = batched_mask_to_box(masks.to(torch.bool)) # Casting this to boolean as we work with one-hot labels. 170 keep_by_nms = batched_nms( 171 boxes.float(), 172 torch.as_tensor(scores, dtype=torch.float), 173 torch.zeros_like(boxes[:, 0]), # categories 174 iou_threshold=nms_thresh, 175 ) 176 177 # only recalculate RLEs for masks that have changed 178 for i_mask in keep_by_nms: 179 if scores[i_mask] == 0.0: 180 mask_torch = masks[i_mask].unsqueeze(0) 181 # mask_data["rles"][i_mask] = amg_utils.mask_to_rle_pytorch(mask_torch)[0] 182 mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] 183 mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly 184 mask_data.filter(keep_by_nms) 185 186 return mask_data 187 188 def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode): 189 # filter small disconnected regions and holes in masks 190 if min_mask_region_area > 0: 191 mask_data = self._postprocess_small_regions( 192 mask_data, 193 min_mask_region_area, 194 max(box_nms_thresh, crop_nms_thresh), 195 ) 196 197 # encode masks 198 if output_mode == "coco_rle": 199 mask_data["segmentations"] = [amg_utils.coco_encode_rle(rle) for rle in mask_data["rles"]] 200 # We require the binary mask output for creating the instance-seg output. 201 elif output_mode in ("binary_mask", "instance_segmentation"): 202 mask_data["segmentations"] = [amg_utils.rle_to_mask(rle) for rle in mask_data["rles"]] 203 elif output_mode == "rle": 204 mask_data["segmentations"] = mask_data["rles"] 205 else: 206 raise ValueError(f"Invalid output mode {output_mode}.") 207 208 # write mask records 209 curr_anns = [] 210 for idx in range(len(mask_data["segmentations"])): 211 ann = { 212 "segmentation": mask_data["segmentations"][idx], 213 "area": amg_utils.area_from_rle(mask_data["rles"][idx]), 214 "bbox": amg_utils.box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), 215 "predicted_iou": mask_data["iou_preds"][idx].item(), 216 "stability_score": mask_data["stability_score"][idx].item(), 217 "crop_box": amg_utils.box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), 218 } 219 # the data from embedding based segmentation doesn't have the points 220 # so we skip if the corresponding key can't be found 221 try: 222 ann["point_coords"] = [mask_data["points"][idx].tolist()] 223 except KeyError: 224 pass 225 curr_anns.append(ann) 226 227 return curr_anns 228 229 def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None): 230 orig_h, orig_w = original_size 231 232 # serialize predictions and store in MaskData 233 data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1)) 234 if points is not None: 235 data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0), dtype=torch.float) 236 237 del masks 238 239 # calculate the stability scores 240 data["stability_score"] = amg_utils.calculate_stability_score( 241 data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset 242 ) 243 244 # threshold masks and calculate boxes 245 data["masks"] = data["masks"] > self._predictor.model.mask_threshold 246 data["masks"] = data["masks"].type(torch.bool) 247 data["boxes"] = batched_mask_to_box(data["masks"]) 248 249 # compress to RLE 250 data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w) 251 # data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"]) 252 data["rles"] = mask_to_rle_pytorch(data["masks"]) 253 del data["masks"] 254 255 return data 256 257 def get_state(self) -> Dict[str, Any]: 258 """Get the initialized state of the mask generator. 259 260 Returns: 261 State of the mask generator. 262 """ 263 if not self.is_initialized: 264 raise RuntimeError("The state has not been computed yet. Call initialize first.") 265 266 return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size} 267 268 def set_state(self, state: Dict[str, Any]) -> None: 269 """Set the state of the mask generator. 270 271 Args: 272 state: The state of the mask generator, e.g. from serialized state. 273 """ 274 self._crop_list = state["crop_list"] 275 self._crop_boxes = state["crop_boxes"] 276 self._original_size = state["original_size"] 277 self._is_initialized = True 278 279 def clear_state(self): 280 """Clear the state of the mask generator. 281 """ 282 self._crop_list = None 283 self._crop_boxes = None 284 self._original_size = None 285 self._is_initialized = False 286 287 288class AutomaticMaskGenerator(AMGBase): 289 """Generates an instance segmentation without prompts, using a point grid. 290 291 This class implements the same logic as 292 https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py 293 It decouples the computationally expensive steps of generating masks from the cheap post-processing operation 294 to filter these masks to enable grid search and interactively changing the post-processing. 295 296 Use this class as follows: 297 ```python 298 amg = AutomaticMaskGenerator(predictor) 299 amg.initialize(image) # Initialize the masks, this takes care of all expensive computations. 300 masks = amg.generate(pred_iou_thresh=0.8) # Generate the masks. This is fast and enables testing parameters 301 ``` 302 303 Args: 304 predictor: The segment anything predictor. 305 points_per_side: The number of points to be sampled along one side of the image. 306 If None, `point_grids` must provide explicit point sampling. By default, set to '32'. 307 points_per_batch: The number of points run simultaneously by the model. 308 Higher numbers may be faster but use more GPU memory. 309 By default, selects '64' for all devices except 'mps' (selects '16' for performance reasons). 310 crop_n_layers: If >0, the mask prediction will be run again on crops of the image. 311 By default, set to '0'. 312 crop_overlap_ratio: Sets the degree to which crops overlap. By default, set to '512 / 1500'. 313 crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops. 314 By default, set to '1'. 315 point_grids: A list over explicit grids of points used for sampling masks. 316 Normalized to [0, 1] with respect to the image coordinate system. 317 stability_score_offset: The amount to shift the cutoff when calculating the stability score. 318 By default, set to '1.0'. 319 """ 320 def __init__( 321 self, 322 predictor: SamPredictor, 323 points_per_side: Optional[int] = 32, 324 points_per_batch: Optional[int] = None, 325 crop_n_layers: int = 0, 326 crop_overlap_ratio: float = 512 / 1500, 327 crop_n_points_downscale_factor: int = 1, 328 point_grids: Optional[List[np.ndarray]] = None, 329 stability_score_offset: float = 1.0, 330 ): 331 super().__init__() 332 333 if points_per_side is not None: 334 self.point_grids = amg_utils.build_all_layer_point_grids( 335 points_per_side, crop_n_layers, crop_n_points_downscale_factor, 336 ) 337 elif point_grids is not None: 338 self.point_grids = point_grids 339 else: 340 raise ValueError("Can't have both points_per_side and point_grid be None or not None.") 341 342 self._predictor = predictor 343 self._points_per_side = points_per_side 344 345 # we set the points per batch to 16 for mps for performance reasons 346 # and otherwise keep them at the default of 64 347 if points_per_batch is None: 348 points_per_batch = 16 if str(predictor.device) == "mps" else 64 349 self._points_per_batch = points_per_batch 350 351 self._crop_n_layers = crop_n_layers 352 self._crop_overlap_ratio = crop_overlap_ratio 353 self._crop_n_points_downscale_factor = crop_n_points_downscale_factor 354 self._stability_score_offset = stability_score_offset 355 356 def _process_batch(self, points, im_size, crop_box, original_size): 357 # run model on this batch 358 transformed_points = self._predictor.transform.apply_coords(points, im_size) 359 in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float) 360 in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) 361 masks, iou_preds, _ = self._predictor.predict_torch( 362 point_coords=in_points[:, None, :], 363 point_labels=in_labels[:, None], 364 multimask_output=True, 365 return_logits=True, 366 ) 367 data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points) 368 del masks 369 return data 370 371 def _process_crop(self, image, crop_box, crop_layer_idx, precomputed_embeddings, pbar_init=None, pbar_update=None): 372 # Crop the image and calculate embeddings. 373 x0, y0, x1, y1 = crop_box 374 cropped_im = image[y0:y1, x0:x1, :] 375 cropped_im_size = cropped_im.shape[:2] 376 377 if not precomputed_embeddings: 378 self._predictor.set_image(cropped_im) 379 380 # Get the points for this crop. 381 points_scale = np.array(cropped_im_size)[None, ::-1] 382 points_for_image = self.point_grids[crop_layer_idx] * points_scale 383 384 # Generate masks for this crop in batches. 385 data = amg_utils.MaskData() 386 n_batches = len(points_for_image) // self._points_per_batch +\ 387 int(len(points_for_image) % self._points_per_batch != 0) 388 if pbar_init is not None: 389 pbar_init(n_batches, "Predict masks for point grid prompts") 390 391 for (points,) in amg_utils.batch_iterator(self._points_per_batch, points_for_image): 392 batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size) 393 data.cat(batch_data) 394 del batch_data 395 if pbar_update is not None: 396 pbar_update(1) 397 398 if not precomputed_embeddings: 399 self._predictor.reset_image() 400 401 return data 402 403 @torch.no_grad() 404 def initialize( 405 self, 406 image: np.ndarray, 407 image_embeddings: Optional[util.ImageEmbeddings] = None, 408 i: Optional[int] = None, 409 verbose: bool = False, 410 pbar_init: Optional[callable] = None, 411 pbar_update: Optional[callable] = None, 412 ) -> None: 413 """Initialize image embeddings and masks for an image. 414 415 Args: 416 image: The input image, volume or timeseries. 417 image_embeddings: Optional precomputed image embeddings. 418 See `util.precompute_image_embeddings` for details. 419 i: Index for the image data. Required if `image` has three spatial dimensions 420 or a time dimension and two spatial dimensions. 421 verbose: Whether to print computation progress. By default, set to 'False'. 422 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 423 Can be used together with pbar_update to handle napari progress bar in other thread. 424 To enable using this function within a threadworker. 425 pbar_update: Callback to update an external progress bar. 426 """ 427 original_size = image.shape[:2] 428 self._original_size = original_size 429 430 crop_boxes, layer_idxs = amg_utils.generate_crop_boxes( 431 original_size, self._crop_n_layers, self._crop_overlap_ratio 432 ) 433 434 # We can set fixed image embeddings if we only have a single crop box (the default setting). 435 # Otherwise we have to recompute the embeddings for each crop and can't precompute. 436 if len(crop_boxes) == 1: 437 if image_embeddings is None: 438 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 439 util.set_precomputed(self._predictor, image_embeddings, i=i) 440 precomputed_embeddings = True 441 else: 442 precomputed_embeddings = False 443 444 # we need to cast to the image representation that is compatible with SAM 445 image = util._to_image(image) 446 447 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 448 449 crop_list = [] 450 for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 451 crop_data = self._process_crop( 452 image, crop_box, layer_idx, 453 precomputed_embeddings=precomputed_embeddings, 454 pbar_init=pbar_init, pbar_update=pbar_update, 455 ) 456 crop_list.append(crop_data) 457 pbar_close() 458 459 self._is_initialized = True 460 self._crop_list = crop_list 461 self._crop_boxes = crop_boxes 462 463 @torch.no_grad() 464 def generate( 465 self, 466 pred_iou_thresh: float = 0.88, 467 stability_score_thresh: float = 0.95, 468 box_nms_thresh: float = 0.7, 469 crop_nms_thresh: float = 0.7, 470 min_mask_region_area: int = 0, 471 output_mode: str = "instance_segmentation", 472 with_background: bool = True, 473 ) -> Union[List[Dict[str, Any]], np.ndarray]: 474 """Generate instance segmentation for the currently initialized image. 475 476 Args: 477 pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model. 478 By default, set to '0.88'. 479 stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask 480 under changes to the cutoff used to binarize the model prediction. By default, set to '0.95'. 481 box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks. 482 By default, set to '0.7'. 483 crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops. 484 By default, set to '0.7'. 485 min_mask_region_area: Minimal size for the predicted masks. By default, set to '0'. 486 output_mode: The form masks are returned in. Possible values are: 487 - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks. 488 - 'coco_rle': Return a list of dictionaries with run-length encoded masks in MS COCO format. 489 - 'instance_segmentation': Return masks merged into an instance segmentation in a single array. 490 - 'rle': Return a list of dictionaries with run-length encoded masks. 491 By default, set to 'instance_segmentation'. 492 with_background: Whether to remove the largest object, which often covers the background. 493 494 Returns: 495 The segmentation masks. 496 """ 497 if not self.is_initialized: 498 raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.") 499 500 data = amg_utils.MaskData() 501 for data_, crop_box in zip(self.crop_list, self.crop_boxes): 502 crop_data = self._postprocess_batch( 503 data=deepcopy(data_), 504 crop_box=crop_box, original_size=self.original_size, 505 pred_iou_thresh=pred_iou_thresh, 506 stability_score_thresh=stability_score_thresh, 507 box_nms_thresh=box_nms_thresh 508 ) 509 data.cat(crop_data) 510 511 if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0: 512 # Prefer masks from smaller crops 513 scores = 1 / box_area(data["crop_boxes"]) 514 scores = scores.to(data["boxes"].device) 515 keep_by_nms = batched_nms( 516 data["boxes"].float(), 517 scores, 518 torch.zeros_like(data["boxes"][:, 0]), # categories 519 iou_threshold=crop_nms_thresh, 520 ) 521 data.filter(keep_by_nms) 522 523 data.to_numpy() 524 masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode) 525 if output_mode == "instance_segmentation": 526 shape = next(iter(masks))["segmentation"].shape if len(masks) > 0 else self.original_size 527 masks = util.mask_data_to_segmentation( 528 masks, shape=shape, with_background=with_background, merge_exclusively=False 529 ) 530 return masks 531 532 533# Helper function for tiled embedding computation and checking consistent state. 534def _process_tiled_embeddings(predictor, image, image_embeddings, tile_shape, halo, verbose, batch_size, mask, i): 535 if image_embeddings is None: 536 if tile_shape is None or halo is None: 537 raise ValueError("To compute tiled embeddings the parameters tile_shape and halo have to be passed.") 538 image_embeddings = util.precompute_image_embeddings( 539 predictor, image, tile_shape=tile_shape, halo=halo, verbose=verbose, batch_size=batch_size, mask=mask, 540 ) 541 542 # Use tile shape and halo from the precomputed embeddings if not given. 543 # Otherwise check that they are consistent. 544 feats = image_embeddings["features"] 545 tile_shape_, halo_ = tuple(feats.attrs["tile_shape"]), tuple(feats.attrs["halo"]) 546 if tile_shape is None: 547 tile_shape = tile_shape_ 548 elif tile_shape != tile_shape_: 549 raise ValueError( 550 f"Inconsistent tile_shape parameter {tile_shape} with precomputed embeedings: {tile_shape_}." 551 ) 552 if halo is None: 553 halo = halo_ 554 elif halo != halo_: 555 raise ValueError(f"Inconsistent halo parameter {halo} with precomputed embeedings: {halo_}.") 556 557 tiles_in_mask = feats.attrs.get("tiles_in_mask", None) 558 if tiles_in_mask is not None and i is not None: 559 tiles_in_mask = tiles_in_mask[str(i)] 560 561 return image_embeddings, tile_shape, halo, tiles_in_mask 562 563 564class TiledAutomaticMaskGenerator(AutomaticMaskGenerator): 565 """Generates an instance segmentation without prompts, using a point grid. 566 567 Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings. 568 569 Args: 570 predictor: The Segment Anything predictor. 571 points_per_side: The number of points to be sampled along one side of the image. 572 If None, `point_grids` must provide explicit point sampling. By default, set to '32'. 573 points_per_batch: The number of points run simultaneously by the model. 574 Higher numbers may be faster but use more GPU memory. By default, set to '64'. 575 point_grids: A list over explicit grids of points used for sampling masks. 576 Normalized to [0, 1] with respect to the image coordinate system. 577 stability_score_offset: The amount to shift the cutoff when calculating the stability score. 578 By default, set to '1.0'. 579 """ 580 581 # We only expose the arguments that make sense for the tiled mask generator. 582 # Anything related to crops doesn't make sense, because we re-use that functionality 583 # for tiling, so these parameters wouldn't have any effect. 584 def __init__( 585 self, 586 predictor: SamPredictor, 587 points_per_side: Optional[int] = 32, 588 points_per_batch: int = 64, 589 point_grids: Optional[List[np.ndarray]] = None, 590 stability_score_offset: float = 1.0, 591 ) -> None: 592 super().__init__( 593 predictor=predictor, 594 points_per_side=points_per_side, 595 points_per_batch=points_per_batch, 596 point_grids=point_grids, 597 stability_score_offset=stability_score_offset, 598 ) 599 600 @torch.no_grad() 601 def initialize( 602 self, 603 image: np.ndarray, 604 image_embeddings: Optional[util.ImageEmbeddings] = None, 605 i: Optional[int] = None, 606 tile_shape: Optional[Tuple[int, int]] = None, 607 halo: Optional[Tuple[int, int]] = None, 608 verbose: bool = False, 609 pbar_init: Optional[callable] = None, 610 pbar_update: Optional[callable] = None, 611 batch_size: int = 1, 612 mask: Optional[np.typing.ArrayLike] = None, 613 ) -> None: 614 """Initialize image embeddings and masks for an image. 615 616 Args: 617 image: The input image, volume or timeseries. 618 image_embeddings: Optional precomputed image embeddings. 619 See `util.precompute_image_embeddings` for details. 620 i: Index for the image data. Required if `image` has three spatial dimensions 621 or a time dimension and two spatial dimensions. 622 tile_shape: The tile shape for embedding prediction. 623 halo: The overlap of between tiles. 624 verbose: Whether to print computation progress. By default, set to 'False'. 625 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 626 Can be used together with pbar_update to handle napari progress bar in other thread. 627 To enable using this function within a threadworker. 628 pbar_update: Callback to update an external progress bar. 629 batch_size: The batch size for image embedding prediction. By default, set to '1'. 630 mask: An optional mask to define areas that are ignored in the segmentation. 631 """ 632 original_size = image.shape[:2] 633 self._original_size = original_size 634 635 self._image_embeddings, tile_shape, halo, tiles_in_mask = _process_tiled_embeddings( 636 self._predictor, image, image_embeddings, tile_shape, halo, 637 verbose=verbose, batch_size=batch_size, mask=mask, i=i, 638 ) 639 640 tiling = Blocking([0, 0], original_size, tile_shape) 641 if tiles_in_mask is None: 642 n_tiles = tiling.number_of_blocks 643 tile_ids = range(n_tiles) 644 else: 645 n_tiles = len(tiles_in_mask) 646 tile_ids = tiles_in_mask 647 648 # The crop box is always the full local tile. 649 tiles = [tiling.get_block_with_halo(tile_id, list(halo)).outer_block for tile_id in tile_ids] 650 crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles] 651 652 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 653 pbar_init(n_tiles, "Compute masks for tile") 654 655 # We need to cast to the image representation that is compatible with SAM. 656 image = util._to_image(image) 657 658 mask_data = [] 659 for idx, tile_id in enumerate(tile_ids): 660 # set the pre-computed embeddings for this tile 661 features = image_embeddings["features"][str(tile_id)] 662 tile_embeddings = { 663 "features": features, 664 "input_size": features.attrs["input_size"], 665 "original_size": features.attrs["original_size"], 666 } 667 util.set_precomputed(self._predictor, tile_embeddings, i) 668 669 # Compute the mask data for this tile and append it 670 this_mask_data = self._process_crop( 671 image, crop_box=crop_boxes[idx], crop_layer_idx=0, precomputed_embeddings=True 672 ) 673 mask_data.append(this_mask_data) 674 pbar_update(1) 675 pbar_close() 676 677 # set the initialized data 678 self._is_initialized = True 679 self._crop_list = mask_data 680 self._crop_boxes = crop_boxes 681 682 683# 684# Instance segmentation functionality based on fine-tuned decoder 685# 686 687 688class DecoderAdapter(torch.nn.Module): 689 """Adapter to contain the UNETR decoder in a single module. 690 691 To apply the decoder on top of pre-computed embeddings for the segmentation functionality. 692 See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py 693 """ 694 def __init__(self, unetr: torch.nn.Module): 695 super().__init__() 696 697 self.base = unetr.base 698 self.out_conv = unetr.out_conv 699 self.deconv_out = unetr.deconv_out 700 self.decoder_head = unetr.decoder_head 701 self.final_activation = unetr.final_activation 702 self.postprocess_masks = unetr.postprocess_masks 703 704 self.decoder = unetr.decoder 705 self.deconv1 = unetr.deconv1 706 self.deconv2 = unetr.deconv2 707 self.deconv3 = unetr.deconv3 708 self.deconv4 = unetr.deconv4 709 710 def _forward_impl(self, input_): 711 z12 = input_ 712 713 z9 = self.deconv1(z12) 714 z6 = self.deconv2(z9) 715 z3 = self.deconv3(z6) 716 z0 = self.deconv4(z3) 717 718 updated_from_encoder = [z9, z6, z3] 719 720 x = self.base(z12) 721 x = self.decoder(x, encoder_inputs=updated_from_encoder) 722 x = self.deconv_out(x) 723 724 x = torch.cat([x, z0], dim=1) 725 x = self.decoder_head(x) 726 727 x = self.out_conv(x) 728 if self.final_activation is not None: 729 x = self.final_activation(x) 730 return x 731 732 def forward(self, input_, input_shape, original_shape): 733 x = self._forward_impl(input_) 734 x = self.postprocess_masks(x, input_shape, original_shape) 735 return x 736 737 738def get_unetr( 739 image_encoder: torch.nn.Module, 740 decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None, 741 device: Optional[Union[str, torch.device]] = None, 742 out_channels: int = 3, 743 flexible_load_checkpoint: bool = False, 744 final_activation: Optional[str] = "Sigmoid", 745) -> torch.nn.Module: 746 """Get UNETR model for automatic instance segmentation. 747 748 Args: 749 image_encoder: The image encoder of the SAM model. 750 This is used as encoder by the UNETR too. 751 decoder_state: Optional decoder state to initialize the weights of the UNETR decoder. 752 device: The device. By default, automatically chooses the best available device. 753 out_channels: The number of output channels. By default, set to '3'. 754 flexible_load_checkpoint: Whether to allow reinitialization of parameters 755 which could not be found in the provided decoder state. By default, set to 'False'. 756 final_activation: The activation applied to the network output. By default uses a Sigmoid activation. 757 758 Returns: 759 The UNETR model. 760 """ 761 device = util.get_device(device) 762 763 if decoder_state is None: 764 use_conv_transpose = False # By default, we use interpolation for upsampling. 765 else: 766 # From the provided pretrained 'decoder_state', we check whether it uses transposed convolutions. 767 # NOTE: Explanation to the logic below - 768 # - We do this by looking for parameter names that contain '.block.' within the "decoder.samplers" 769 # submodules. This naming convention indicates that transposed convolutions are used, 770 # wrapped inside a custom block module. 771 # - Otherwise '.conv.' appears. It indicates a standard `Conv2d` applied after interpolation for upsampling. 772 use_conv_transpose = any(".block." in k for k in decoder_state.keys() if k.startswith("decoder.samplers")) 773 774 unetr = UNETR( 775 backbone="sam", 776 encoder=image_encoder, 777 out_channels=out_channels, 778 use_sam_stats=True, 779 final_activation=final_activation, 780 use_skip_connection=False, 781 resize_input=True, 782 use_conv_transpose=use_conv_transpose, 783 ) 784 785 if decoder_state is not None: 786 unetr_state_dict = unetr.state_dict() 787 for k, v in unetr_state_dict.items(): 788 if not k.startswith("encoder"): 789 # Whether allow reinitalization of params, if not found or mismatched. 790 if flexible_load_checkpoint: 791 if k in decoder_state: # First check whether the key is available in the provided decoder state. 792 if v.shape != decoder_state[k].shape: # Then check if the sizes mismatch. 793 warnings.warn(f"Shape of '{k}' did not match. Hence, we reinitialize it.") 794 unetr_state_dict[k] = v 795 else: 796 unetr_state_dict[k] = decoder_state[k] 797 else: # Otherwise, allow it to initialize it. 798 warnings.warn(f"Could not find '{k}' in the pretrained state dict. Hence, we reinitialize it.") 799 unetr_state_dict[k] = v 800 801 else: # Be strict on finding the parameter in the decoder state. 802 if k not in decoder_state: 803 raise RuntimeError(f"The parameters for '{k}' could not be found or has a size mismatch.") 804 unetr_state_dict[k] = decoder_state[k] 805 806 unetr.load_state_dict(unetr_state_dict) 807 808 unetr.to(device) 809 return unetr 810 811 812def get_decoder( 813 image_encoder: torch.nn.Module, 814 decoder_state: OrderedDict[str, torch.Tensor], 815 device: Optional[Union[str, torch.device]] = None, 816) -> DecoderAdapter: 817 """Get decoder to predict outputs for automatic instance segmentation 818 819 Args: 820 image_encoder: The image encoder of the SAM model. 821 decoder_state: State to initialize the weights of the UNETR decoder. 822 device: The device. By default, automatically chooses the best available device. 823 824 Returns: 825 The decoder for instance segmentation. 826 """ 827 unetr = get_unetr(image_encoder, decoder_state, device) 828 return DecoderAdapter(unetr) 829 830 831def get_predictor_and_decoder( 832 model_type: str, 833 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 834 device: Optional[Union[str, torch.device]] = None, 835 peft_kwargs: Optional[Dict] = None, 836) -> Tuple[SamPredictor, DecoderAdapter]: 837 """Load the SAM model (predictor) and instance segmentation decoder. 838 839 This requires a checkpoint that contains the state for both predictor 840 and decoder. 841 842 Args: 843 model_type: The type of the image encoder used in the SAM model. 844 checkpoint_path: Path to the checkpoint from which to load the data. 845 device: The device. By default, automatically chooses the best available device. 846 peft_kwargs: Keyword arguments for the PEFT wrapper class. 847 848 Returns: 849 The SAM predictor. 850 The decoder for instance segmentation. 851 """ 852 device = util.get_device(device) 853 predictor, state = util.get_sam_model( 854 model_type=model_type, 855 checkpoint_path=checkpoint_path, 856 device=device, 857 return_state=True, 858 peft_kwargs=peft_kwargs, 859 ) 860 861 if "decoder_state" not in state: 862 raise ValueError( 863 f"The checkpoint at '{checkpoint_path}' or the chosen model '{model_type}' does not contain a decoder state" 864 ) 865 866 decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device) 867 return predictor, decoder 868 869 870@contextmanager 871def _array_or_zarr(shape, dtype, chunks, use_zarr=False): 872 if not use_zarr: 873 yield np.zeros(shape, dtype=dtype) 874 return 875 876 tmpdir = tempfile.mkdtemp(prefix="tmp-zarr-") 877 try: 878 store_path = os.path.join(tmpdir, "tmp.zarr") 879 root = zarr.open_group(store_path, mode="w") 880 arr = root.create_dataset(name="data", shape=shape, dtype=dtype, chunks=chunks) 881 yield arr 882 883 finally: 884 shutil.rmtree(tmpdir, ignore_errors=True) 885 886 887def _watershed_from_center_and_boundary_distances_parallel( 888 center_distances, 889 boundary_distances, 890 foreground_map, 891 center_distance_threshold, 892 boundary_distance_threshold, 893 foreground_threshold, 894 distance_smoothing, 895 min_size, 896 tile_shape, 897 halo, 898 n_threads, 899 verbose=False, 900 optimize_memory=False, 901 segmentation=None, 902): 903 center_distances = apply_filter( 904 center_distances, "gaussianSmoothing", sigma=distance_smoothing, 905 block_shape=tile_shape, n_threads=n_threads 906 ) 907 boundary_distances = apply_filter( 908 boundary_distances, "gaussianSmoothing", sigma=distance_smoothing, 909 block_shape=tile_shape, n_threads=n_threads 910 ) 911 912 fg_mask = ThresholdWrapper(foreground_map, foreground_threshold, operator=np.greater) 913 914 marker_map = MultiTransformationWrapper( 915 np.logical_and, 916 ThresholdWrapper(center_distances, center_distance_threshold, operator=np.less), 917 ThresholdWrapper(boundary_distances, boundary_distance_threshold, operator=np.less), 918 ) 919 marker_map = MultiTransformationWrapper(np.logical_and, marker_map, fg_mask) 920 921 with _array_or_zarr(marker_map.shape, dtype="uint64", chunks=tile_shape, use_zarr=optimize_memory) as markers: 922 markers = parallel_impl.label( 923 marker_map, out=markers, block_shape=tile_shape, n_threads=n_threads, verbose=verbose, 924 ) 925 926 if segmentation is None: 927 segmentation = np.zeros(markers.shape, dtype="uint64") 928 segmentation = parallel_impl.seeded_watershed( 929 boundary_distances, seeds=markers, out=segmentation, block_shape=tile_shape, 930 halo=halo, n_threads=n_threads, verbose=verbose, mask=fg_mask, 931 ) 932 933 if min_size > 0: 934 segmentation = parallel_impl.size_filter( 935 segmentation, out=segmentation, min_size=min_size, 936 block_shape=tile_shape, n_threads=n_threads, verbose=verbose 937 ) 938 939 return segmentation 940 941 942def _apply_smoothing(foreground, foreground_smoothing, tile_shape, n_threads): 943 if tile_shape is None: 944 foreground = filter_impl.gaussian_smoothing(foreground, sigma=foreground_smoothing) 945 else: 946 foreground = apply_filter( 947 foreground, "gaussianSmoothing", sigma=foreground_smoothing, 948 block_shape=tile_shape, n_threads=n_threads 949 ) 950 return foreground 951 952 953class InstanceSegmentationWithDecoder: 954 """Generates an instance segmentation without prompts, using a decoder. 955 956 Implements the same interface as `AutomaticMaskGenerator`. 957 958 Use this class as follows: 959 ```python 960 segmenter = InstanceSegmentationWithDecoder(predictor, decoder) 961 segmenter.initialize(image) # Predict the image embeddings and decoder outputs. 962 masks = segmenter.generate(center_distance_threshold=0.75) # Generate the instance segmentation. 963 ``` 964 965 Args: 966 predictor: The segment anything predictor. 967 decoder: The decoder to predict intermediate representations for instance segmentation. 968 """ 969 def __init__(self, predictor: SamPredictor, decoder: torch.nn.Module) -> None: 970 self._predictor = predictor 971 self._decoder = decoder 972 973 # The decoder outputs. 974 self._foreground = None 975 self._center_distances = None 976 self._boundary_distances = None 977 978 self._is_initialized = False 979 980 @property 981 def is_initialized(self): 982 """Whether the mask generator has already been initialized. 983 """ 984 return self._is_initialized 985 986 @torch.no_grad() 987 def initialize( 988 self, 989 image: np.ndarray, 990 image_embeddings: Optional[util.ImageEmbeddings] = None, 991 i: Optional[int] = None, 992 verbose: bool = False, 993 pbar_init: Optional[callable] = None, 994 pbar_update: Optional[callable] = None, 995 ndim: int = 2, 996 ) -> None: 997 """Initialize image embeddings and decoder predictions for an image. 998 999 Args: 1000 image: The input image, volume or timeseries. 1001 image_embeddings: Optional precomputed image embeddings. 1002 See `util.precompute_image_embeddings` for details. 1003 i: Index for the image data. Required if `image` has three spatial dimensions 1004 or a time dimension and two spatial dimensions. 1005 verbose: Whether to be verbose. By default, set to 'False'. 1006 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 1007 Can be used together with pbar_update to handle napari progress bar in other thread. 1008 To enable using this function within a threadworker. 1009 pbar_update: Callback to update an external progress bar. 1010 ndim: The dimensionality of the data. If not given will be deduced from the input data. By default, 2. 1011 """ 1012 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 1013 pbar_init(1, "Initialize instance segmentation with decoder") 1014 1015 if image_embeddings is None: 1016 image_embeddings = util.precompute_image_embeddings( 1017 predictor=self._predictor, input_=image, ndim=ndim, verbose=verbose 1018 ) 1019 1020 # Get the image embeddings from the predictor. 1021 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i) 1022 embeddings = self._predictor.features 1023 input_shape = tuple(self._predictor.input_size) 1024 original_shape = tuple(self._predictor.original_size) 1025 1026 # Run prediction with the UNETR decoder. 1027 output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0) 1028 assert output.shape[0] == 3, f"{output.shape}" 1029 pbar_update(1) 1030 pbar_close() 1031 1032 # Set the state. 1033 self._foreground = output[0] 1034 self._center_distances = output[1] 1035 self._boundary_distances = output[2] 1036 self._i = i 1037 self._is_initialized = True 1038 1039 def _to_masks(self, segmentation, output_mode): 1040 if output_mode != "binary_mask": 1041 raise ValueError( 1042 f"Output mode {output_mode} is not supported. Choose one of 'instance_segmentation', 'binary_masks'" 1043 ) 1044 1045 props = regionprops(segmentation) 1046 ndim = segmentation.ndim 1047 assert ndim in (2, 3) 1048 1049 shape = segmentation.shape 1050 if ndim == 2: 1051 crop_box = [0, shape[1], 0, shape[0]] 1052 else: 1053 crop_box = [0, shape[2], 0, shape[1], 0, shape[0]] 1054 1055 # go from skimage bbox in format [y0, x0, y1, x1] to SAM format [x0, w, y0, h] 1056 def to_bbox_2d(bbox): 1057 y0, x0 = bbox[0], bbox[1] 1058 w = bbox[3] - x0 1059 h = bbox[2] - y0 1060 return [x0, w, y0, h] 1061 1062 def to_bbox_3d(bbox): 1063 z0, y0, x0 = bbox[0], bbox[1], bbox[2] 1064 w = bbox[5] - x0 1065 h = bbox[4] - y0 1066 d = bbox[3] - y0 1067 return [x0, w, y0, h, z0, d] 1068 1069 to_bbox = to_bbox_2d if ndim == 2 else to_bbox_3d 1070 masks = [ 1071 { 1072 "segmentation": segmentation == prop.label, 1073 "area": prop.area, 1074 "bbox": to_bbox(prop.bbox), 1075 "crop_box": crop_box, 1076 "seg_id": prop.label, 1077 } for prop in props 1078 ] 1079 return masks 1080 1081 def generate( 1082 self, 1083 center_distance_threshold: float = 0.5, 1084 boundary_distance_threshold: float = 0.5, 1085 foreground_threshold: float = 0.5, 1086 foreground_smoothing: float = 1.0, 1087 distance_smoothing: float = 1.6, 1088 min_size: int = 0, 1089 output_mode: str = "instance_segmentation", 1090 tile_shape: Optional[Tuple[int, int]] = None, 1091 halo: Optional[Tuple[int, int]] = None, 1092 n_threads: Optional[int] = None, 1093 optimize_memory: bool = False, 1094 segmentation: Optional[np.ndarray] = None, 1095 ) -> Union[List[Dict[str, Any]], np.ndarray]: 1096 """Generate instance segmentation for the currently initialized image. 1097 1098 Args: 1099 center_distance_threshold: Center distance predictions below this value will be 1100 used to find seeds (intersected with thresholded boundary distance predictions). 1101 By default, set to '0.5'. 1102 boundary_distance_threshold: Boundary distance predictions below this value will be 1103 used to find seeds (intersected with thresholded center distance predictions). 1104 By default, set to '0.5'. 1105 foreground_threshold: Foreground predictions above this value will be used as foreground mask. 1106 By default, set to '0.5'. 1107 foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid 1108 checkerboard artifacts in the prediction. By default, set to '1.0'. 1109 distance_smoothing: Sigma value for smoothing the distance predictions. 1110 min_size: Minimal object size in the segmentation result. By default, set to '0'. 1111 output_mode: The form masks are returned in. Possible values are: 1112 - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks. 1113 - 'instance_segmentation': Return masks merged into an instance segmentation in a single array. 1114 By default, set to 'instance_segmentation'. 1115 tile_shape: Tile shape for parallelizing the instance segmentation post-processing. 1116 This parameter is independent from the tile shape for computing the embeddings. 1117 If not given then post-processing will not be parallelized. 1118 halo: Halo for parallel post-processing. See also `tile_shape`. 1119 n_threads: Number of threads for parallel post-processing. See also `tile_shape`. 1120 optimize_memory: Whether to optimize the memory consumption by allocating intermediate files. 1121 segmentation: Optional pre-allocated segmentation. 1122 1123 Returns: 1124 The segmentation masks. 1125 """ 1126 if not self.is_initialized: 1127 raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.") 1128 1129 if foreground_smoothing > 0: 1130 foreground = _apply_smoothing(self._foreground, foreground_smoothing, tile_shape, n_threads) 1131 else: 1132 foreground = self._foreground 1133 1134 if tile_shape is None: 1135 segmentation = watershed_from_center_and_boundary_distances( 1136 center_distances=self._center_distances, 1137 boundary_distances=self._boundary_distances, 1138 foreground_map=foreground, 1139 center_distance_threshold=center_distance_threshold, 1140 boundary_distance_threshold=boundary_distance_threshold, 1141 foreground_threshold=foreground_threshold, 1142 distance_smoothing=distance_smoothing, 1143 min_size=min_size, 1144 ) 1145 else: 1146 if halo is None: 1147 raise ValueError("You must pass a value for halo if tile_shape is given.") 1148 1149 # Shards are not thread-safe for parallel writing! So if we have shards we have to use them for tiling. 1150 # This is ok in terms efficiency as GPU tiles are small; shards should still be manegable for the watershed. 1151 if isinstance(segmentation, zarr.Array) and getattr(segmentation, "shards", None) is not None: 1152 tile_shape = segmentation.shards 1153 1154 segmentation = _watershed_from_center_and_boundary_distances_parallel( 1155 center_distances=self._center_distances, 1156 boundary_distances=self._boundary_distances, 1157 foreground_map=foreground, 1158 center_distance_threshold=center_distance_threshold, 1159 boundary_distance_threshold=boundary_distance_threshold, 1160 foreground_threshold=foreground_threshold, 1161 distance_smoothing=distance_smoothing, 1162 min_size=min_size, 1163 tile_shape=tile_shape, 1164 halo=halo, 1165 n_threads=n_threads, 1166 verbose=False, 1167 optimize_memory=optimize_memory, 1168 segmentation=segmentation, 1169 ) 1170 1171 if output_mode != "instance_segmentation": 1172 segmentation = self._to_masks(segmentation, output_mode) 1173 return segmentation 1174 1175 def get_state(self) -> Dict[str, Any]: 1176 """Get the initialized state of the instance segmenter. 1177 1178 Returns: 1179 Instance segmentation state. 1180 """ 1181 if not self.is_initialized: 1182 raise RuntimeError("The state has not been computed yet. Call initialize first.") 1183 1184 return { 1185 "foreground": self._foreground, 1186 "center_distances": self._center_distances, 1187 "boundary_distances": self._boundary_distances, 1188 } 1189 1190 def set_state(self, state: Dict[str, Any]) -> None: 1191 """Set the state of the instance segmenter. 1192 1193 Args: 1194 state: The instance segmentation state 1195 """ 1196 self._foreground = state["foreground"] 1197 self._center_distances = state["center_distances"] 1198 self._boundary_distances = state["boundary_distances"] 1199 self._is_initialized = True 1200 1201 def clear_state(self): 1202 """Clear the state of the instance segmenter. 1203 """ 1204 self._foreground = None 1205 self._center_distances = None 1206 self._boundary_distances = None 1207 self._is_initialized = False 1208 1209 1210class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder): 1211 """Same as `InstanceSegmentationWithDecoder` but for tiled image embeddings. 1212 """ 1213 1214 # Apply the decoder in a batched fashion, and then perform the resizing independently per output. 1215 # This is necessary, because the individual tiles may have different tile shapes due to border tiles. 1216 def _predict_decoder(self, batched_embeddings, input_shapes, original_shapes): 1217 batched_embeddings = torch.cat(batched_embeddings) 1218 output = self._decoder._forward_impl(batched_embeddings) 1219 1220 batched_output = [] 1221 for x, input_shape, original_shape in zip(output, input_shapes, original_shapes): 1222 x = self._decoder.postprocess_masks(x.unsqueeze(0), input_shape, original_shape).squeeze(0) 1223 batched_output.append(x.cpu().numpy()) 1224 return batched_output 1225 1226 @torch.no_grad() 1227 def initialize( 1228 self, 1229 image: np.ndarray, 1230 image_embeddings: Optional[util.ImageEmbeddings] = None, 1231 i: Optional[int] = None, 1232 tile_shape: Optional[Tuple[int, int]] = None, 1233 halo: Optional[Tuple[int, int]] = None, 1234 verbose: bool = False, 1235 pbar_init: Optional[callable] = None, 1236 pbar_update: Optional[callable] = None, 1237 batch_size: int = 1, 1238 mask: Optional[np.typing.ArrayLike] = None, 1239 ) -> None: 1240 """Initialize image embeddings and decoder predictions for an image. 1241 1242 Args: 1243 image: The input image, volume or timeseries. 1244 image_embeddings: Optional precomputed image embeddings. 1245 See `util.precompute_image_embeddings` for details. 1246 i: Index for the image data. Required if `image` has three spatial dimensions 1247 or a time dimension and two spatial dimensions. 1248 tile_shape: Shape of the tiles for precomputing image embeddings. 1249 halo: Overlap of the tiles for tiled precomputation of image embeddings. 1250 verbose: Dummy input to be compatible with other function signatures. By default, set to 'False'. 1251 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 1252 Can be used together with pbar_update to handle napari progress bar in other thread. 1253 To enable using this function within a threadworker. 1254 pbar_update: Callback to update an external progress bar. 1255 batch_size: The batch size for image embedding computation and segmentation decoder prediction. 1256 mask: An optional mask to define areas that are ignored in the segmentation. 1257 """ 1258 original_size = image.shape[:2] 1259 self._image_embeddings, tile_shape, halo, tiles_in_mask = _process_tiled_embeddings( 1260 self._predictor, image, image_embeddings, tile_shape, halo, 1261 verbose=verbose, batch_size=batch_size, mask=mask, i=i, 1262 ) 1263 tiling = Blocking([0, 0], original_size, tile_shape) 1264 1265 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 1266 1267 foreground = np.zeros(original_size, dtype="float32") 1268 center_distances = np.zeros(original_size, dtype="float32") 1269 boundary_distances = np.zeros(original_size, dtype="float32") 1270 1271 msg = "Initialize tiled instance segmentation with decoder" 1272 if tiles_in_mask is None: 1273 n_tiles = tiling.number_of_blocks 1274 all_tile_ids = list(range(n_tiles)) 1275 else: 1276 n_tiles = len(tiles_in_mask) 1277 all_tile_ids = tiles_in_mask 1278 msg += " and mask" 1279 1280 n_batches = int(np.ceil(n_tiles / batch_size)) 1281 pbar_init(n_tiles, msg) 1282 tile_ids_for_batches = np.array_split(all_tile_ids, n_batches) 1283 1284 for tile_ids in tile_ids_for_batches: 1285 batched_embeddings, input_shapes, original_shapes = [], [], [] 1286 for tile_id in tile_ids: 1287 # Get the image embeddings from the predictor for this tile. 1288 self._predictor = util.set_precomputed(self._predictor, self._image_embeddings, i=i, tile_id=tile_id) 1289 1290 batched_embeddings.append(self._predictor.features) 1291 input_shapes.append(tuple(self._predictor.input_size)) 1292 original_shapes.append(tuple(self._predictor.original_size)) 1293 1294 batched_output = self._predict_decoder(batched_embeddings, input_shapes, original_shapes) 1295 1296 for output_id, tile_id in enumerate(tile_ids): 1297 output = batched_output[output_id] 1298 assert output.shape[0] == 3 1299 1300 # Set the predictions in the output for this tile. 1301 block = tiling.get_block_with_halo(tile_id, halo=list(halo)) 1302 local_bb = tuple( 1303 slice(beg, end) for beg, end in zip(block.inner_block_local.begin, block.inner_block_local.end) 1304 ) 1305 inner_bb = tuple(slice(beg, end) for beg, end in zip(block.inner_block.begin, block.inner_block.end)) 1306 1307 foreground[inner_bb] = output[0][local_bb] 1308 center_distances[inner_bb] = output[1][local_bb] 1309 boundary_distances[inner_bb] = output[2][local_bb] 1310 pbar_update(1) 1311 1312 pbar_close() 1313 1314 # Set the state. 1315 self._i = i 1316 self._foreground = foreground 1317 self._center_distances = center_distances 1318 self._boundary_distances = boundary_distances 1319 self._is_initialized = True 1320 1321 1322def _get_centers(segmentation, avoid_image_border=True): 1323 # Not sure if 'find_boundaries' has to be paralellized. 1324 # Compute the distance transform to object bounaries. 1325 boundaries = find_boundaries(segmentation, mode="outer") == 0 1326 # Avoid centroids on the border. 1327 if avoid_image_border: 1328 boundaries[0, :] = False 1329 boundaries[:, 0] = False 1330 boundaries[-1, :] = False 1331 boundaries[:, -1] = False 1332 block_shape = (512, 512) 1333 halo = (16, 16) 1334 distances = parallel_impl.distance_transform(boundaries, halo=halo, block_shape=block_shape) 1335 1336 # Get the maximum coordinate of the distance transform for each id. 1337 props = regionprops(segmentation) 1338 centers = [] 1339 for prop in props: 1340 seg_id = prop.label 1341 # Get the bounding box and mask for this segmentation id. 1342 bounding_box = np.s_[prop.bbox[0]:prop.bbox[2], prop.bbox[1]:prop.bbox[3]] 1343 mask = segmentation[bounding_box] == seg_id 1344 # Restrict the distances to the bounding box. 1345 dist = distances[bounding_box].copy() 1346 # Set the distances outside of the mask to 0. 1347 dist[~mask] = 0 1348 # Find the center coordinate (= distance maximum). 1349 center = np.argmax(dist) 1350 center = np.unravel_index(center, dist.shape) 1351 # Bring the center coordinate back to the full coordinate system. 1352 center = tuple(ce + bb.start for ce, bb in zip(center, bounding_box)) 1353 centers.append(center) 1354 1355 return np.array(centers) 1356 1357 1358def _derive_point_prompts( 1359 foreground: np.ndarray, 1360 center_distances: np.ndarray, 1361 boundary_distances: np.ndarray, 1362 foreground_threshold: float = 0.5, 1363 center_distance_threshold: float = 0.5, 1364 boundary_distance_threshold: float = 0.5, 1365): 1366 bg_mask = foreground < foreground_threshold 1367 hmap_cc = np.logical_and( 1368 center_distances < center_distance_threshold, boundary_distances < boundary_distance_threshold, 1369 ) 1370 hmap_cc[bg_mask] = 0 1371 cc = np.zeros_like(hmap_cc, dtype="uint32") 1372 cc = parallel_impl.label(hmap_cc, out=cc, block_shape=(512, 512)) 1373 prompts = _get_centers(cc) 1374 if len(prompts) == 0: 1375 return None 1376 1377 points = prompts[:, None, ::-1] 1378 labels = np.ones((len(prompts), 1)) 1379 return {"points": points, "point_labels": labels} 1380 1381 1382def _derive_box_prompts(predictions, box_extension): 1383 shape = predictions[0]["segmentation"].shape 1384 bboxes = [pred["bbox"] for pred in predictions] 1385 prompts = [[ 1386 max(x - w * box_extension, 0), 1387 max(y - h * box_extension, 0), 1388 min(x + (1 + box_extension) * w, shape[0]), 1389 min(y + (1 + box_extension) * h, shape[1]), 1390 ] for (x, y, w, h) in bboxes] 1391 return {"boxes": np.array(prompts)} 1392 1393 1394class AutomaticPromptGenerator(InstanceSegmentationWithDecoder): 1395 """Generates an instance segmentation automatically, using automatically generated prompts from a decoder. 1396 1397 This class is used in the same way as `InstanceSegmentationWithDecoder` and `AutomaticMaskGenerator` 1398 1399 Args: 1400 predictor: The segment anything predictor. 1401 decoder: The derive prompts for automatic instance segmentation. 1402 """ 1403 def generate( 1404 self, 1405 min_size: int = 25, 1406 center_distance_threshold: float = 0.5, 1407 boundary_distance_threshold: float = 0.5, 1408 foreground_threshold: float = 0.5, 1409 multimasking: bool = False, 1410 batch_size: int = 32, 1411 nms_threshold: float = 0.9, 1412 intersection_over_min: bool = False, 1413 output_mode: str = "instance_segmentation", 1414 mask_threshold: Optional[Union[float, str]] = None, 1415 refine_with_box_prompts: bool = False, 1416 prompt_function: Optional[callable] = None, 1417 ) -> Union[List[Dict[str, Any]], np.ndarray]: 1418 """Generate the instance segmentation for the currently initialized image. 1419 1420 The instance segmentation is generated by deriving prompts from the foreground and 1421 distance predictions of the segmentation decoder by thresholding these predictions, 1422 intersecting them, computing connected components, and then using the component's 1423 centers as point prompts. The masks are then filtered via NMS and merged into a segmentation. 1424 1425 Args: 1426 min_size: Minimal object size in the segmentation result. By default, set to '25'. 1427 center_distance_threshold: The threshold for the center distance predictions. 1428 boundary_distance_threshold: The threshold for the boundary distance predictions. 1429 multimasking: Whether to use multi-mask prediction for turning the prompts into masks. 1430 batch_size: The batch size for parallelizing the prediction based on prompts. 1431 nms_threshold: The threshold for non-maximum suppression (NMS). 1432 intersection_over_min: Whether to use the minimum area of the two objects or the 1433 intersection over union area (default) in NMS. 1434 output_mode: The form masks are returned in. Possible values are: 1435 - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks. 1436 - 'instance_segmentation': Return masks merged into an instance segmentation in a single array. 1437 By default, set to 'instance_segmentation'. 1438 mask_threshold: The threshold for turning logits into masks in `micro_sam.inference.batched_inference`.` 1439 refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps 1440 derived from the segmentations after point prompts. 1441 prompt_function: A custom function for deriving prompts from the segmentation decoder predictions. 1442 If given, the default prompt derivation procedure is not used. Must have the following signature: 1443 ``` 1444 def prompt_function(foreground, center_distances, boundary_distances, **kwargs) 1445 ``` 1446 where `foreground`, `center_distances`, and `boundary_distances` are the respective 1447 predictions from the segmentation decoder. It must returns a dictionary containing 1448 either point, box, or mask prompts in a format compattible with `micro_sam.inference.batched_inference`. 1449 1450 Returns: 1451 The instance segmentation masks. 1452 """ 1453 if not self.is_initialized: 1454 raise RuntimeError("AutomaticPromptGenerator has not been initialized. Call initialize first.") 1455 foreground, center_distances, boundary_distances =\ 1456 self._foreground, self._center_distances, self._boundary_distances 1457 1458 # 1.) Derive promtps from the decoder predictions. 1459 prompt_function = _derive_point_prompts if prompt_function is None else prompt_function 1460 prompts = prompt_function( 1461 foreground=foreground, 1462 center_distances=center_distances, 1463 boundary_distances=boundary_distances, 1464 foreground_threshold=foreground_threshold, 1465 center_distance_threshold=center_distance_threshold, 1466 boundary_distance_threshold=boundary_distance_threshold, 1467 ) 1468 1469 # 2.) Apply the predictor to the prompts. 1470 if prompts is None: # No prompts were derived, we can't do much further and return empty masks. 1471 return np.zeros(foreground.shape, dtype="uint32") if output_mode == "instance_segmentation" else [] 1472 else: 1473 predictions = batched_inference( 1474 self._predictor, 1475 image=None, 1476 batch_size=batch_size, 1477 return_instance_segmentation=False, 1478 multimasking=multimasking, 1479 mask_threshold=mask_threshold, 1480 i=getattr(self, "_i", None), 1481 **prompts, 1482 ) 1483 1484 # 3.) Refine the segmentation with box prompts. 1485 if refine_with_box_prompts: 1486 box_extension = 0.01 # expose as hyperparam? 1487 prompts = _derive_box_prompts(predictions, box_extension) 1488 predictions = batched_inference( 1489 self._predictor, 1490 image=None, 1491 batch_size=batch_size, 1492 return_instance_segmentation=False, 1493 multimasking=multimasking, 1494 mask_threshold=mask_threshold, 1495 i=getattr(self, "_i", None), 1496 **prompts, 1497 ) 1498 1499 # 4.) Apply non-max suppression to the masks. 1500 segmentation = util.apply_nms( 1501 predictions, min_size=min_size, nms_thresh=nms_threshold, intersection_over_min=intersection_over_min 1502 ) 1503 if output_mode != "instance_segmentation": 1504 segmentation = self._to_masks(segmentation, output_mode) 1505 return segmentation 1506 1507 1508class TiledAutomaticPromptGenerator(TiledInstanceSegmentationWithDecoder): 1509 """Same as `AutomaticPromptGenerator` but for tiled image embeddings. 1510 """ 1511 def generate( 1512 self, 1513 min_size: int = 25, 1514 center_distance_threshold: float = 0.5, 1515 boundary_distance_threshold: float = 0.5, 1516 foreground_threshold: float = 0.5, 1517 multimasking: bool = False, 1518 batch_size: int = 32, 1519 nms_threshold: float = 0.9, 1520 intersection_over_min: bool = False, 1521 output_mode: str = "instance_segmentation", 1522 mask_threshold: Optional[Union[float, str]] = None, 1523 refine_with_box_prompts: bool = False, 1524 prompt_function: Optional[callable] = None, 1525 optimize_memory: bool = False, 1526 ) -> List[Dict[str, Any]]: 1527 """Generate tiling-based instance segmentation for the currently initialized image. 1528 1529 Args: 1530 min_size: Minimal object size in the segmentation result. By default, set to '25'. 1531 center_distance_threshold: The threshold for the center distance predictions. 1532 boundary_distance_threshold: The threshold for the boundary distance predictions. 1533 multimasking: Whether to use multi-mask prediction for turning the prompts into masks. 1534 batch_size: The batch size for parallelizing the prediction based on prompts. 1535 nms_threshold: The threshold for non-maximum suppression (NMS). 1536 intersection_over_min: Whether to use the minimum area of the two objects or the 1537 intersection over union area (default) in NMS. 1538 output_mode: The form masks are returned in. Possible values are: 1539 - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks. 1540 - 'instance_segmentation': Return masks merged into an instance segmentation in a single array. 1541 By default, set to 'instance_segmentation'. 1542 mask_threshold: The threshold for turining logits into masks in `micro_sam.inference.batched_inference`.` 1543 refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps 1544 derived from the segmentations after point prompts. Currently not supported for tiled segmentation. 1545 prompt_function: A custom function for deriving prompts from the segmentation decoder predictions. 1546 If given, the default prompt derivation procedure is not used. Must have the following signature: 1547 ``` 1548 def prompt_function(foreground, center_distances, boundary_distances, **kwargs) 1549 ``` 1550 where `foreground`, `center_distances`, and `boundary_distances` are the respective 1551 predictions from the segmentation decoder. It must returns a dictionary containing 1552 either point, box, or mask prompts in a format compattible with `micro_sam.inference.batched_inference`. 1553 optimize_memory: Whether to optimize the memory consumption by merging the per-slice 1554 segmentation results immediatly with NMS, rather than running a NMS for all results. 1555 This may lead to a slightly different segmentation result and is only compatible with 1556 `refine_with_box_prompts=False` and `output_mode="instance_segmentation"`. 1557 1558 Returns: 1559 The instance segmentation masks. 1560 """ 1561 if not self.is_initialized: 1562 raise RuntimeError("TiledAutomaticPromptGenerator has not been initialized. Call initialize first.") 1563 if optimize_memory and (output_mode != "instance_segmentation" or refine_with_box_prompts): 1564 raise ValueError("Invalid settings") 1565 foreground, center_distances, boundary_distances =\ 1566 self._foreground, self._center_distances, self._boundary_distances 1567 1568 # 1.) Derive promtps from the decoder predictions. 1569 prompt_function = _derive_point_prompts if prompt_function is None else prompt_function 1570 prompts = prompt_function( 1571 foreground, 1572 center_distances, 1573 boundary_distances, 1574 foreground_threshold=foreground_threshold, 1575 center_distance_threshold=center_distance_threshold, 1576 boundary_distance_threshold=boundary_distance_threshold, 1577 ) 1578 1579 # 2.) Apply the predictor to the prompts. 1580 shape = foreground.shape 1581 if prompts is None: # No prompts were derived, we can't do much further and return empty masks. 1582 return np.zeros(shape, dtype="uint32") if output_mode == "instance_segmentation" else [] 1583 else: 1584 if optimize_memory: 1585 prompts.update(dict( 1586 min_size=min_size, nms_thresh=nms_threshold, intersection_over_min=intersection_over_min 1587 )) 1588 predictions = batched_tiled_inference( 1589 self._predictor, 1590 image=None, 1591 batch_size=batch_size, 1592 image_embeddings=self._image_embeddings, 1593 return_instance_segmentation=False, 1594 multimasking=multimasking, 1595 optimize_memory=optimize_memory, 1596 i=getattr(self, "_i", None), 1597 **prompts 1598 ) 1599 # Optimize memory directly returns an instance segmentation and does not 1600 # allow for any further refinements. 1601 if optimize_memory: 1602 return predictions 1603 1604 # 3.) Refine the segmentation with box prompts. 1605 if refine_with_box_prompts: 1606 # TODO 1607 raise NotImplementedError 1608 1609 # 4.) Apply non-max suppression to the masks. 1610 segmentation = util.apply_nms( 1611 predictions, shape=shape, min_size=min_size, nms_thresh=nms_threshold, 1612 intersection_over_min=intersection_over_min, 1613 ) 1614 if output_mode != "instance_segmentation": 1615 segmentation = self._to_masks(segmentation, output_mode) 1616 return segmentation 1617 1618 # Set state and get state are not implemented yet, as this generator relies on having the image embeddings 1619 # in the state. However, they should not be serialized here and we have to address this a bit differently. 1620 def get_state(self): 1621 """@private 1622 """ 1623 raise NotImplementedError 1624 1625 def set_state(self, state): 1626 """@private 1627 """ 1628 raise NotImplementedError 1629 1630 1631def get_instance_segmentation_generator( 1632 predictor: SamPredictor, 1633 is_tiled: bool, 1634 decoder: Optional[torch.nn.Module] = None, 1635 segmentation_mode: Optional[Literal["amg", "ais", "apg"]] = None, 1636 **kwargs, 1637) -> Union[AMGBase, InstanceSegmentationWithDecoder]: 1638 f"""Get the automatic mask generator. 1639 1640 Args: 1641 predictor: The segment anything predictor. 1642 is_tiled: Whether tiled embeddings are used. 1643 decoder: Decoder to predict instacne segmmentation. 1644 segmentation_mode: The segmentation mode. One of 'amg', 'ais', or 'apg'. 1645 By default, '{DEFAULT_SEGMENTATION_MODE_WITH_DECODER}' is used if a decoder is passed, 1646 otherwise 'amg' is used. 1647 kwargs: The keyword arguments of the segmentation genetator class. 1648 1649 Returns: 1650 The segmentation generator instance. 1651 """ 1652 # Choose the segmentation decoder default depending on whether we have a decoder. 1653 if segmentation_mode is None: 1654 segmentation_mode = "amg" if decoder is None else DEFAULT_SEGMENTATION_MODE_WITH_DECODER 1655 1656 if segmentation_mode.lower() == "amg": 1657 segmenter_class = TiledAutomaticMaskGenerator if is_tiled else AutomaticMaskGenerator 1658 segmenter = segmenter_class(predictor, **kwargs) 1659 elif segmentation_mode.lower() == "ais": 1660 assert decoder is not None 1661 segmenter_class = TiledInstanceSegmentationWithDecoder if is_tiled else InstanceSegmentationWithDecoder 1662 segmenter = segmenter_class(predictor, decoder, **kwargs) 1663 elif segmentation_mode.lower() == "apg": 1664 assert decoder is not None 1665 segmenter_class = TiledAutomaticPromptGenerator if is_tiled else AutomaticPromptGenerator 1666 segmenter = segmenter_class(predictor, decoder, **kwargs) 1667 else: 1668 raise ValueError(f"Invalid segmentation_mode: {segmentation_mode}. Choose one of 'amg', 'ais', or 'apg'.") 1669 1670 return segmenter
66class AMGBase(ABC): 67 """Base class for the automatic mask generators. 68 """ 69 def __init__(self): 70 # the state that has to be computed by the 'initialize' method of the child classes 71 self._is_initialized = False 72 self._crop_list = None 73 self._crop_boxes = None 74 self._original_size = None 75 76 @property 77 def is_initialized(self): 78 """Whether the mask generator has already been initialized. 79 """ 80 return self._is_initialized 81 82 @property 83 def crop_list(self): 84 """The list of mask data after initialization. 85 """ 86 return self._crop_list 87 88 @property 89 def crop_boxes(self): 90 """The list of crop boxes. 91 """ 92 return self._crop_boxes 93 94 @property 95 def original_size(self): 96 """The original image size. 97 """ 98 return self._original_size 99 100 def _postprocess_batch( 101 self, 102 data, 103 crop_box, 104 original_size, 105 pred_iou_thresh, 106 stability_score_thresh, 107 box_nms_thresh, 108 ): 109 orig_h, orig_w = original_size 110 111 # filter by predicted IoU 112 if pred_iou_thresh > 0.0: 113 keep_mask = data["iou_preds"] > pred_iou_thresh 114 data.filter(keep_mask) 115 116 # filter by stability score 117 if stability_score_thresh > 0.0: 118 keep_mask = data["stability_score"] >= stability_score_thresh 119 data.filter(keep_mask) 120 121 # filter boxes that touch crop boundaries 122 keep_mask = ~amg_utils.is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) 123 if not torch.all(keep_mask): 124 data.filter(keep_mask) 125 126 # remove duplicates within this crop. 127 keep_by_nms = batched_nms( 128 data["boxes"].float(), 129 data["iou_preds"], 130 torch.zeros_like(data["boxes"][:, 0]), # categories 131 iou_threshold=box_nms_thresh, 132 ) 133 data.filter(keep_by_nms) 134 135 # return to the original image frame 136 data["boxes"] = amg_utils.uncrop_boxes_xyxy(data["boxes"], crop_box) 137 data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) 138 # the data from embedding based segmentation doesn't have the points 139 # so we skip if the corresponding key can't be found 140 try: 141 data["points"] = amg_utils.uncrop_points(data["points"], crop_box) 142 except KeyError: 143 pass 144 145 return data 146 147 def _postprocess_small_regions(self, mask_data, min_area, nms_thresh): 148 149 if len(mask_data["rles"]) == 0: 150 return mask_data 151 152 # filter small disconnected regions and holes 153 new_masks = [] 154 scores = [] 155 for rle in mask_data["rles"]: 156 mask = amg_utils.rle_to_mask(rle) 157 158 mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="holes") 159 unchanged = not changed 160 mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="islands") 161 unchanged = unchanged and not changed 162 163 new_masks.append(torch.as_tensor(mask, dtype=torch.int).unsqueeze(0)) 164 # give score=0 to changed masks and score=1 to unchanged masks 165 # so NMS will prefer ones that didn't need postprocessing 166 scores.append(float(unchanged)) 167 168 # recalculate boxes and remove any new duplicates 169 masks = torch.cat(new_masks, dim=0) 170 boxes = batched_mask_to_box(masks.to(torch.bool)) # Casting this to boolean as we work with one-hot labels. 171 keep_by_nms = batched_nms( 172 boxes.float(), 173 torch.as_tensor(scores, dtype=torch.float), 174 torch.zeros_like(boxes[:, 0]), # categories 175 iou_threshold=nms_thresh, 176 ) 177 178 # only recalculate RLEs for masks that have changed 179 for i_mask in keep_by_nms: 180 if scores[i_mask] == 0.0: 181 mask_torch = masks[i_mask].unsqueeze(0) 182 # mask_data["rles"][i_mask] = amg_utils.mask_to_rle_pytorch(mask_torch)[0] 183 mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] 184 mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly 185 mask_data.filter(keep_by_nms) 186 187 return mask_data 188 189 def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode): 190 # filter small disconnected regions and holes in masks 191 if min_mask_region_area > 0: 192 mask_data = self._postprocess_small_regions( 193 mask_data, 194 min_mask_region_area, 195 max(box_nms_thresh, crop_nms_thresh), 196 ) 197 198 # encode masks 199 if output_mode == "coco_rle": 200 mask_data["segmentations"] = [amg_utils.coco_encode_rle(rle) for rle in mask_data["rles"]] 201 # We require the binary mask output for creating the instance-seg output. 202 elif output_mode in ("binary_mask", "instance_segmentation"): 203 mask_data["segmentations"] = [amg_utils.rle_to_mask(rle) for rle in mask_data["rles"]] 204 elif output_mode == "rle": 205 mask_data["segmentations"] = mask_data["rles"] 206 else: 207 raise ValueError(f"Invalid output mode {output_mode}.") 208 209 # write mask records 210 curr_anns = [] 211 for idx in range(len(mask_data["segmentations"])): 212 ann = { 213 "segmentation": mask_data["segmentations"][idx], 214 "area": amg_utils.area_from_rle(mask_data["rles"][idx]), 215 "bbox": amg_utils.box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), 216 "predicted_iou": mask_data["iou_preds"][idx].item(), 217 "stability_score": mask_data["stability_score"][idx].item(), 218 "crop_box": amg_utils.box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), 219 } 220 # the data from embedding based segmentation doesn't have the points 221 # so we skip if the corresponding key can't be found 222 try: 223 ann["point_coords"] = [mask_data["points"][idx].tolist()] 224 except KeyError: 225 pass 226 curr_anns.append(ann) 227 228 return curr_anns 229 230 def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None): 231 orig_h, orig_w = original_size 232 233 # serialize predictions and store in MaskData 234 data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1)) 235 if points is not None: 236 data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0), dtype=torch.float) 237 238 del masks 239 240 # calculate the stability scores 241 data["stability_score"] = amg_utils.calculate_stability_score( 242 data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset 243 ) 244 245 # threshold masks and calculate boxes 246 data["masks"] = data["masks"] > self._predictor.model.mask_threshold 247 data["masks"] = data["masks"].type(torch.bool) 248 data["boxes"] = batched_mask_to_box(data["masks"]) 249 250 # compress to RLE 251 data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w) 252 # data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"]) 253 data["rles"] = mask_to_rle_pytorch(data["masks"]) 254 del data["masks"] 255 256 return data 257 258 def get_state(self) -> Dict[str, Any]: 259 """Get the initialized state of the mask generator. 260 261 Returns: 262 State of the mask generator. 263 """ 264 if not self.is_initialized: 265 raise RuntimeError("The state has not been computed yet. Call initialize first.") 266 267 return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size} 268 269 def set_state(self, state: Dict[str, Any]) -> None: 270 """Set the state of the mask generator. 271 272 Args: 273 state: The state of the mask generator, e.g. from serialized state. 274 """ 275 self._crop_list = state["crop_list"] 276 self._crop_boxes = state["crop_boxes"] 277 self._original_size = state["original_size"] 278 self._is_initialized = True 279 280 def clear_state(self): 281 """Clear the state of the mask generator. 282 """ 283 self._crop_list = None 284 self._crop_boxes = None 285 self._original_size = None 286 self._is_initialized = False
Base class for the automatic mask generators.
76 @property 77 def is_initialized(self): 78 """Whether the mask generator has already been initialized. 79 """ 80 return self._is_initialized
Whether the mask generator has already been initialized.
82 @property 83 def crop_list(self): 84 """The list of mask data after initialization. 85 """ 86 return self._crop_list
The list of mask data after initialization.
88 @property 89 def crop_boxes(self): 90 """The list of crop boxes. 91 """ 92 return self._crop_boxes
The list of crop boxes.
94 @property 95 def original_size(self): 96 """The original image size. 97 """ 98 return self._original_size
The original image size.
258 def get_state(self) -> Dict[str, Any]: 259 """Get the initialized state of the mask generator. 260 261 Returns: 262 State of the mask generator. 263 """ 264 if not self.is_initialized: 265 raise RuntimeError("The state has not been computed yet. Call initialize first.") 266 267 return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size}
Get the initialized state of the mask generator.
Returns:
State of the mask generator.
269 def set_state(self, state: Dict[str, Any]) -> None: 270 """Set the state of the mask generator. 271 272 Args: 273 state: The state of the mask generator, e.g. from serialized state. 274 """ 275 self._crop_list = state["crop_list"] 276 self._crop_boxes = state["crop_boxes"] 277 self._original_size = state["original_size"] 278 self._is_initialized = True
Set the state of the mask generator.
Arguments:
- state: The state of the mask generator, e.g. from serialized state.
289class AutomaticMaskGenerator(AMGBase): 290 """Generates an instance segmentation without prompts, using a point grid. 291 292 This class implements the same logic as 293 https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py 294 It decouples the computationally expensive steps of generating masks from the cheap post-processing operation 295 to filter these masks to enable grid search and interactively changing the post-processing. 296 297 Use this class as follows: 298 ```python 299 amg = AutomaticMaskGenerator(predictor) 300 amg.initialize(image) # Initialize the masks, this takes care of all expensive computations. 301 masks = amg.generate(pred_iou_thresh=0.8) # Generate the masks. This is fast and enables testing parameters 302 ``` 303 304 Args: 305 predictor: The segment anything predictor. 306 points_per_side: The number of points to be sampled along one side of the image. 307 If None, `point_grids` must provide explicit point sampling. By default, set to '32'. 308 points_per_batch: The number of points run simultaneously by the model. 309 Higher numbers may be faster but use more GPU memory. 310 By default, selects '64' for all devices except 'mps' (selects '16' for performance reasons). 311 crop_n_layers: If >0, the mask prediction will be run again on crops of the image. 312 By default, set to '0'. 313 crop_overlap_ratio: Sets the degree to which crops overlap. By default, set to '512 / 1500'. 314 crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops. 315 By default, set to '1'. 316 point_grids: A list over explicit grids of points used for sampling masks. 317 Normalized to [0, 1] with respect to the image coordinate system. 318 stability_score_offset: The amount to shift the cutoff when calculating the stability score. 319 By default, set to '1.0'. 320 """ 321 def __init__( 322 self, 323 predictor: SamPredictor, 324 points_per_side: Optional[int] = 32, 325 points_per_batch: Optional[int] = None, 326 crop_n_layers: int = 0, 327 crop_overlap_ratio: float = 512 / 1500, 328 crop_n_points_downscale_factor: int = 1, 329 point_grids: Optional[List[np.ndarray]] = None, 330 stability_score_offset: float = 1.0, 331 ): 332 super().__init__() 333 334 if points_per_side is not None: 335 self.point_grids = amg_utils.build_all_layer_point_grids( 336 points_per_side, crop_n_layers, crop_n_points_downscale_factor, 337 ) 338 elif point_grids is not None: 339 self.point_grids = point_grids 340 else: 341 raise ValueError("Can't have both points_per_side and point_grid be None or not None.") 342 343 self._predictor = predictor 344 self._points_per_side = points_per_side 345 346 # we set the points per batch to 16 for mps for performance reasons 347 # and otherwise keep them at the default of 64 348 if points_per_batch is None: 349 points_per_batch = 16 if str(predictor.device) == "mps" else 64 350 self._points_per_batch = points_per_batch 351 352 self._crop_n_layers = crop_n_layers 353 self._crop_overlap_ratio = crop_overlap_ratio 354 self._crop_n_points_downscale_factor = crop_n_points_downscale_factor 355 self._stability_score_offset = stability_score_offset 356 357 def _process_batch(self, points, im_size, crop_box, original_size): 358 # run model on this batch 359 transformed_points = self._predictor.transform.apply_coords(points, im_size) 360 in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float) 361 in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) 362 masks, iou_preds, _ = self._predictor.predict_torch( 363 point_coords=in_points[:, None, :], 364 point_labels=in_labels[:, None], 365 multimask_output=True, 366 return_logits=True, 367 ) 368 data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points) 369 del masks 370 return data 371 372 def _process_crop(self, image, crop_box, crop_layer_idx, precomputed_embeddings, pbar_init=None, pbar_update=None): 373 # Crop the image and calculate embeddings. 374 x0, y0, x1, y1 = crop_box 375 cropped_im = image[y0:y1, x0:x1, :] 376 cropped_im_size = cropped_im.shape[:2] 377 378 if not precomputed_embeddings: 379 self._predictor.set_image(cropped_im) 380 381 # Get the points for this crop. 382 points_scale = np.array(cropped_im_size)[None, ::-1] 383 points_for_image = self.point_grids[crop_layer_idx] * points_scale 384 385 # Generate masks for this crop in batches. 386 data = amg_utils.MaskData() 387 n_batches = len(points_for_image) // self._points_per_batch +\ 388 int(len(points_for_image) % self._points_per_batch != 0) 389 if pbar_init is not None: 390 pbar_init(n_batches, "Predict masks for point grid prompts") 391 392 for (points,) in amg_utils.batch_iterator(self._points_per_batch, points_for_image): 393 batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size) 394 data.cat(batch_data) 395 del batch_data 396 if pbar_update is not None: 397 pbar_update(1) 398 399 if not precomputed_embeddings: 400 self._predictor.reset_image() 401 402 return data 403 404 @torch.no_grad() 405 def initialize( 406 self, 407 image: np.ndarray, 408 image_embeddings: Optional[util.ImageEmbeddings] = None, 409 i: Optional[int] = None, 410 verbose: bool = False, 411 pbar_init: Optional[callable] = None, 412 pbar_update: Optional[callable] = None, 413 ) -> None: 414 """Initialize image embeddings and masks for an image. 415 416 Args: 417 image: The input image, volume or timeseries. 418 image_embeddings: Optional precomputed image embeddings. 419 See `util.precompute_image_embeddings` for details. 420 i: Index for the image data. Required if `image` has three spatial dimensions 421 or a time dimension and two spatial dimensions. 422 verbose: Whether to print computation progress. By default, set to 'False'. 423 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 424 Can be used together with pbar_update to handle napari progress bar in other thread. 425 To enable using this function within a threadworker. 426 pbar_update: Callback to update an external progress bar. 427 """ 428 original_size = image.shape[:2] 429 self._original_size = original_size 430 431 crop_boxes, layer_idxs = amg_utils.generate_crop_boxes( 432 original_size, self._crop_n_layers, self._crop_overlap_ratio 433 ) 434 435 # We can set fixed image embeddings if we only have a single crop box (the default setting). 436 # Otherwise we have to recompute the embeddings for each crop and can't precompute. 437 if len(crop_boxes) == 1: 438 if image_embeddings is None: 439 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 440 util.set_precomputed(self._predictor, image_embeddings, i=i) 441 precomputed_embeddings = True 442 else: 443 precomputed_embeddings = False 444 445 # we need to cast to the image representation that is compatible with SAM 446 image = util._to_image(image) 447 448 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 449 450 crop_list = [] 451 for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 452 crop_data = self._process_crop( 453 image, crop_box, layer_idx, 454 precomputed_embeddings=precomputed_embeddings, 455 pbar_init=pbar_init, pbar_update=pbar_update, 456 ) 457 crop_list.append(crop_data) 458 pbar_close() 459 460 self._is_initialized = True 461 self._crop_list = crop_list 462 self._crop_boxes = crop_boxes 463 464 @torch.no_grad() 465 def generate( 466 self, 467 pred_iou_thresh: float = 0.88, 468 stability_score_thresh: float = 0.95, 469 box_nms_thresh: float = 0.7, 470 crop_nms_thresh: float = 0.7, 471 min_mask_region_area: int = 0, 472 output_mode: str = "instance_segmentation", 473 with_background: bool = True, 474 ) -> Union[List[Dict[str, Any]], np.ndarray]: 475 """Generate instance segmentation for the currently initialized image. 476 477 Args: 478 pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model. 479 By default, set to '0.88'. 480 stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask 481 under changes to the cutoff used to binarize the model prediction. By default, set to '0.95'. 482 box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks. 483 By default, set to '0.7'. 484 crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops. 485 By default, set to '0.7'. 486 min_mask_region_area: Minimal size for the predicted masks. By default, set to '0'. 487 output_mode: The form masks are returned in. Possible values are: 488 - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks. 489 - 'coco_rle': Return a list of dictionaries with run-length encoded masks in MS COCO format. 490 - 'instance_segmentation': Return masks merged into an instance segmentation in a single array. 491 - 'rle': Return a list of dictionaries with run-length encoded masks. 492 By default, set to 'instance_segmentation'. 493 with_background: Whether to remove the largest object, which often covers the background. 494 495 Returns: 496 The segmentation masks. 497 """ 498 if not self.is_initialized: 499 raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.") 500 501 data = amg_utils.MaskData() 502 for data_, crop_box in zip(self.crop_list, self.crop_boxes): 503 crop_data = self._postprocess_batch( 504 data=deepcopy(data_), 505 crop_box=crop_box, original_size=self.original_size, 506 pred_iou_thresh=pred_iou_thresh, 507 stability_score_thresh=stability_score_thresh, 508 box_nms_thresh=box_nms_thresh 509 ) 510 data.cat(crop_data) 511 512 if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0: 513 # Prefer masks from smaller crops 514 scores = 1 / box_area(data["crop_boxes"]) 515 scores = scores.to(data["boxes"].device) 516 keep_by_nms = batched_nms( 517 data["boxes"].float(), 518 scores, 519 torch.zeros_like(data["boxes"][:, 0]), # categories 520 iou_threshold=crop_nms_thresh, 521 ) 522 data.filter(keep_by_nms) 523 524 data.to_numpy() 525 masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode) 526 if output_mode == "instance_segmentation": 527 shape = next(iter(masks))["segmentation"].shape if len(masks) > 0 else self.original_size 528 masks = util.mask_data_to_segmentation( 529 masks, shape=shape, with_background=with_background, merge_exclusively=False 530 ) 531 return masks
Generates an instance segmentation without prompts, using a point grid.
This class implements the same logic as https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py It decouples the computationally expensive steps of generating masks from the cheap post-processing operation to filter these masks to enable grid search and interactively changing the post-processing.
Use this class as follows:
amg = AutomaticMaskGenerator(predictor)
amg.initialize(image) # Initialize the masks, this takes care of all expensive computations.
masks = amg.generate(pred_iou_thresh=0.8) # Generate the masks. This is fast and enables testing parameters
Arguments:
- predictor: The segment anything predictor.
- points_per_side: The number of points to be sampled along one side of the image.
If None,
point_gridsmust provide explicit point sampling. By default, set to '32'. - points_per_batch: The number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory. By default, selects '64' for all devices except 'mps' (selects '16' for performance reasons).
- crop_n_layers: If >0, the mask prediction will be run again on crops of the image. By default, set to '0'.
- crop_overlap_ratio: Sets the degree to which crops overlap. By default, set to '512 / 1500'.
- crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops. By default, set to '1'.
- point_grids: A list over explicit grids of points used for sampling masks. Normalized to [0, 1] with respect to the image coordinate system.
- stability_score_offset: The amount to shift the cutoff when calculating the stability score. By default, set to '1.0'.
321 def __init__( 322 self, 323 predictor: SamPredictor, 324 points_per_side: Optional[int] = 32, 325 points_per_batch: Optional[int] = None, 326 crop_n_layers: int = 0, 327 crop_overlap_ratio: float = 512 / 1500, 328 crop_n_points_downscale_factor: int = 1, 329 point_grids: Optional[List[np.ndarray]] = None, 330 stability_score_offset: float = 1.0, 331 ): 332 super().__init__() 333 334 if points_per_side is not None: 335 self.point_grids = amg_utils.build_all_layer_point_grids( 336 points_per_side, crop_n_layers, crop_n_points_downscale_factor, 337 ) 338 elif point_grids is not None: 339 self.point_grids = point_grids 340 else: 341 raise ValueError("Can't have both points_per_side and point_grid be None or not None.") 342 343 self._predictor = predictor 344 self._points_per_side = points_per_side 345 346 # we set the points per batch to 16 for mps for performance reasons 347 # and otherwise keep them at the default of 64 348 if points_per_batch is None: 349 points_per_batch = 16 if str(predictor.device) == "mps" else 64 350 self._points_per_batch = points_per_batch 351 352 self._crop_n_layers = crop_n_layers 353 self._crop_overlap_ratio = crop_overlap_ratio 354 self._crop_n_points_downscale_factor = crop_n_points_downscale_factor 355 self._stability_score_offset = stability_score_offset
404 @torch.no_grad() 405 def initialize( 406 self, 407 image: np.ndarray, 408 image_embeddings: Optional[util.ImageEmbeddings] = None, 409 i: Optional[int] = None, 410 verbose: bool = False, 411 pbar_init: Optional[callable] = None, 412 pbar_update: Optional[callable] = None, 413 ) -> None: 414 """Initialize image embeddings and masks for an image. 415 416 Args: 417 image: The input image, volume or timeseries. 418 image_embeddings: Optional precomputed image embeddings. 419 See `util.precompute_image_embeddings` for details. 420 i: Index for the image data. Required if `image` has three spatial dimensions 421 or a time dimension and two spatial dimensions. 422 verbose: Whether to print computation progress. By default, set to 'False'. 423 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 424 Can be used together with pbar_update to handle napari progress bar in other thread. 425 To enable using this function within a threadworker. 426 pbar_update: Callback to update an external progress bar. 427 """ 428 original_size = image.shape[:2] 429 self._original_size = original_size 430 431 crop_boxes, layer_idxs = amg_utils.generate_crop_boxes( 432 original_size, self._crop_n_layers, self._crop_overlap_ratio 433 ) 434 435 # We can set fixed image embeddings if we only have a single crop box (the default setting). 436 # Otherwise we have to recompute the embeddings for each crop and can't precompute. 437 if len(crop_boxes) == 1: 438 if image_embeddings is None: 439 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 440 util.set_precomputed(self._predictor, image_embeddings, i=i) 441 precomputed_embeddings = True 442 else: 443 precomputed_embeddings = False 444 445 # we need to cast to the image representation that is compatible with SAM 446 image = util._to_image(image) 447 448 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 449 450 crop_list = [] 451 for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 452 crop_data = self._process_crop( 453 image, crop_box, layer_idx, 454 precomputed_embeddings=precomputed_embeddings, 455 pbar_init=pbar_init, pbar_update=pbar_update, 456 ) 457 crop_list.append(crop_data) 458 pbar_close() 459 460 self._is_initialized = True 461 self._crop_list = crop_list 462 self._crop_boxes = crop_boxes
Initialize image embeddings and masks for an image.
Arguments:
- image: The input image, volume or timeseries.
- image_embeddings: Optional precomputed image embeddings.
See
util.precompute_image_embeddingsfor details. - i: Index for the image data. Required if
imagehas three spatial dimensions or a time dimension and two spatial dimensions. - verbose: Whether to print computation progress. By default, set to 'False'.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enable using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
464 @torch.no_grad() 465 def generate( 466 self, 467 pred_iou_thresh: float = 0.88, 468 stability_score_thresh: float = 0.95, 469 box_nms_thresh: float = 0.7, 470 crop_nms_thresh: float = 0.7, 471 min_mask_region_area: int = 0, 472 output_mode: str = "instance_segmentation", 473 with_background: bool = True, 474 ) -> Union[List[Dict[str, Any]], np.ndarray]: 475 """Generate instance segmentation for the currently initialized image. 476 477 Args: 478 pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model. 479 By default, set to '0.88'. 480 stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask 481 under changes to the cutoff used to binarize the model prediction. By default, set to '0.95'. 482 box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks. 483 By default, set to '0.7'. 484 crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops. 485 By default, set to '0.7'. 486 min_mask_region_area: Minimal size for the predicted masks. By default, set to '0'. 487 output_mode: The form masks are returned in. Possible values are: 488 - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks. 489 - 'coco_rle': Return a list of dictionaries with run-length encoded masks in MS COCO format. 490 - 'instance_segmentation': Return masks merged into an instance segmentation in a single array. 491 - 'rle': Return a list of dictionaries with run-length encoded masks. 492 By default, set to 'instance_segmentation'. 493 with_background: Whether to remove the largest object, which often covers the background. 494 495 Returns: 496 The segmentation masks. 497 """ 498 if not self.is_initialized: 499 raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.") 500 501 data = amg_utils.MaskData() 502 for data_, crop_box in zip(self.crop_list, self.crop_boxes): 503 crop_data = self._postprocess_batch( 504 data=deepcopy(data_), 505 crop_box=crop_box, original_size=self.original_size, 506 pred_iou_thresh=pred_iou_thresh, 507 stability_score_thresh=stability_score_thresh, 508 box_nms_thresh=box_nms_thresh 509 ) 510 data.cat(crop_data) 511 512 if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0: 513 # Prefer masks from smaller crops 514 scores = 1 / box_area(data["crop_boxes"]) 515 scores = scores.to(data["boxes"].device) 516 keep_by_nms = batched_nms( 517 data["boxes"].float(), 518 scores, 519 torch.zeros_like(data["boxes"][:, 0]), # categories 520 iou_threshold=crop_nms_thresh, 521 ) 522 data.filter(keep_by_nms) 523 524 data.to_numpy() 525 masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode) 526 if output_mode == "instance_segmentation": 527 shape = next(iter(masks))["segmentation"].shape if len(masks) > 0 else self.original_size 528 masks = util.mask_data_to_segmentation( 529 masks, shape=shape, with_background=with_background, merge_exclusively=False 530 ) 531 return masks
Generate instance segmentation for the currently initialized image.
Arguments:
- pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model. By default, set to '0.88'.
- stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask under changes to the cutoff used to binarize the model prediction. By default, set to '0.95'.
- box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks. By default, set to '0.7'.
- crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops. By default, set to '0.7'.
- min_mask_region_area: Minimal size for the predicted masks. By default, set to '0'.
- output_mode: The form masks are returned in. Possible values are:
- 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
- 'coco_rle': Return a list of dictionaries with run-length encoded masks in MS COCO format.
- 'instance_segmentation': Return masks merged into an instance segmentation in a single array.
- 'rle': Return a list of dictionaries with run-length encoded masks. By default, set to 'instance_segmentation'.
- with_background: Whether to remove the largest object, which often covers the background.
Returns:
The segmentation masks.
Inherited Members
565class TiledAutomaticMaskGenerator(AutomaticMaskGenerator): 566 """Generates an instance segmentation without prompts, using a point grid. 567 568 Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings. 569 570 Args: 571 predictor: The Segment Anything predictor. 572 points_per_side: The number of points to be sampled along one side of the image. 573 If None, `point_grids` must provide explicit point sampling. By default, set to '32'. 574 points_per_batch: The number of points run simultaneously by the model. 575 Higher numbers may be faster but use more GPU memory. By default, set to '64'. 576 point_grids: A list over explicit grids of points used for sampling masks. 577 Normalized to [0, 1] with respect to the image coordinate system. 578 stability_score_offset: The amount to shift the cutoff when calculating the stability score. 579 By default, set to '1.0'. 580 """ 581 582 # We only expose the arguments that make sense for the tiled mask generator. 583 # Anything related to crops doesn't make sense, because we re-use that functionality 584 # for tiling, so these parameters wouldn't have any effect. 585 def __init__( 586 self, 587 predictor: SamPredictor, 588 points_per_side: Optional[int] = 32, 589 points_per_batch: int = 64, 590 point_grids: Optional[List[np.ndarray]] = None, 591 stability_score_offset: float = 1.0, 592 ) -> None: 593 super().__init__( 594 predictor=predictor, 595 points_per_side=points_per_side, 596 points_per_batch=points_per_batch, 597 point_grids=point_grids, 598 stability_score_offset=stability_score_offset, 599 ) 600 601 @torch.no_grad() 602 def initialize( 603 self, 604 image: np.ndarray, 605 image_embeddings: Optional[util.ImageEmbeddings] = None, 606 i: Optional[int] = None, 607 tile_shape: Optional[Tuple[int, int]] = None, 608 halo: Optional[Tuple[int, int]] = None, 609 verbose: bool = False, 610 pbar_init: Optional[callable] = None, 611 pbar_update: Optional[callable] = None, 612 batch_size: int = 1, 613 mask: Optional[np.typing.ArrayLike] = None, 614 ) -> None: 615 """Initialize image embeddings and masks for an image. 616 617 Args: 618 image: The input image, volume or timeseries. 619 image_embeddings: Optional precomputed image embeddings. 620 See `util.precompute_image_embeddings` for details. 621 i: Index for the image data. Required if `image` has three spatial dimensions 622 or a time dimension and two spatial dimensions. 623 tile_shape: The tile shape for embedding prediction. 624 halo: The overlap of between tiles. 625 verbose: Whether to print computation progress. By default, set to 'False'. 626 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 627 Can be used together with pbar_update to handle napari progress bar in other thread. 628 To enable using this function within a threadworker. 629 pbar_update: Callback to update an external progress bar. 630 batch_size: The batch size for image embedding prediction. By default, set to '1'. 631 mask: An optional mask to define areas that are ignored in the segmentation. 632 """ 633 original_size = image.shape[:2] 634 self._original_size = original_size 635 636 self._image_embeddings, tile_shape, halo, tiles_in_mask = _process_tiled_embeddings( 637 self._predictor, image, image_embeddings, tile_shape, halo, 638 verbose=verbose, batch_size=batch_size, mask=mask, i=i, 639 ) 640 641 tiling = Blocking([0, 0], original_size, tile_shape) 642 if tiles_in_mask is None: 643 n_tiles = tiling.number_of_blocks 644 tile_ids = range(n_tiles) 645 else: 646 n_tiles = len(tiles_in_mask) 647 tile_ids = tiles_in_mask 648 649 # The crop box is always the full local tile. 650 tiles = [tiling.get_block_with_halo(tile_id, list(halo)).outer_block for tile_id in tile_ids] 651 crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles] 652 653 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 654 pbar_init(n_tiles, "Compute masks for tile") 655 656 # We need to cast to the image representation that is compatible with SAM. 657 image = util._to_image(image) 658 659 mask_data = [] 660 for idx, tile_id in enumerate(tile_ids): 661 # set the pre-computed embeddings for this tile 662 features = image_embeddings["features"][str(tile_id)] 663 tile_embeddings = { 664 "features": features, 665 "input_size": features.attrs["input_size"], 666 "original_size": features.attrs["original_size"], 667 } 668 util.set_precomputed(self._predictor, tile_embeddings, i) 669 670 # Compute the mask data for this tile and append it 671 this_mask_data = self._process_crop( 672 image, crop_box=crop_boxes[idx], crop_layer_idx=0, precomputed_embeddings=True 673 ) 674 mask_data.append(this_mask_data) 675 pbar_update(1) 676 pbar_close() 677 678 # set the initialized data 679 self._is_initialized = True 680 self._crop_list = mask_data 681 self._crop_boxes = crop_boxes
Generates an instance segmentation without prompts, using a point grid.
Implements the same functionality as AutomaticMaskGenerator but for tiled embeddings.
Arguments:
- predictor: The Segment Anything predictor.
- points_per_side: The number of points to be sampled along one side of the image.
If None,
point_gridsmust provide explicit point sampling. By default, set to '32'. - points_per_batch: The number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory. By default, set to '64'.
- point_grids: A list over explicit grids of points used for sampling masks. Normalized to [0, 1] with respect to the image coordinate system.
- stability_score_offset: The amount to shift the cutoff when calculating the stability score. By default, set to '1.0'.
585 def __init__( 586 self, 587 predictor: SamPredictor, 588 points_per_side: Optional[int] = 32, 589 points_per_batch: int = 64, 590 point_grids: Optional[List[np.ndarray]] = None, 591 stability_score_offset: float = 1.0, 592 ) -> None: 593 super().__init__( 594 predictor=predictor, 595 points_per_side=points_per_side, 596 points_per_batch=points_per_batch, 597 point_grids=point_grids, 598 stability_score_offset=stability_score_offset, 599 )
601 @torch.no_grad() 602 def initialize( 603 self, 604 image: np.ndarray, 605 image_embeddings: Optional[util.ImageEmbeddings] = None, 606 i: Optional[int] = None, 607 tile_shape: Optional[Tuple[int, int]] = None, 608 halo: Optional[Tuple[int, int]] = None, 609 verbose: bool = False, 610 pbar_init: Optional[callable] = None, 611 pbar_update: Optional[callable] = None, 612 batch_size: int = 1, 613 mask: Optional[np.typing.ArrayLike] = None, 614 ) -> None: 615 """Initialize image embeddings and masks for an image. 616 617 Args: 618 image: The input image, volume or timeseries. 619 image_embeddings: Optional precomputed image embeddings. 620 See `util.precompute_image_embeddings` for details. 621 i: Index for the image data. Required if `image` has three spatial dimensions 622 or a time dimension and two spatial dimensions. 623 tile_shape: The tile shape for embedding prediction. 624 halo: The overlap of between tiles. 625 verbose: Whether to print computation progress. By default, set to 'False'. 626 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 627 Can be used together with pbar_update to handle napari progress bar in other thread. 628 To enable using this function within a threadworker. 629 pbar_update: Callback to update an external progress bar. 630 batch_size: The batch size for image embedding prediction. By default, set to '1'. 631 mask: An optional mask to define areas that are ignored in the segmentation. 632 """ 633 original_size = image.shape[:2] 634 self._original_size = original_size 635 636 self._image_embeddings, tile_shape, halo, tiles_in_mask = _process_tiled_embeddings( 637 self._predictor, image, image_embeddings, tile_shape, halo, 638 verbose=verbose, batch_size=batch_size, mask=mask, i=i, 639 ) 640 641 tiling = Blocking([0, 0], original_size, tile_shape) 642 if tiles_in_mask is None: 643 n_tiles = tiling.number_of_blocks 644 tile_ids = range(n_tiles) 645 else: 646 n_tiles = len(tiles_in_mask) 647 tile_ids = tiles_in_mask 648 649 # The crop box is always the full local tile. 650 tiles = [tiling.get_block_with_halo(tile_id, list(halo)).outer_block for tile_id in tile_ids] 651 crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles] 652 653 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 654 pbar_init(n_tiles, "Compute masks for tile") 655 656 # We need to cast to the image representation that is compatible with SAM. 657 image = util._to_image(image) 658 659 mask_data = [] 660 for idx, tile_id in enumerate(tile_ids): 661 # set the pre-computed embeddings for this tile 662 features = image_embeddings["features"][str(tile_id)] 663 tile_embeddings = { 664 "features": features, 665 "input_size": features.attrs["input_size"], 666 "original_size": features.attrs["original_size"], 667 } 668 util.set_precomputed(self._predictor, tile_embeddings, i) 669 670 # Compute the mask data for this tile and append it 671 this_mask_data = self._process_crop( 672 image, crop_box=crop_boxes[idx], crop_layer_idx=0, precomputed_embeddings=True 673 ) 674 mask_data.append(this_mask_data) 675 pbar_update(1) 676 pbar_close() 677 678 # set the initialized data 679 self._is_initialized = True 680 self._crop_list = mask_data 681 self._crop_boxes = crop_boxes
Initialize image embeddings and masks for an image.
Arguments:
- image: The input image, volume or timeseries.
- image_embeddings: Optional precomputed image embeddings.
See
util.precompute_image_embeddingsfor details. - i: Index for the image data. Required if
imagehas three spatial dimensions or a time dimension and two spatial dimensions. - tile_shape: The tile shape for embedding prediction.
- halo: The overlap of between tiles.
- verbose: Whether to print computation progress. By default, set to 'False'.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enable using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
- batch_size: The batch size for image embedding prediction. By default, set to '1'.
- mask: An optional mask to define areas that are ignored in the segmentation.
689class DecoderAdapter(torch.nn.Module): 690 """Adapter to contain the UNETR decoder in a single module. 691 692 To apply the decoder on top of pre-computed embeddings for the segmentation functionality. 693 See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py 694 """ 695 def __init__(self, unetr: torch.nn.Module): 696 super().__init__() 697 698 self.base = unetr.base 699 self.out_conv = unetr.out_conv 700 self.deconv_out = unetr.deconv_out 701 self.decoder_head = unetr.decoder_head 702 self.final_activation = unetr.final_activation 703 self.postprocess_masks = unetr.postprocess_masks 704 705 self.decoder = unetr.decoder 706 self.deconv1 = unetr.deconv1 707 self.deconv2 = unetr.deconv2 708 self.deconv3 = unetr.deconv3 709 self.deconv4 = unetr.deconv4 710 711 def _forward_impl(self, input_): 712 z12 = input_ 713 714 z9 = self.deconv1(z12) 715 z6 = self.deconv2(z9) 716 z3 = self.deconv3(z6) 717 z0 = self.deconv4(z3) 718 719 updated_from_encoder = [z9, z6, z3] 720 721 x = self.base(z12) 722 x = self.decoder(x, encoder_inputs=updated_from_encoder) 723 x = self.deconv_out(x) 724 725 x = torch.cat([x, z0], dim=1) 726 x = self.decoder_head(x) 727 728 x = self.out_conv(x) 729 if self.final_activation is not None: 730 x = self.final_activation(x) 731 return x 732 733 def forward(self, input_, input_shape, original_shape): 734 x = self._forward_impl(input_) 735 x = self.postprocess_masks(x, input_shape, original_shape) 736 return x
Adapter to contain the UNETR decoder in a single module.
To apply the decoder on top of pre-computed embeddings for the segmentation functionality. See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py
695 def __init__(self, unetr: torch.nn.Module): 696 super().__init__() 697 698 self.base = unetr.base 699 self.out_conv = unetr.out_conv 700 self.deconv_out = unetr.deconv_out 701 self.decoder_head = unetr.decoder_head 702 self.final_activation = unetr.final_activation 703 self.postprocess_masks = unetr.postprocess_masks 704 705 self.decoder = unetr.decoder 706 self.deconv1 = unetr.deconv1 707 self.deconv2 = unetr.deconv2 708 self.deconv3 = unetr.deconv3 709 self.deconv4 = unetr.deconv4
Initialize internal Module state, shared by both nn.Module and ScriptModule.
733 def forward(self, input_, input_shape, original_shape): 734 x = self._forward_impl(input_) 735 x = self.postprocess_masks(x, input_shape, original_shape) 736 return x
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
739def get_unetr( 740 image_encoder: torch.nn.Module, 741 decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None, 742 device: Optional[Union[str, torch.device]] = None, 743 out_channels: int = 3, 744 flexible_load_checkpoint: bool = False, 745 final_activation: Optional[str] = "Sigmoid", 746) -> torch.nn.Module: 747 """Get UNETR model for automatic instance segmentation. 748 749 Args: 750 image_encoder: The image encoder of the SAM model. 751 This is used as encoder by the UNETR too. 752 decoder_state: Optional decoder state to initialize the weights of the UNETR decoder. 753 device: The device. By default, automatically chooses the best available device. 754 out_channels: The number of output channels. By default, set to '3'. 755 flexible_load_checkpoint: Whether to allow reinitialization of parameters 756 which could not be found in the provided decoder state. By default, set to 'False'. 757 final_activation: The activation applied to the network output. By default uses a Sigmoid activation. 758 759 Returns: 760 The UNETR model. 761 """ 762 device = util.get_device(device) 763 764 if decoder_state is None: 765 use_conv_transpose = False # By default, we use interpolation for upsampling. 766 else: 767 # From the provided pretrained 'decoder_state', we check whether it uses transposed convolutions. 768 # NOTE: Explanation to the logic below - 769 # - We do this by looking for parameter names that contain '.block.' within the "decoder.samplers" 770 # submodules. This naming convention indicates that transposed convolutions are used, 771 # wrapped inside a custom block module. 772 # - Otherwise '.conv.' appears. It indicates a standard `Conv2d` applied after interpolation for upsampling. 773 use_conv_transpose = any(".block." in k for k in decoder_state.keys() if k.startswith("decoder.samplers")) 774 775 unetr = UNETR( 776 backbone="sam", 777 encoder=image_encoder, 778 out_channels=out_channels, 779 use_sam_stats=True, 780 final_activation=final_activation, 781 use_skip_connection=False, 782 resize_input=True, 783 use_conv_transpose=use_conv_transpose, 784 ) 785 786 if decoder_state is not None: 787 unetr_state_dict = unetr.state_dict() 788 for k, v in unetr_state_dict.items(): 789 if not k.startswith("encoder"): 790 # Whether allow reinitalization of params, if not found or mismatched. 791 if flexible_load_checkpoint: 792 if k in decoder_state: # First check whether the key is available in the provided decoder state. 793 if v.shape != decoder_state[k].shape: # Then check if the sizes mismatch. 794 warnings.warn(f"Shape of '{k}' did not match. Hence, we reinitialize it.") 795 unetr_state_dict[k] = v 796 else: 797 unetr_state_dict[k] = decoder_state[k] 798 else: # Otherwise, allow it to initialize it. 799 warnings.warn(f"Could not find '{k}' in the pretrained state dict. Hence, we reinitialize it.") 800 unetr_state_dict[k] = v 801 802 else: # Be strict on finding the parameter in the decoder state. 803 if k not in decoder_state: 804 raise RuntimeError(f"The parameters for '{k}' could not be found or has a size mismatch.") 805 unetr_state_dict[k] = decoder_state[k] 806 807 unetr.load_state_dict(unetr_state_dict) 808 809 unetr.to(device) 810 return unetr
Get UNETR model for automatic instance segmentation.
Arguments:
- image_encoder: The image encoder of the SAM model. This is used as encoder by the UNETR too.
- decoder_state: Optional decoder state to initialize the weights of the UNETR decoder.
- device: The device. By default, automatically chooses the best available device.
- out_channels: The number of output channels. By default, set to '3'.
- flexible_load_checkpoint: Whether to allow reinitialization of parameters which could not be found in the provided decoder state. By default, set to 'False'.
- final_activation: The activation applied to the network output. By default uses a Sigmoid activation.
Returns:
The UNETR model.
813def get_decoder( 814 image_encoder: torch.nn.Module, 815 decoder_state: OrderedDict[str, torch.Tensor], 816 device: Optional[Union[str, torch.device]] = None, 817) -> DecoderAdapter: 818 """Get decoder to predict outputs for automatic instance segmentation 819 820 Args: 821 image_encoder: The image encoder of the SAM model. 822 decoder_state: State to initialize the weights of the UNETR decoder. 823 device: The device. By default, automatically chooses the best available device. 824 825 Returns: 826 The decoder for instance segmentation. 827 """ 828 unetr = get_unetr(image_encoder, decoder_state, device) 829 return DecoderAdapter(unetr)
Get decoder to predict outputs for automatic instance segmentation
Arguments:
- image_encoder: The image encoder of the SAM model.
- decoder_state: State to initialize the weights of the UNETR decoder.
- device: The device. By default, automatically chooses the best available device.
Returns:
The decoder for instance segmentation.
832def get_predictor_and_decoder( 833 model_type: str, 834 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 835 device: Optional[Union[str, torch.device]] = None, 836 peft_kwargs: Optional[Dict] = None, 837) -> Tuple[SamPredictor, DecoderAdapter]: 838 """Load the SAM model (predictor) and instance segmentation decoder. 839 840 This requires a checkpoint that contains the state for both predictor 841 and decoder. 842 843 Args: 844 model_type: The type of the image encoder used in the SAM model. 845 checkpoint_path: Path to the checkpoint from which to load the data. 846 device: The device. By default, automatically chooses the best available device. 847 peft_kwargs: Keyword arguments for the PEFT wrapper class. 848 849 Returns: 850 The SAM predictor. 851 The decoder for instance segmentation. 852 """ 853 device = util.get_device(device) 854 predictor, state = util.get_sam_model( 855 model_type=model_type, 856 checkpoint_path=checkpoint_path, 857 device=device, 858 return_state=True, 859 peft_kwargs=peft_kwargs, 860 ) 861 862 if "decoder_state" not in state: 863 raise ValueError( 864 f"The checkpoint at '{checkpoint_path}' or the chosen model '{model_type}' does not contain a decoder state" 865 ) 866 867 decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device) 868 return predictor, decoder
Load the SAM model (predictor) and instance segmentation decoder.
This requires a checkpoint that contains the state for both predictor and decoder.
Arguments:
- model_type: The type of the image encoder used in the SAM model.
- checkpoint_path: Path to the checkpoint from which to load the data.
- device: The device. By default, automatically chooses the best available device.
- peft_kwargs: Keyword arguments for the PEFT wrapper class.
Returns:
The SAM predictor. The decoder for instance segmentation.
954class InstanceSegmentationWithDecoder: 955 """Generates an instance segmentation without prompts, using a decoder. 956 957 Implements the same interface as `AutomaticMaskGenerator`. 958 959 Use this class as follows: 960 ```python 961 segmenter = InstanceSegmentationWithDecoder(predictor, decoder) 962 segmenter.initialize(image) # Predict the image embeddings and decoder outputs. 963 masks = segmenter.generate(center_distance_threshold=0.75) # Generate the instance segmentation. 964 ``` 965 966 Args: 967 predictor: The segment anything predictor. 968 decoder: The decoder to predict intermediate representations for instance segmentation. 969 """ 970 def __init__(self, predictor: SamPredictor, decoder: torch.nn.Module) -> None: 971 self._predictor = predictor 972 self._decoder = decoder 973 974 # The decoder outputs. 975 self._foreground = None 976 self._center_distances = None 977 self._boundary_distances = None 978 979 self._is_initialized = False 980 981 @property 982 def is_initialized(self): 983 """Whether the mask generator has already been initialized. 984 """ 985 return self._is_initialized 986 987 @torch.no_grad() 988 def initialize( 989 self, 990 image: np.ndarray, 991 image_embeddings: Optional[util.ImageEmbeddings] = None, 992 i: Optional[int] = None, 993 verbose: bool = False, 994 pbar_init: Optional[callable] = None, 995 pbar_update: Optional[callable] = None, 996 ndim: int = 2, 997 ) -> None: 998 """Initialize image embeddings and decoder predictions for an image. 999 1000 Args: 1001 image: The input image, volume or timeseries. 1002 image_embeddings: Optional precomputed image embeddings. 1003 See `util.precompute_image_embeddings` for details. 1004 i: Index for the image data. Required if `image` has three spatial dimensions 1005 or a time dimension and two spatial dimensions. 1006 verbose: Whether to be verbose. By default, set to 'False'. 1007 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 1008 Can be used together with pbar_update to handle napari progress bar in other thread. 1009 To enable using this function within a threadworker. 1010 pbar_update: Callback to update an external progress bar. 1011 ndim: The dimensionality of the data. If not given will be deduced from the input data. By default, 2. 1012 """ 1013 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 1014 pbar_init(1, "Initialize instance segmentation with decoder") 1015 1016 if image_embeddings is None: 1017 image_embeddings = util.precompute_image_embeddings( 1018 predictor=self._predictor, input_=image, ndim=ndim, verbose=verbose 1019 ) 1020 1021 # Get the image embeddings from the predictor. 1022 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i) 1023 embeddings = self._predictor.features 1024 input_shape = tuple(self._predictor.input_size) 1025 original_shape = tuple(self._predictor.original_size) 1026 1027 # Run prediction with the UNETR decoder. 1028 output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0) 1029 assert output.shape[0] == 3, f"{output.shape}" 1030 pbar_update(1) 1031 pbar_close() 1032 1033 # Set the state. 1034 self._foreground = output[0] 1035 self._center_distances = output[1] 1036 self._boundary_distances = output[2] 1037 self._i = i 1038 self._is_initialized = True 1039 1040 def _to_masks(self, segmentation, output_mode): 1041 if output_mode != "binary_mask": 1042 raise ValueError( 1043 f"Output mode {output_mode} is not supported. Choose one of 'instance_segmentation', 'binary_masks'" 1044 ) 1045 1046 props = regionprops(segmentation) 1047 ndim = segmentation.ndim 1048 assert ndim in (2, 3) 1049 1050 shape = segmentation.shape 1051 if ndim == 2: 1052 crop_box = [0, shape[1], 0, shape[0]] 1053 else: 1054 crop_box = [0, shape[2], 0, shape[1], 0, shape[0]] 1055 1056 # go from skimage bbox in format [y0, x0, y1, x1] to SAM format [x0, w, y0, h] 1057 def to_bbox_2d(bbox): 1058 y0, x0 = bbox[0], bbox[1] 1059 w = bbox[3] - x0 1060 h = bbox[2] - y0 1061 return [x0, w, y0, h] 1062 1063 def to_bbox_3d(bbox): 1064 z0, y0, x0 = bbox[0], bbox[1], bbox[2] 1065 w = bbox[5] - x0 1066 h = bbox[4] - y0 1067 d = bbox[3] - y0 1068 return [x0, w, y0, h, z0, d] 1069 1070 to_bbox = to_bbox_2d if ndim == 2 else to_bbox_3d 1071 masks = [ 1072 { 1073 "segmentation": segmentation == prop.label, 1074 "area": prop.area, 1075 "bbox": to_bbox(prop.bbox), 1076 "crop_box": crop_box, 1077 "seg_id": prop.label, 1078 } for prop in props 1079 ] 1080 return masks 1081 1082 def generate( 1083 self, 1084 center_distance_threshold: float = 0.5, 1085 boundary_distance_threshold: float = 0.5, 1086 foreground_threshold: float = 0.5, 1087 foreground_smoothing: float = 1.0, 1088 distance_smoothing: float = 1.6, 1089 min_size: int = 0, 1090 output_mode: str = "instance_segmentation", 1091 tile_shape: Optional[Tuple[int, int]] = None, 1092 halo: Optional[Tuple[int, int]] = None, 1093 n_threads: Optional[int] = None, 1094 optimize_memory: bool = False, 1095 segmentation: Optional[np.ndarray] = None, 1096 ) -> Union[List[Dict[str, Any]], np.ndarray]: 1097 """Generate instance segmentation for the currently initialized image. 1098 1099 Args: 1100 center_distance_threshold: Center distance predictions below this value will be 1101 used to find seeds (intersected with thresholded boundary distance predictions). 1102 By default, set to '0.5'. 1103 boundary_distance_threshold: Boundary distance predictions below this value will be 1104 used to find seeds (intersected with thresholded center distance predictions). 1105 By default, set to '0.5'. 1106 foreground_threshold: Foreground predictions above this value will be used as foreground mask. 1107 By default, set to '0.5'. 1108 foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid 1109 checkerboard artifacts in the prediction. By default, set to '1.0'. 1110 distance_smoothing: Sigma value for smoothing the distance predictions. 1111 min_size: Minimal object size in the segmentation result. By default, set to '0'. 1112 output_mode: The form masks are returned in. Possible values are: 1113 - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks. 1114 - 'instance_segmentation': Return masks merged into an instance segmentation in a single array. 1115 By default, set to 'instance_segmentation'. 1116 tile_shape: Tile shape for parallelizing the instance segmentation post-processing. 1117 This parameter is independent from the tile shape for computing the embeddings. 1118 If not given then post-processing will not be parallelized. 1119 halo: Halo for parallel post-processing. See also `tile_shape`. 1120 n_threads: Number of threads for parallel post-processing. See also `tile_shape`. 1121 optimize_memory: Whether to optimize the memory consumption by allocating intermediate files. 1122 segmentation: Optional pre-allocated segmentation. 1123 1124 Returns: 1125 The segmentation masks. 1126 """ 1127 if not self.is_initialized: 1128 raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.") 1129 1130 if foreground_smoothing > 0: 1131 foreground = _apply_smoothing(self._foreground, foreground_smoothing, tile_shape, n_threads) 1132 else: 1133 foreground = self._foreground 1134 1135 if tile_shape is None: 1136 segmentation = watershed_from_center_and_boundary_distances( 1137 center_distances=self._center_distances, 1138 boundary_distances=self._boundary_distances, 1139 foreground_map=foreground, 1140 center_distance_threshold=center_distance_threshold, 1141 boundary_distance_threshold=boundary_distance_threshold, 1142 foreground_threshold=foreground_threshold, 1143 distance_smoothing=distance_smoothing, 1144 min_size=min_size, 1145 ) 1146 else: 1147 if halo is None: 1148 raise ValueError("You must pass a value for halo if tile_shape is given.") 1149 1150 # Shards are not thread-safe for parallel writing! So if we have shards we have to use them for tiling. 1151 # This is ok in terms efficiency as GPU tiles are small; shards should still be manegable for the watershed. 1152 if isinstance(segmentation, zarr.Array) and getattr(segmentation, "shards", None) is not None: 1153 tile_shape = segmentation.shards 1154 1155 segmentation = _watershed_from_center_and_boundary_distances_parallel( 1156 center_distances=self._center_distances, 1157 boundary_distances=self._boundary_distances, 1158 foreground_map=foreground, 1159 center_distance_threshold=center_distance_threshold, 1160 boundary_distance_threshold=boundary_distance_threshold, 1161 foreground_threshold=foreground_threshold, 1162 distance_smoothing=distance_smoothing, 1163 min_size=min_size, 1164 tile_shape=tile_shape, 1165 halo=halo, 1166 n_threads=n_threads, 1167 verbose=False, 1168 optimize_memory=optimize_memory, 1169 segmentation=segmentation, 1170 ) 1171 1172 if output_mode != "instance_segmentation": 1173 segmentation = self._to_masks(segmentation, output_mode) 1174 return segmentation 1175 1176 def get_state(self) -> Dict[str, Any]: 1177 """Get the initialized state of the instance segmenter. 1178 1179 Returns: 1180 Instance segmentation state. 1181 """ 1182 if not self.is_initialized: 1183 raise RuntimeError("The state has not been computed yet. Call initialize first.") 1184 1185 return { 1186 "foreground": self._foreground, 1187 "center_distances": self._center_distances, 1188 "boundary_distances": self._boundary_distances, 1189 } 1190 1191 def set_state(self, state: Dict[str, Any]) -> None: 1192 """Set the state of the instance segmenter. 1193 1194 Args: 1195 state: The instance segmentation state 1196 """ 1197 self._foreground = state["foreground"] 1198 self._center_distances = state["center_distances"] 1199 self._boundary_distances = state["boundary_distances"] 1200 self._is_initialized = True 1201 1202 def clear_state(self): 1203 """Clear the state of the instance segmenter. 1204 """ 1205 self._foreground = None 1206 self._center_distances = None 1207 self._boundary_distances = None 1208 self._is_initialized = False
Generates an instance segmentation without prompts, using a decoder.
Implements the same interface as AutomaticMaskGenerator.
Use this class as follows:
segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
segmenter.initialize(image) # Predict the image embeddings and decoder outputs.
masks = segmenter.generate(center_distance_threshold=0.75) # Generate the instance segmentation.
Arguments:
- predictor: The segment anything predictor.
- decoder: The decoder to predict intermediate representations for instance segmentation.
970 def __init__(self, predictor: SamPredictor, decoder: torch.nn.Module) -> None: 971 self._predictor = predictor 972 self._decoder = decoder 973 974 # The decoder outputs. 975 self._foreground = None 976 self._center_distances = None 977 self._boundary_distances = None 978 979 self._is_initialized = False
981 @property 982 def is_initialized(self): 983 """Whether the mask generator has already been initialized. 984 """ 985 return self._is_initialized
Whether the mask generator has already been initialized.
987 @torch.no_grad() 988 def initialize( 989 self, 990 image: np.ndarray, 991 image_embeddings: Optional[util.ImageEmbeddings] = None, 992 i: Optional[int] = None, 993 verbose: bool = False, 994 pbar_init: Optional[callable] = None, 995 pbar_update: Optional[callable] = None, 996 ndim: int = 2, 997 ) -> None: 998 """Initialize image embeddings and decoder predictions for an image. 999 1000 Args: 1001 image: The input image, volume or timeseries. 1002 image_embeddings: Optional precomputed image embeddings. 1003 See `util.precompute_image_embeddings` for details. 1004 i: Index for the image data. Required if `image` has three spatial dimensions 1005 or a time dimension and two spatial dimensions. 1006 verbose: Whether to be verbose. By default, set to 'False'. 1007 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 1008 Can be used together with pbar_update to handle napari progress bar in other thread. 1009 To enable using this function within a threadworker. 1010 pbar_update: Callback to update an external progress bar. 1011 ndim: The dimensionality of the data. If not given will be deduced from the input data. By default, 2. 1012 """ 1013 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 1014 pbar_init(1, "Initialize instance segmentation with decoder") 1015 1016 if image_embeddings is None: 1017 image_embeddings = util.precompute_image_embeddings( 1018 predictor=self._predictor, input_=image, ndim=ndim, verbose=verbose 1019 ) 1020 1021 # Get the image embeddings from the predictor. 1022 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i) 1023 embeddings = self._predictor.features 1024 input_shape = tuple(self._predictor.input_size) 1025 original_shape = tuple(self._predictor.original_size) 1026 1027 # Run prediction with the UNETR decoder. 1028 output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0) 1029 assert output.shape[0] == 3, f"{output.shape}" 1030 pbar_update(1) 1031 pbar_close() 1032 1033 # Set the state. 1034 self._foreground = output[0] 1035 self._center_distances = output[1] 1036 self._boundary_distances = output[2] 1037 self._i = i 1038 self._is_initialized = True
Initialize image embeddings and decoder predictions for an image.
Arguments:
- image: The input image, volume or timeseries.
- image_embeddings: Optional precomputed image embeddings.
See
util.precompute_image_embeddingsfor details. - i: Index for the image data. Required if
imagehas three spatial dimensions or a time dimension and two spatial dimensions. - verbose: Whether to be verbose. By default, set to 'False'.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enable using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
- ndim: The dimensionality of the data. If not given will be deduced from the input data. By default, 2.
1082 def generate( 1083 self, 1084 center_distance_threshold: float = 0.5, 1085 boundary_distance_threshold: float = 0.5, 1086 foreground_threshold: float = 0.5, 1087 foreground_smoothing: float = 1.0, 1088 distance_smoothing: float = 1.6, 1089 min_size: int = 0, 1090 output_mode: str = "instance_segmentation", 1091 tile_shape: Optional[Tuple[int, int]] = None, 1092 halo: Optional[Tuple[int, int]] = None, 1093 n_threads: Optional[int] = None, 1094 optimize_memory: bool = False, 1095 segmentation: Optional[np.ndarray] = None, 1096 ) -> Union[List[Dict[str, Any]], np.ndarray]: 1097 """Generate instance segmentation for the currently initialized image. 1098 1099 Args: 1100 center_distance_threshold: Center distance predictions below this value will be 1101 used to find seeds (intersected with thresholded boundary distance predictions). 1102 By default, set to '0.5'. 1103 boundary_distance_threshold: Boundary distance predictions below this value will be 1104 used to find seeds (intersected with thresholded center distance predictions). 1105 By default, set to '0.5'. 1106 foreground_threshold: Foreground predictions above this value will be used as foreground mask. 1107 By default, set to '0.5'. 1108 foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid 1109 checkerboard artifacts in the prediction. By default, set to '1.0'. 1110 distance_smoothing: Sigma value for smoothing the distance predictions. 1111 min_size: Minimal object size in the segmentation result. By default, set to '0'. 1112 output_mode: The form masks are returned in. Possible values are: 1113 - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks. 1114 - 'instance_segmentation': Return masks merged into an instance segmentation in a single array. 1115 By default, set to 'instance_segmentation'. 1116 tile_shape: Tile shape for parallelizing the instance segmentation post-processing. 1117 This parameter is independent from the tile shape for computing the embeddings. 1118 If not given then post-processing will not be parallelized. 1119 halo: Halo for parallel post-processing. See also `tile_shape`. 1120 n_threads: Number of threads for parallel post-processing. See also `tile_shape`. 1121 optimize_memory: Whether to optimize the memory consumption by allocating intermediate files. 1122 segmentation: Optional pre-allocated segmentation. 1123 1124 Returns: 1125 The segmentation masks. 1126 """ 1127 if not self.is_initialized: 1128 raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.") 1129 1130 if foreground_smoothing > 0: 1131 foreground = _apply_smoothing(self._foreground, foreground_smoothing, tile_shape, n_threads) 1132 else: 1133 foreground = self._foreground 1134 1135 if tile_shape is None: 1136 segmentation = watershed_from_center_and_boundary_distances( 1137 center_distances=self._center_distances, 1138 boundary_distances=self._boundary_distances, 1139 foreground_map=foreground, 1140 center_distance_threshold=center_distance_threshold, 1141 boundary_distance_threshold=boundary_distance_threshold, 1142 foreground_threshold=foreground_threshold, 1143 distance_smoothing=distance_smoothing, 1144 min_size=min_size, 1145 ) 1146 else: 1147 if halo is None: 1148 raise ValueError("You must pass a value for halo if tile_shape is given.") 1149 1150 # Shards are not thread-safe for parallel writing! So if we have shards we have to use them for tiling. 1151 # This is ok in terms efficiency as GPU tiles are small; shards should still be manegable for the watershed. 1152 if isinstance(segmentation, zarr.Array) and getattr(segmentation, "shards", None) is not None: 1153 tile_shape = segmentation.shards 1154 1155 segmentation = _watershed_from_center_and_boundary_distances_parallel( 1156 center_distances=self._center_distances, 1157 boundary_distances=self._boundary_distances, 1158 foreground_map=foreground, 1159 center_distance_threshold=center_distance_threshold, 1160 boundary_distance_threshold=boundary_distance_threshold, 1161 foreground_threshold=foreground_threshold, 1162 distance_smoothing=distance_smoothing, 1163 min_size=min_size, 1164 tile_shape=tile_shape, 1165 halo=halo, 1166 n_threads=n_threads, 1167 verbose=False, 1168 optimize_memory=optimize_memory, 1169 segmentation=segmentation, 1170 ) 1171 1172 if output_mode != "instance_segmentation": 1173 segmentation = self._to_masks(segmentation, output_mode) 1174 return segmentation
Generate instance segmentation for the currently initialized image.
Arguments:
- center_distance_threshold: Center distance predictions below this value will be used to find seeds (intersected with thresholded boundary distance predictions). By default, set to '0.5'.
- boundary_distance_threshold: Boundary distance predictions below this value will be used to find seeds (intersected with thresholded center distance predictions). By default, set to '0.5'.
- foreground_threshold: Foreground predictions above this value will be used as foreground mask. By default, set to '0.5'.
- foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid checkerboard artifacts in the prediction. By default, set to '1.0'.
- distance_smoothing: Sigma value for smoothing the distance predictions.
- min_size: Minimal object size in the segmentation result. By default, set to '0'.
- output_mode: The form masks are returned in. Possible values are:
- 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
- 'instance_segmentation': Return masks merged into an instance segmentation in a single array. By default, set to 'instance_segmentation'.
- tile_shape: Tile shape for parallelizing the instance segmentation post-processing. This parameter is independent from the tile shape for computing the embeddings. If not given then post-processing will not be parallelized.
- halo: Halo for parallel post-processing. See also
tile_shape. - n_threads: Number of threads for parallel post-processing. See also
tile_shape. - optimize_memory: Whether to optimize the memory consumption by allocating intermediate files.
- segmentation: Optional pre-allocated segmentation.
Returns:
The segmentation masks.
1176 def get_state(self) -> Dict[str, Any]: 1177 """Get the initialized state of the instance segmenter. 1178 1179 Returns: 1180 Instance segmentation state. 1181 """ 1182 if not self.is_initialized: 1183 raise RuntimeError("The state has not been computed yet. Call initialize first.") 1184 1185 return { 1186 "foreground": self._foreground, 1187 "center_distances": self._center_distances, 1188 "boundary_distances": self._boundary_distances, 1189 }
Get the initialized state of the instance segmenter.
Returns:
Instance segmentation state.
1191 def set_state(self, state: Dict[str, Any]) -> None: 1192 """Set the state of the instance segmenter. 1193 1194 Args: 1195 state: The instance segmentation state 1196 """ 1197 self._foreground = state["foreground"] 1198 self._center_distances = state["center_distances"] 1199 self._boundary_distances = state["boundary_distances"] 1200 self._is_initialized = True
Set the state of the instance segmenter.
Arguments:
- state: The instance segmentation state
1211class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder): 1212 """Same as `InstanceSegmentationWithDecoder` but for tiled image embeddings. 1213 """ 1214 1215 # Apply the decoder in a batched fashion, and then perform the resizing independently per output. 1216 # This is necessary, because the individual tiles may have different tile shapes due to border tiles. 1217 def _predict_decoder(self, batched_embeddings, input_shapes, original_shapes): 1218 batched_embeddings = torch.cat(batched_embeddings) 1219 output = self._decoder._forward_impl(batched_embeddings) 1220 1221 batched_output = [] 1222 for x, input_shape, original_shape in zip(output, input_shapes, original_shapes): 1223 x = self._decoder.postprocess_masks(x.unsqueeze(0), input_shape, original_shape).squeeze(0) 1224 batched_output.append(x.cpu().numpy()) 1225 return batched_output 1226 1227 @torch.no_grad() 1228 def initialize( 1229 self, 1230 image: np.ndarray, 1231 image_embeddings: Optional[util.ImageEmbeddings] = None, 1232 i: Optional[int] = None, 1233 tile_shape: Optional[Tuple[int, int]] = None, 1234 halo: Optional[Tuple[int, int]] = None, 1235 verbose: bool = False, 1236 pbar_init: Optional[callable] = None, 1237 pbar_update: Optional[callable] = None, 1238 batch_size: int = 1, 1239 mask: Optional[np.typing.ArrayLike] = None, 1240 ) -> None: 1241 """Initialize image embeddings and decoder predictions for an image. 1242 1243 Args: 1244 image: The input image, volume or timeseries. 1245 image_embeddings: Optional precomputed image embeddings. 1246 See `util.precompute_image_embeddings` for details. 1247 i: Index for the image data. Required if `image` has three spatial dimensions 1248 or a time dimension and two spatial dimensions. 1249 tile_shape: Shape of the tiles for precomputing image embeddings. 1250 halo: Overlap of the tiles for tiled precomputation of image embeddings. 1251 verbose: Dummy input to be compatible with other function signatures. By default, set to 'False'. 1252 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 1253 Can be used together with pbar_update to handle napari progress bar in other thread. 1254 To enable using this function within a threadworker. 1255 pbar_update: Callback to update an external progress bar. 1256 batch_size: The batch size for image embedding computation and segmentation decoder prediction. 1257 mask: An optional mask to define areas that are ignored in the segmentation. 1258 """ 1259 original_size = image.shape[:2] 1260 self._image_embeddings, tile_shape, halo, tiles_in_mask = _process_tiled_embeddings( 1261 self._predictor, image, image_embeddings, tile_shape, halo, 1262 verbose=verbose, batch_size=batch_size, mask=mask, i=i, 1263 ) 1264 tiling = Blocking([0, 0], original_size, tile_shape) 1265 1266 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 1267 1268 foreground = np.zeros(original_size, dtype="float32") 1269 center_distances = np.zeros(original_size, dtype="float32") 1270 boundary_distances = np.zeros(original_size, dtype="float32") 1271 1272 msg = "Initialize tiled instance segmentation with decoder" 1273 if tiles_in_mask is None: 1274 n_tiles = tiling.number_of_blocks 1275 all_tile_ids = list(range(n_tiles)) 1276 else: 1277 n_tiles = len(tiles_in_mask) 1278 all_tile_ids = tiles_in_mask 1279 msg += " and mask" 1280 1281 n_batches = int(np.ceil(n_tiles / batch_size)) 1282 pbar_init(n_tiles, msg) 1283 tile_ids_for_batches = np.array_split(all_tile_ids, n_batches) 1284 1285 for tile_ids in tile_ids_for_batches: 1286 batched_embeddings, input_shapes, original_shapes = [], [], [] 1287 for tile_id in tile_ids: 1288 # Get the image embeddings from the predictor for this tile. 1289 self._predictor = util.set_precomputed(self._predictor, self._image_embeddings, i=i, tile_id=tile_id) 1290 1291 batched_embeddings.append(self._predictor.features) 1292 input_shapes.append(tuple(self._predictor.input_size)) 1293 original_shapes.append(tuple(self._predictor.original_size)) 1294 1295 batched_output = self._predict_decoder(batched_embeddings, input_shapes, original_shapes) 1296 1297 for output_id, tile_id in enumerate(tile_ids): 1298 output = batched_output[output_id] 1299 assert output.shape[0] == 3 1300 1301 # Set the predictions in the output for this tile. 1302 block = tiling.get_block_with_halo(tile_id, halo=list(halo)) 1303 local_bb = tuple( 1304 slice(beg, end) for beg, end in zip(block.inner_block_local.begin, block.inner_block_local.end) 1305 ) 1306 inner_bb = tuple(slice(beg, end) for beg, end in zip(block.inner_block.begin, block.inner_block.end)) 1307 1308 foreground[inner_bb] = output[0][local_bb] 1309 center_distances[inner_bb] = output[1][local_bb] 1310 boundary_distances[inner_bb] = output[2][local_bb] 1311 pbar_update(1) 1312 1313 pbar_close() 1314 1315 # Set the state. 1316 self._i = i 1317 self._foreground = foreground 1318 self._center_distances = center_distances 1319 self._boundary_distances = boundary_distances 1320 self._is_initialized = True
Same as InstanceSegmentationWithDecoder but for tiled image embeddings.
1227 @torch.no_grad() 1228 def initialize( 1229 self, 1230 image: np.ndarray, 1231 image_embeddings: Optional[util.ImageEmbeddings] = None, 1232 i: Optional[int] = None, 1233 tile_shape: Optional[Tuple[int, int]] = None, 1234 halo: Optional[Tuple[int, int]] = None, 1235 verbose: bool = False, 1236 pbar_init: Optional[callable] = None, 1237 pbar_update: Optional[callable] = None, 1238 batch_size: int = 1, 1239 mask: Optional[np.typing.ArrayLike] = None, 1240 ) -> None: 1241 """Initialize image embeddings and decoder predictions for an image. 1242 1243 Args: 1244 image: The input image, volume or timeseries. 1245 image_embeddings: Optional precomputed image embeddings. 1246 See `util.precompute_image_embeddings` for details. 1247 i: Index for the image data. Required if `image` has three spatial dimensions 1248 or a time dimension and two spatial dimensions. 1249 tile_shape: Shape of the tiles for precomputing image embeddings. 1250 halo: Overlap of the tiles for tiled precomputation of image embeddings. 1251 verbose: Dummy input to be compatible with other function signatures. By default, set to 'False'. 1252 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 1253 Can be used together with pbar_update to handle napari progress bar in other thread. 1254 To enable using this function within a threadworker. 1255 pbar_update: Callback to update an external progress bar. 1256 batch_size: The batch size for image embedding computation and segmentation decoder prediction. 1257 mask: An optional mask to define areas that are ignored in the segmentation. 1258 """ 1259 original_size = image.shape[:2] 1260 self._image_embeddings, tile_shape, halo, tiles_in_mask = _process_tiled_embeddings( 1261 self._predictor, image, image_embeddings, tile_shape, halo, 1262 verbose=verbose, batch_size=batch_size, mask=mask, i=i, 1263 ) 1264 tiling = Blocking([0, 0], original_size, tile_shape) 1265 1266 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 1267 1268 foreground = np.zeros(original_size, dtype="float32") 1269 center_distances = np.zeros(original_size, dtype="float32") 1270 boundary_distances = np.zeros(original_size, dtype="float32") 1271 1272 msg = "Initialize tiled instance segmentation with decoder" 1273 if tiles_in_mask is None: 1274 n_tiles = tiling.number_of_blocks 1275 all_tile_ids = list(range(n_tiles)) 1276 else: 1277 n_tiles = len(tiles_in_mask) 1278 all_tile_ids = tiles_in_mask 1279 msg += " and mask" 1280 1281 n_batches = int(np.ceil(n_tiles / batch_size)) 1282 pbar_init(n_tiles, msg) 1283 tile_ids_for_batches = np.array_split(all_tile_ids, n_batches) 1284 1285 for tile_ids in tile_ids_for_batches: 1286 batched_embeddings, input_shapes, original_shapes = [], [], [] 1287 for tile_id in tile_ids: 1288 # Get the image embeddings from the predictor for this tile. 1289 self._predictor = util.set_precomputed(self._predictor, self._image_embeddings, i=i, tile_id=tile_id) 1290 1291 batched_embeddings.append(self._predictor.features) 1292 input_shapes.append(tuple(self._predictor.input_size)) 1293 original_shapes.append(tuple(self._predictor.original_size)) 1294 1295 batched_output = self._predict_decoder(batched_embeddings, input_shapes, original_shapes) 1296 1297 for output_id, tile_id in enumerate(tile_ids): 1298 output = batched_output[output_id] 1299 assert output.shape[0] == 3 1300 1301 # Set the predictions in the output for this tile. 1302 block = tiling.get_block_with_halo(tile_id, halo=list(halo)) 1303 local_bb = tuple( 1304 slice(beg, end) for beg, end in zip(block.inner_block_local.begin, block.inner_block_local.end) 1305 ) 1306 inner_bb = tuple(slice(beg, end) for beg, end in zip(block.inner_block.begin, block.inner_block.end)) 1307 1308 foreground[inner_bb] = output[0][local_bb] 1309 center_distances[inner_bb] = output[1][local_bb] 1310 boundary_distances[inner_bb] = output[2][local_bb] 1311 pbar_update(1) 1312 1313 pbar_close() 1314 1315 # Set the state. 1316 self._i = i 1317 self._foreground = foreground 1318 self._center_distances = center_distances 1319 self._boundary_distances = boundary_distances 1320 self._is_initialized = True
Initialize image embeddings and decoder predictions for an image.
Arguments:
- image: The input image, volume or timeseries.
- image_embeddings: Optional precomputed image embeddings.
See
util.precompute_image_embeddingsfor details. - i: Index for the image data. Required if
imagehas three spatial dimensions or a time dimension and two spatial dimensions. - tile_shape: Shape of the tiles for precomputing image embeddings.
- halo: Overlap of the tiles for tiled precomputation of image embeddings.
- verbose: Dummy input to be compatible with other function signatures. By default, set to 'False'.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enable using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
- batch_size: The batch size for image embedding computation and segmentation decoder prediction.
- mask: An optional mask to define areas that are ignored in the segmentation.
1395class AutomaticPromptGenerator(InstanceSegmentationWithDecoder): 1396 """Generates an instance segmentation automatically, using automatically generated prompts from a decoder. 1397 1398 This class is used in the same way as `InstanceSegmentationWithDecoder` and `AutomaticMaskGenerator` 1399 1400 Args: 1401 predictor: The segment anything predictor. 1402 decoder: The derive prompts for automatic instance segmentation. 1403 """ 1404 def generate( 1405 self, 1406 min_size: int = 25, 1407 center_distance_threshold: float = 0.5, 1408 boundary_distance_threshold: float = 0.5, 1409 foreground_threshold: float = 0.5, 1410 multimasking: bool = False, 1411 batch_size: int = 32, 1412 nms_threshold: float = 0.9, 1413 intersection_over_min: bool = False, 1414 output_mode: str = "instance_segmentation", 1415 mask_threshold: Optional[Union[float, str]] = None, 1416 refine_with_box_prompts: bool = False, 1417 prompt_function: Optional[callable] = None, 1418 ) -> Union[List[Dict[str, Any]], np.ndarray]: 1419 """Generate the instance segmentation for the currently initialized image. 1420 1421 The instance segmentation is generated by deriving prompts from the foreground and 1422 distance predictions of the segmentation decoder by thresholding these predictions, 1423 intersecting them, computing connected components, and then using the component's 1424 centers as point prompts. The masks are then filtered via NMS and merged into a segmentation. 1425 1426 Args: 1427 min_size: Minimal object size in the segmentation result. By default, set to '25'. 1428 center_distance_threshold: The threshold for the center distance predictions. 1429 boundary_distance_threshold: The threshold for the boundary distance predictions. 1430 multimasking: Whether to use multi-mask prediction for turning the prompts into masks. 1431 batch_size: The batch size for parallelizing the prediction based on prompts. 1432 nms_threshold: The threshold for non-maximum suppression (NMS). 1433 intersection_over_min: Whether to use the minimum area of the two objects or the 1434 intersection over union area (default) in NMS. 1435 output_mode: The form masks are returned in. Possible values are: 1436 - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks. 1437 - 'instance_segmentation': Return masks merged into an instance segmentation in a single array. 1438 By default, set to 'instance_segmentation'. 1439 mask_threshold: The threshold for turning logits into masks in `micro_sam.inference.batched_inference`.` 1440 refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps 1441 derived from the segmentations after point prompts. 1442 prompt_function: A custom function for deriving prompts from the segmentation decoder predictions. 1443 If given, the default prompt derivation procedure is not used. Must have the following signature: 1444 ``` 1445 def prompt_function(foreground, center_distances, boundary_distances, **kwargs) 1446 ``` 1447 where `foreground`, `center_distances`, and `boundary_distances` are the respective 1448 predictions from the segmentation decoder. It must returns a dictionary containing 1449 either point, box, or mask prompts in a format compattible with `micro_sam.inference.batched_inference`. 1450 1451 Returns: 1452 The instance segmentation masks. 1453 """ 1454 if not self.is_initialized: 1455 raise RuntimeError("AutomaticPromptGenerator has not been initialized. Call initialize first.") 1456 foreground, center_distances, boundary_distances =\ 1457 self._foreground, self._center_distances, self._boundary_distances 1458 1459 # 1.) Derive promtps from the decoder predictions. 1460 prompt_function = _derive_point_prompts if prompt_function is None else prompt_function 1461 prompts = prompt_function( 1462 foreground=foreground, 1463 center_distances=center_distances, 1464 boundary_distances=boundary_distances, 1465 foreground_threshold=foreground_threshold, 1466 center_distance_threshold=center_distance_threshold, 1467 boundary_distance_threshold=boundary_distance_threshold, 1468 ) 1469 1470 # 2.) Apply the predictor to the prompts. 1471 if prompts is None: # No prompts were derived, we can't do much further and return empty masks. 1472 return np.zeros(foreground.shape, dtype="uint32") if output_mode == "instance_segmentation" else [] 1473 else: 1474 predictions = batched_inference( 1475 self._predictor, 1476 image=None, 1477 batch_size=batch_size, 1478 return_instance_segmentation=False, 1479 multimasking=multimasking, 1480 mask_threshold=mask_threshold, 1481 i=getattr(self, "_i", None), 1482 **prompts, 1483 ) 1484 1485 # 3.) Refine the segmentation with box prompts. 1486 if refine_with_box_prompts: 1487 box_extension = 0.01 # expose as hyperparam? 1488 prompts = _derive_box_prompts(predictions, box_extension) 1489 predictions = batched_inference( 1490 self._predictor, 1491 image=None, 1492 batch_size=batch_size, 1493 return_instance_segmentation=False, 1494 multimasking=multimasking, 1495 mask_threshold=mask_threshold, 1496 i=getattr(self, "_i", None), 1497 **prompts, 1498 ) 1499 1500 # 4.) Apply non-max suppression to the masks. 1501 segmentation = util.apply_nms( 1502 predictions, min_size=min_size, nms_thresh=nms_threshold, intersection_over_min=intersection_over_min 1503 ) 1504 if output_mode != "instance_segmentation": 1505 segmentation = self._to_masks(segmentation, output_mode) 1506 return segmentation
Generates an instance segmentation automatically, using automatically generated prompts from a decoder.
This class is used in the same way as InstanceSegmentationWithDecoder and AutomaticMaskGenerator
Arguments:
- predictor: The segment anything predictor.
- decoder: The derive prompts for automatic instance segmentation.
1404 def generate( 1405 self, 1406 min_size: int = 25, 1407 center_distance_threshold: float = 0.5, 1408 boundary_distance_threshold: float = 0.5, 1409 foreground_threshold: float = 0.5, 1410 multimasking: bool = False, 1411 batch_size: int = 32, 1412 nms_threshold: float = 0.9, 1413 intersection_over_min: bool = False, 1414 output_mode: str = "instance_segmentation", 1415 mask_threshold: Optional[Union[float, str]] = None, 1416 refine_with_box_prompts: bool = False, 1417 prompt_function: Optional[callable] = None, 1418 ) -> Union[List[Dict[str, Any]], np.ndarray]: 1419 """Generate the instance segmentation for the currently initialized image. 1420 1421 The instance segmentation is generated by deriving prompts from the foreground and 1422 distance predictions of the segmentation decoder by thresholding these predictions, 1423 intersecting them, computing connected components, and then using the component's 1424 centers as point prompts. The masks are then filtered via NMS and merged into a segmentation. 1425 1426 Args: 1427 min_size: Minimal object size in the segmentation result. By default, set to '25'. 1428 center_distance_threshold: The threshold for the center distance predictions. 1429 boundary_distance_threshold: The threshold for the boundary distance predictions. 1430 multimasking: Whether to use multi-mask prediction for turning the prompts into masks. 1431 batch_size: The batch size for parallelizing the prediction based on prompts. 1432 nms_threshold: The threshold for non-maximum suppression (NMS). 1433 intersection_over_min: Whether to use the minimum area of the two objects or the 1434 intersection over union area (default) in NMS. 1435 output_mode: The form masks are returned in. Possible values are: 1436 - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks. 1437 - 'instance_segmentation': Return masks merged into an instance segmentation in a single array. 1438 By default, set to 'instance_segmentation'. 1439 mask_threshold: The threshold for turning logits into masks in `micro_sam.inference.batched_inference`.` 1440 refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps 1441 derived from the segmentations after point prompts. 1442 prompt_function: A custom function for deriving prompts from the segmentation decoder predictions. 1443 If given, the default prompt derivation procedure is not used. Must have the following signature: 1444 ``` 1445 def prompt_function(foreground, center_distances, boundary_distances, **kwargs) 1446 ``` 1447 where `foreground`, `center_distances`, and `boundary_distances` are the respective 1448 predictions from the segmentation decoder. It must returns a dictionary containing 1449 either point, box, or mask prompts in a format compattible with `micro_sam.inference.batched_inference`. 1450 1451 Returns: 1452 The instance segmentation masks. 1453 """ 1454 if not self.is_initialized: 1455 raise RuntimeError("AutomaticPromptGenerator has not been initialized. Call initialize first.") 1456 foreground, center_distances, boundary_distances =\ 1457 self._foreground, self._center_distances, self._boundary_distances 1458 1459 # 1.) Derive promtps from the decoder predictions. 1460 prompt_function = _derive_point_prompts if prompt_function is None else prompt_function 1461 prompts = prompt_function( 1462 foreground=foreground, 1463 center_distances=center_distances, 1464 boundary_distances=boundary_distances, 1465 foreground_threshold=foreground_threshold, 1466 center_distance_threshold=center_distance_threshold, 1467 boundary_distance_threshold=boundary_distance_threshold, 1468 ) 1469 1470 # 2.) Apply the predictor to the prompts. 1471 if prompts is None: # No prompts were derived, we can't do much further and return empty masks. 1472 return np.zeros(foreground.shape, dtype="uint32") if output_mode == "instance_segmentation" else [] 1473 else: 1474 predictions = batched_inference( 1475 self._predictor, 1476 image=None, 1477 batch_size=batch_size, 1478 return_instance_segmentation=False, 1479 multimasking=multimasking, 1480 mask_threshold=mask_threshold, 1481 i=getattr(self, "_i", None), 1482 **prompts, 1483 ) 1484 1485 # 3.) Refine the segmentation with box prompts. 1486 if refine_with_box_prompts: 1487 box_extension = 0.01 # expose as hyperparam? 1488 prompts = _derive_box_prompts(predictions, box_extension) 1489 predictions = batched_inference( 1490 self._predictor, 1491 image=None, 1492 batch_size=batch_size, 1493 return_instance_segmentation=False, 1494 multimasking=multimasking, 1495 mask_threshold=mask_threshold, 1496 i=getattr(self, "_i", None), 1497 **prompts, 1498 ) 1499 1500 # 4.) Apply non-max suppression to the masks. 1501 segmentation = util.apply_nms( 1502 predictions, min_size=min_size, nms_thresh=nms_threshold, intersection_over_min=intersection_over_min 1503 ) 1504 if output_mode != "instance_segmentation": 1505 segmentation = self._to_masks(segmentation, output_mode) 1506 return segmentation
Generate the instance segmentation for the currently initialized image.
The instance segmentation is generated by deriving prompts from the foreground and distance predictions of the segmentation decoder by thresholding these predictions, intersecting them, computing connected components, and then using the component's centers as point prompts. The masks are then filtered via NMS and merged into a segmentation.
Arguments:
- min_size: Minimal object size in the segmentation result. By default, set to '25'.
- center_distance_threshold: The threshold for the center distance predictions.
- boundary_distance_threshold: The threshold for the boundary distance predictions.
- multimasking: Whether to use multi-mask prediction for turning the prompts into masks.
- batch_size: The batch size for parallelizing the prediction based on prompts.
- nms_threshold: The threshold for non-maximum suppression (NMS).
- intersection_over_min: Whether to use the minimum area of the two objects or the intersection over union area (default) in NMS.
- output_mode: The form masks are returned in. Possible values are:
- 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
- 'instance_segmentation': Return masks merged into an instance segmentation in a single array. By default, set to 'instance_segmentation'.
- mask_threshold: The threshold for turning logits into masks in
micro_sam.inference.batched_inference.` - refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps derived from the segmentations after point prompts.
prompt_function: A custom function for deriving prompts from the segmentation decoder predictions. If given, the default prompt derivation procedure is not used. Must have the following signature:
def prompt_function(foreground, center_distances, boundary_distances, **kwargs)where
foreground,center_distances, andboundary_distancesare the respective predictions from the segmentation decoder. It must returns a dictionary containing either point, box, or mask prompts in a format compattible withmicro_sam.inference.batched_inference.
Returns:
The instance segmentation masks.
1509class TiledAutomaticPromptGenerator(TiledInstanceSegmentationWithDecoder): 1510 """Same as `AutomaticPromptGenerator` but for tiled image embeddings. 1511 """ 1512 def generate( 1513 self, 1514 min_size: int = 25, 1515 center_distance_threshold: float = 0.5, 1516 boundary_distance_threshold: float = 0.5, 1517 foreground_threshold: float = 0.5, 1518 multimasking: bool = False, 1519 batch_size: int = 32, 1520 nms_threshold: float = 0.9, 1521 intersection_over_min: bool = False, 1522 output_mode: str = "instance_segmentation", 1523 mask_threshold: Optional[Union[float, str]] = None, 1524 refine_with_box_prompts: bool = False, 1525 prompt_function: Optional[callable] = None, 1526 optimize_memory: bool = False, 1527 ) -> List[Dict[str, Any]]: 1528 """Generate tiling-based instance segmentation for the currently initialized image. 1529 1530 Args: 1531 min_size: Minimal object size in the segmentation result. By default, set to '25'. 1532 center_distance_threshold: The threshold for the center distance predictions. 1533 boundary_distance_threshold: The threshold for the boundary distance predictions. 1534 multimasking: Whether to use multi-mask prediction for turning the prompts into masks. 1535 batch_size: The batch size for parallelizing the prediction based on prompts. 1536 nms_threshold: The threshold for non-maximum suppression (NMS). 1537 intersection_over_min: Whether to use the minimum area of the two objects or the 1538 intersection over union area (default) in NMS. 1539 output_mode: The form masks are returned in. Possible values are: 1540 - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks. 1541 - 'instance_segmentation': Return masks merged into an instance segmentation in a single array. 1542 By default, set to 'instance_segmentation'. 1543 mask_threshold: The threshold for turining logits into masks in `micro_sam.inference.batched_inference`.` 1544 refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps 1545 derived from the segmentations after point prompts. Currently not supported for tiled segmentation. 1546 prompt_function: A custom function for deriving prompts from the segmentation decoder predictions. 1547 If given, the default prompt derivation procedure is not used. Must have the following signature: 1548 ``` 1549 def prompt_function(foreground, center_distances, boundary_distances, **kwargs) 1550 ``` 1551 where `foreground`, `center_distances`, and `boundary_distances` are the respective 1552 predictions from the segmentation decoder. It must returns a dictionary containing 1553 either point, box, or mask prompts in a format compattible with `micro_sam.inference.batched_inference`. 1554 optimize_memory: Whether to optimize the memory consumption by merging the per-slice 1555 segmentation results immediatly with NMS, rather than running a NMS for all results. 1556 This may lead to a slightly different segmentation result and is only compatible with 1557 `refine_with_box_prompts=False` and `output_mode="instance_segmentation"`. 1558 1559 Returns: 1560 The instance segmentation masks. 1561 """ 1562 if not self.is_initialized: 1563 raise RuntimeError("TiledAutomaticPromptGenerator has not been initialized. Call initialize first.") 1564 if optimize_memory and (output_mode != "instance_segmentation" or refine_with_box_prompts): 1565 raise ValueError("Invalid settings") 1566 foreground, center_distances, boundary_distances =\ 1567 self._foreground, self._center_distances, self._boundary_distances 1568 1569 # 1.) Derive promtps from the decoder predictions. 1570 prompt_function = _derive_point_prompts if prompt_function is None else prompt_function 1571 prompts = prompt_function( 1572 foreground, 1573 center_distances, 1574 boundary_distances, 1575 foreground_threshold=foreground_threshold, 1576 center_distance_threshold=center_distance_threshold, 1577 boundary_distance_threshold=boundary_distance_threshold, 1578 ) 1579 1580 # 2.) Apply the predictor to the prompts. 1581 shape = foreground.shape 1582 if prompts is None: # No prompts were derived, we can't do much further and return empty masks. 1583 return np.zeros(shape, dtype="uint32") if output_mode == "instance_segmentation" else [] 1584 else: 1585 if optimize_memory: 1586 prompts.update(dict( 1587 min_size=min_size, nms_thresh=nms_threshold, intersection_over_min=intersection_over_min 1588 )) 1589 predictions = batched_tiled_inference( 1590 self._predictor, 1591 image=None, 1592 batch_size=batch_size, 1593 image_embeddings=self._image_embeddings, 1594 return_instance_segmentation=False, 1595 multimasking=multimasking, 1596 optimize_memory=optimize_memory, 1597 i=getattr(self, "_i", None), 1598 **prompts 1599 ) 1600 # Optimize memory directly returns an instance segmentation and does not 1601 # allow for any further refinements. 1602 if optimize_memory: 1603 return predictions 1604 1605 # 3.) Refine the segmentation with box prompts. 1606 if refine_with_box_prompts: 1607 # TODO 1608 raise NotImplementedError 1609 1610 # 4.) Apply non-max suppression to the masks. 1611 segmentation = util.apply_nms( 1612 predictions, shape=shape, min_size=min_size, nms_thresh=nms_threshold, 1613 intersection_over_min=intersection_over_min, 1614 ) 1615 if output_mode != "instance_segmentation": 1616 segmentation = self._to_masks(segmentation, output_mode) 1617 return segmentation 1618 1619 # Set state and get state are not implemented yet, as this generator relies on having the image embeddings 1620 # in the state. However, they should not be serialized here and we have to address this a bit differently. 1621 def get_state(self): 1622 """@private 1623 """ 1624 raise NotImplementedError 1625 1626 def set_state(self, state): 1627 """@private 1628 """ 1629 raise NotImplementedError
Same as AutomaticPromptGenerator but for tiled image embeddings.
1512 def generate( 1513 self, 1514 min_size: int = 25, 1515 center_distance_threshold: float = 0.5, 1516 boundary_distance_threshold: float = 0.5, 1517 foreground_threshold: float = 0.5, 1518 multimasking: bool = False, 1519 batch_size: int = 32, 1520 nms_threshold: float = 0.9, 1521 intersection_over_min: bool = False, 1522 output_mode: str = "instance_segmentation", 1523 mask_threshold: Optional[Union[float, str]] = None, 1524 refine_with_box_prompts: bool = False, 1525 prompt_function: Optional[callable] = None, 1526 optimize_memory: bool = False, 1527 ) -> List[Dict[str, Any]]: 1528 """Generate tiling-based instance segmentation for the currently initialized image. 1529 1530 Args: 1531 min_size: Minimal object size in the segmentation result. By default, set to '25'. 1532 center_distance_threshold: The threshold for the center distance predictions. 1533 boundary_distance_threshold: The threshold for the boundary distance predictions. 1534 multimasking: Whether to use multi-mask prediction for turning the prompts into masks. 1535 batch_size: The batch size for parallelizing the prediction based on prompts. 1536 nms_threshold: The threshold for non-maximum suppression (NMS). 1537 intersection_over_min: Whether to use the minimum area of the two objects or the 1538 intersection over union area (default) in NMS. 1539 output_mode: The form masks are returned in. Possible values are: 1540 - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks. 1541 - 'instance_segmentation': Return masks merged into an instance segmentation in a single array. 1542 By default, set to 'instance_segmentation'. 1543 mask_threshold: The threshold for turining logits into masks in `micro_sam.inference.batched_inference`.` 1544 refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps 1545 derived from the segmentations after point prompts. Currently not supported for tiled segmentation. 1546 prompt_function: A custom function for deriving prompts from the segmentation decoder predictions. 1547 If given, the default prompt derivation procedure is not used. Must have the following signature: 1548 ``` 1549 def prompt_function(foreground, center_distances, boundary_distances, **kwargs) 1550 ``` 1551 where `foreground`, `center_distances`, and `boundary_distances` are the respective 1552 predictions from the segmentation decoder. It must returns a dictionary containing 1553 either point, box, or mask prompts in a format compattible with `micro_sam.inference.batched_inference`. 1554 optimize_memory: Whether to optimize the memory consumption by merging the per-slice 1555 segmentation results immediatly with NMS, rather than running a NMS for all results. 1556 This may lead to a slightly different segmentation result and is only compatible with 1557 `refine_with_box_prompts=False` and `output_mode="instance_segmentation"`. 1558 1559 Returns: 1560 The instance segmentation masks. 1561 """ 1562 if not self.is_initialized: 1563 raise RuntimeError("TiledAutomaticPromptGenerator has not been initialized. Call initialize first.") 1564 if optimize_memory and (output_mode != "instance_segmentation" or refine_with_box_prompts): 1565 raise ValueError("Invalid settings") 1566 foreground, center_distances, boundary_distances =\ 1567 self._foreground, self._center_distances, self._boundary_distances 1568 1569 # 1.) Derive promtps from the decoder predictions. 1570 prompt_function = _derive_point_prompts if prompt_function is None else prompt_function 1571 prompts = prompt_function( 1572 foreground, 1573 center_distances, 1574 boundary_distances, 1575 foreground_threshold=foreground_threshold, 1576 center_distance_threshold=center_distance_threshold, 1577 boundary_distance_threshold=boundary_distance_threshold, 1578 ) 1579 1580 # 2.) Apply the predictor to the prompts. 1581 shape = foreground.shape 1582 if prompts is None: # No prompts were derived, we can't do much further and return empty masks. 1583 return np.zeros(shape, dtype="uint32") if output_mode == "instance_segmentation" else [] 1584 else: 1585 if optimize_memory: 1586 prompts.update(dict( 1587 min_size=min_size, nms_thresh=nms_threshold, intersection_over_min=intersection_over_min 1588 )) 1589 predictions = batched_tiled_inference( 1590 self._predictor, 1591 image=None, 1592 batch_size=batch_size, 1593 image_embeddings=self._image_embeddings, 1594 return_instance_segmentation=False, 1595 multimasking=multimasking, 1596 optimize_memory=optimize_memory, 1597 i=getattr(self, "_i", None), 1598 **prompts 1599 ) 1600 # Optimize memory directly returns an instance segmentation and does not 1601 # allow for any further refinements. 1602 if optimize_memory: 1603 return predictions 1604 1605 # 3.) Refine the segmentation with box prompts. 1606 if refine_with_box_prompts: 1607 # TODO 1608 raise NotImplementedError 1609 1610 # 4.) Apply non-max suppression to the masks. 1611 segmentation = util.apply_nms( 1612 predictions, shape=shape, min_size=min_size, nms_thresh=nms_threshold, 1613 intersection_over_min=intersection_over_min, 1614 ) 1615 if output_mode != "instance_segmentation": 1616 segmentation = self._to_masks(segmentation, output_mode) 1617 return segmentation
Generate tiling-based instance segmentation for the currently initialized image.
Arguments:
- min_size: Minimal object size in the segmentation result. By default, set to '25'.
- center_distance_threshold: The threshold for the center distance predictions.
- boundary_distance_threshold: The threshold for the boundary distance predictions.
- multimasking: Whether to use multi-mask prediction for turning the prompts into masks.
- batch_size: The batch size for parallelizing the prediction based on prompts.
- nms_threshold: The threshold for non-maximum suppression (NMS).
- intersection_over_min: Whether to use the minimum area of the two objects or the intersection over union area (default) in NMS.
- output_mode: The form masks are returned in. Possible values are:
- 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
- 'instance_segmentation': Return masks merged into an instance segmentation in a single array. By default, set to 'instance_segmentation'.
- mask_threshold: The threshold for turining logits into masks in
micro_sam.inference.batched_inference.` - refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps derived from the segmentations after point prompts. Currently not supported for tiled segmentation.
prompt_function: A custom function for deriving prompts from the segmentation decoder predictions. If given, the default prompt derivation procedure is not used. Must have the following signature:
def prompt_function(foreground, center_distances, boundary_distances, **kwargs)where
foreground,center_distances, andboundary_distancesare the respective predictions from the segmentation decoder. It must returns a dictionary containing either point, box, or mask prompts in a format compattible withmicro_sam.inference.batched_inference.- optimize_memory: Whether to optimize the memory consumption by merging the per-slice
segmentation results immediatly with NMS, rather than running a NMS for all results.
This may lead to a slightly different segmentation result and is only compatible with
refine_with_box_prompts=Falseandoutput_mode="instance_segmentation".
Returns:
The instance segmentation masks.
1632def get_instance_segmentation_generator( 1633 predictor: SamPredictor, 1634 is_tiled: bool, 1635 decoder: Optional[torch.nn.Module] = None, 1636 segmentation_mode: Optional[Literal["amg", "ais", "apg"]] = None, 1637 **kwargs, 1638) -> Union[AMGBase, InstanceSegmentationWithDecoder]: 1639 f"""Get the automatic mask generator. 1640 1641 Args: 1642 predictor: The segment anything predictor. 1643 is_tiled: Whether tiled embeddings are used. 1644 decoder: Decoder to predict instacne segmmentation. 1645 segmentation_mode: The segmentation mode. One of 'amg', 'ais', or 'apg'. 1646 By default, '{DEFAULT_SEGMENTATION_MODE_WITH_DECODER}' is used if a decoder is passed, 1647 otherwise 'amg' is used. 1648 kwargs: The keyword arguments of the segmentation genetator class. 1649 1650 Returns: 1651 The segmentation generator instance. 1652 """ 1653 # Choose the segmentation decoder default depending on whether we have a decoder. 1654 if segmentation_mode is None: 1655 segmentation_mode = "amg" if decoder is None else DEFAULT_SEGMENTATION_MODE_WITH_DECODER 1656 1657 if segmentation_mode.lower() == "amg": 1658 segmenter_class = TiledAutomaticMaskGenerator if is_tiled else AutomaticMaskGenerator 1659 segmenter = segmenter_class(predictor, **kwargs) 1660 elif segmentation_mode.lower() == "ais": 1661 assert decoder is not None 1662 segmenter_class = TiledInstanceSegmentationWithDecoder if is_tiled else InstanceSegmentationWithDecoder 1663 segmenter = segmenter_class(predictor, decoder, **kwargs) 1664 elif segmentation_mode.lower() == "apg": 1665 assert decoder is not None 1666 segmenter_class = TiledAutomaticPromptGenerator if is_tiled else AutomaticPromptGenerator 1667 segmenter = segmenter_class(predictor, decoder, **kwargs) 1668 else: 1669 raise ValueError(f"Invalid segmentation_mode: {segmentation_mode}. Choose one of 'amg', 'ais', or 'apg'.") 1670 1671 return segmenter