micro_sam.instance_segmentation
Automated instance segmentation functionality. The classes implemented here extend the automatic instance segmentation from Segment Anything: https://computational-cell-analytics.github.io/micro-sam/micro_sam.html
1"""Automated instance segmentation functionality. 2The classes implemented here extend the automatic instance segmentation from Segment Anything: 3https://computational-cell-analytics.github.io/micro-sam/micro_sam.html 4""" 5 6import os 7import warnings 8from abc import ABC 9from copy import deepcopy 10from collections import OrderedDict 11from typing import Any, Dict, List, Optional, Tuple, Union 12 13import vigra 14import numpy as np 15from skimage.measure import regionprops 16from skimage.segmentation import find_boundaries 17 18import torch 19from torchvision.ops.boxes import batched_nms, box_area 20 21from torch_em.model import UNETR 22from torch_em.util.segmentation import watershed_from_center_and_boundary_distances 23 24import elf.parallel as parallel_impl 25from elf.parallel.filters import apply_filter 26 27from nifty.tools import blocking 28 29import segment_anything.utils.amg as amg_utils 30from segment_anything.predictor import SamPredictor 31 32from . import util 33from .inference import batched_inference, batched_tiled_inference 34from ._vendored import batched_mask_to_box, mask_to_rle_pytorch 35 36# We may change this to 'apg' in version 1.8. 37DEFAULT_SEGMENTATION_MODE_WITH_DECODER = "ais" 38 39# 40# Utility Functionality 41# 42 43 44class _FakeInput: 45 def __init__(self, shape): 46 self.shape = shape 47 48 def __getitem__(self, index): 49 block_shape = tuple(ind.stop - ind.start for ind in index) 50 return np.zeros(block_shape, dtype="float32") 51 52 53# 54# Classes for automatic instance segmentation 55# 56 57 58class AMGBase(ABC): 59 """Base class for the automatic mask generators. 60 """ 61 def __init__(self): 62 # the state that has to be computed by the 'initialize' method of the child classes 63 self._is_initialized = False 64 self._crop_list = None 65 self._crop_boxes = None 66 self._original_size = None 67 68 @property 69 def is_initialized(self): 70 """Whether the mask generator has already been initialized. 71 """ 72 return self._is_initialized 73 74 @property 75 def crop_list(self): 76 """The list of mask data after initialization. 77 """ 78 return self._crop_list 79 80 @property 81 def crop_boxes(self): 82 """The list of crop boxes. 83 """ 84 return self._crop_boxes 85 86 @property 87 def original_size(self): 88 """The original image size. 89 """ 90 return self._original_size 91 92 def _postprocess_batch( 93 self, 94 data, 95 crop_box, 96 original_size, 97 pred_iou_thresh, 98 stability_score_thresh, 99 box_nms_thresh, 100 ): 101 orig_h, orig_w = original_size 102 103 # filter by predicted IoU 104 if pred_iou_thresh > 0.0: 105 keep_mask = data["iou_preds"] > pred_iou_thresh 106 data.filter(keep_mask) 107 108 # filter by stability score 109 if stability_score_thresh > 0.0: 110 keep_mask = data["stability_score"] >= stability_score_thresh 111 data.filter(keep_mask) 112 113 # filter boxes that touch crop boundaries 114 keep_mask = ~amg_utils.is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) 115 if not torch.all(keep_mask): 116 data.filter(keep_mask) 117 118 # remove duplicates within this crop. 119 keep_by_nms = batched_nms( 120 data["boxes"].float(), 121 data["iou_preds"], 122 torch.zeros_like(data["boxes"][:, 0]), # categories 123 iou_threshold=box_nms_thresh, 124 ) 125 data.filter(keep_by_nms) 126 127 # return to the original image frame 128 data["boxes"] = amg_utils.uncrop_boxes_xyxy(data["boxes"], crop_box) 129 data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) 130 # the data from embedding based segmentation doesn't have the points 131 # so we skip if the corresponding key can't be found 132 try: 133 data["points"] = amg_utils.uncrop_points(data["points"], crop_box) 134 except KeyError: 135 pass 136 137 return data 138 139 def _postprocess_small_regions(self, mask_data, min_area, nms_thresh): 140 141 if len(mask_data["rles"]) == 0: 142 return mask_data 143 144 # filter small disconnected regions and holes 145 new_masks = [] 146 scores = [] 147 for rle in mask_data["rles"]: 148 mask = amg_utils.rle_to_mask(rle) 149 150 mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="holes") 151 unchanged = not changed 152 mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="islands") 153 unchanged = unchanged and not changed 154 155 new_masks.append(torch.as_tensor(mask, dtype=torch.int).unsqueeze(0)) 156 # give score=0 to changed masks and score=1 to unchanged masks 157 # so NMS will prefer ones that didn't need postprocessing 158 scores.append(float(unchanged)) 159 160 # recalculate boxes and remove any new duplicates 161 masks = torch.cat(new_masks, dim=0) 162 boxes = batched_mask_to_box(masks.to(torch.bool)) # Casting this to boolean as we work with one-hot labels. 163 keep_by_nms = batched_nms( 164 boxes.float(), 165 torch.as_tensor(scores, dtype=torch.float), 166 torch.zeros_like(boxes[:, 0]), # categories 167 iou_threshold=nms_thresh, 168 ) 169 170 # only recalculate RLEs for masks that have changed 171 for i_mask in keep_by_nms: 172 if scores[i_mask] == 0.0: 173 mask_torch = masks[i_mask].unsqueeze(0) 174 # mask_data["rles"][i_mask] = amg_utils.mask_to_rle_pytorch(mask_torch)[0] 175 mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] 176 mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly 177 mask_data.filter(keep_by_nms) 178 179 return mask_data 180 181 def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode): 182 # filter small disconnected regions and holes in masks 183 if min_mask_region_area > 0: 184 mask_data = self._postprocess_small_regions( 185 mask_data, 186 min_mask_region_area, 187 max(box_nms_thresh, crop_nms_thresh), 188 ) 189 190 # encode masks 191 if output_mode == "coco_rle": 192 mask_data["segmentations"] = [amg_utils.coco_encode_rle(rle) for rle in mask_data["rles"]] 193 # We require the binary mask output for creating the instance-seg output. 194 elif output_mode in ("binary_mask", "instance_segmentation"): 195 mask_data["segmentations"] = [amg_utils.rle_to_mask(rle) for rle in mask_data["rles"]] 196 elif output_mode == "rle": 197 mask_data["segmentations"] = mask_data["rles"] 198 else: 199 raise ValueError(f"Invalid output mode {output_mode}.") 200 201 # write mask records 202 curr_anns = [] 203 for idx in range(len(mask_data["segmentations"])): 204 ann = { 205 "segmentation": mask_data["segmentations"][idx], 206 "area": amg_utils.area_from_rle(mask_data["rles"][idx]), 207 "bbox": amg_utils.box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), 208 "predicted_iou": mask_data["iou_preds"][idx].item(), 209 "stability_score": mask_data["stability_score"][idx].item(), 210 "crop_box": amg_utils.box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), 211 } 212 # the data from embedding based segmentation doesn't have the points 213 # so we skip if the corresponding key can't be found 214 try: 215 ann["point_coords"] = [mask_data["points"][idx].tolist()] 216 except KeyError: 217 pass 218 curr_anns.append(ann) 219 220 return curr_anns 221 222 def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None): 223 orig_h, orig_w = original_size 224 225 # serialize predictions and store in MaskData 226 data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1)) 227 if points is not None: 228 data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0), dtype=torch.float) 229 230 del masks 231 232 # calculate the stability scores 233 data["stability_score"] = amg_utils.calculate_stability_score( 234 data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset 235 ) 236 237 # threshold masks and calculate boxes 238 data["masks"] = data["masks"] > self._predictor.model.mask_threshold 239 data["masks"] = data["masks"].type(torch.bool) 240 data["boxes"] = batched_mask_to_box(data["masks"]) 241 242 # compress to RLE 243 data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w) 244 # data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"]) 245 data["rles"] = mask_to_rle_pytorch(data["masks"]) 246 del data["masks"] 247 248 return data 249 250 def get_state(self) -> Dict[str, Any]: 251 """Get the initialized state of the mask generator. 252 253 Returns: 254 State of the mask generator. 255 """ 256 if not self.is_initialized: 257 raise RuntimeError("The state has not been computed yet. Call initialize first.") 258 259 return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size} 260 261 def set_state(self, state: Dict[str, Any]) -> None: 262 """Set the state of the mask generator. 263 264 Args: 265 state: The state of the mask generator, e.g. from serialized state. 266 """ 267 self._crop_list = state["crop_list"] 268 self._crop_boxes = state["crop_boxes"] 269 self._original_size = state["original_size"] 270 self._is_initialized = True 271 272 def clear_state(self): 273 """Clear the state of the mask generator. 274 """ 275 self._crop_list = None 276 self._crop_boxes = None 277 self._original_size = None 278 self._is_initialized = False 279 280 281class AutomaticMaskGenerator(AMGBase): 282 """Generates an instance segmentation without prompts, using a point grid. 283 284 This class implements the same logic as 285 https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py 286 It decouples the computationally expensive steps of generating masks from the cheap post-processing operation 287 to filter these masks to enable grid search and interactively changing the post-processing. 288 289 Use this class as follows: 290 ```python 291 amg = AutomaticMaskGenerator(predictor) 292 amg.initialize(image) # Initialize the masks, this takes care of all expensive computations. 293 masks = amg.generate(pred_iou_thresh=0.8) # Generate the masks. This is fast and enables testing parameters 294 ``` 295 296 Args: 297 predictor: The segment anything predictor. 298 points_per_side: The number of points to be sampled along one side of the image. 299 If None, `point_grids` must provide explicit point sampling. By default, set to '32'. 300 points_per_batch: The number of points run simultaneously by the model. 301 Higher numbers may be faster but use more GPU memory. 302 By default, selects '64' for all devices except 'mps' (selects '16' for performance reasons). 303 crop_n_layers: If >0, the mask prediction will be run again on crops of the image. 304 By default, set to '0'. 305 crop_overlap_ratio: Sets the degree to which crops overlap. By default, set to '512 / 1500'. 306 crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops. 307 By default, set to '1'. 308 point_grids: A list over explicit grids of points used for sampling masks. 309 Normalized to [0, 1] with respect to the image coordinate system. 310 stability_score_offset: The amount to shift the cutoff when calculating the stability score. 311 By default, set to '1.0'. 312 """ 313 def __init__( 314 self, 315 predictor: SamPredictor, 316 points_per_side: Optional[int] = 32, 317 points_per_batch: Optional[int] = None, 318 crop_n_layers: int = 0, 319 crop_overlap_ratio: float = 512 / 1500, 320 crop_n_points_downscale_factor: int = 1, 321 point_grids: Optional[List[np.ndarray]] = None, 322 stability_score_offset: float = 1.0, 323 ): 324 super().__init__() 325 326 if points_per_side is not None: 327 self.point_grids = amg_utils.build_all_layer_point_grids( 328 points_per_side, crop_n_layers, crop_n_points_downscale_factor, 329 ) 330 elif point_grids is not None: 331 self.point_grids = point_grids 332 else: 333 raise ValueError("Can't have both points_per_side and point_grid be None or not None.") 334 335 self._predictor = predictor 336 self._points_per_side = points_per_side 337 338 # we set the points per batch to 16 for mps for performance reasons 339 # and otherwise keep them at the default of 64 340 if points_per_batch is None: 341 points_per_batch = 16 if str(predictor.device) == "mps" else 64 342 self._points_per_batch = points_per_batch 343 344 self._crop_n_layers = crop_n_layers 345 self._crop_overlap_ratio = crop_overlap_ratio 346 self._crop_n_points_downscale_factor = crop_n_points_downscale_factor 347 self._stability_score_offset = stability_score_offset 348 349 def _process_batch(self, points, im_size, crop_box, original_size): 350 # run model on this batch 351 transformed_points = self._predictor.transform.apply_coords(points, im_size) 352 in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float) 353 in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) 354 masks, iou_preds, _ = self._predictor.predict_torch( 355 point_coords=in_points[:, None, :], 356 point_labels=in_labels[:, None], 357 multimask_output=True, 358 return_logits=True, 359 ) 360 data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points) 361 del masks 362 return data 363 364 def _process_crop(self, image, crop_box, crop_layer_idx, precomputed_embeddings, pbar_init=None, pbar_update=None): 365 # Crop the image and calculate embeddings. 366 x0, y0, x1, y1 = crop_box 367 cropped_im = image[y0:y1, x0:x1, :] 368 cropped_im_size = cropped_im.shape[:2] 369 370 if not precomputed_embeddings: 371 self._predictor.set_image(cropped_im) 372 373 # Get the points for this crop. 374 points_scale = np.array(cropped_im_size)[None, ::-1] 375 points_for_image = self.point_grids[crop_layer_idx] * points_scale 376 377 # Generate masks for this crop in batches. 378 data = amg_utils.MaskData() 379 n_batches = len(points_for_image) // self._points_per_batch +\ 380 int(len(points_for_image) % self._points_per_batch != 0) 381 if pbar_init is not None: 382 pbar_init(n_batches, "Predict masks for point grid prompts") 383 384 for (points,) in amg_utils.batch_iterator(self._points_per_batch, points_for_image): 385 batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size) 386 data.cat(batch_data) 387 del batch_data 388 if pbar_update is not None: 389 pbar_update(1) 390 391 if not precomputed_embeddings: 392 self._predictor.reset_image() 393 394 return data 395 396 @torch.no_grad() 397 def initialize( 398 self, 399 image: np.ndarray, 400 image_embeddings: Optional[util.ImageEmbeddings] = None, 401 i: Optional[int] = None, 402 verbose: bool = False, 403 pbar_init: Optional[callable] = None, 404 pbar_update: Optional[callable] = None, 405 ) -> None: 406 """Initialize image embeddings and masks for an image. 407 408 Args: 409 image: The input image, volume or timeseries. 410 image_embeddings: Optional precomputed image embeddings. 411 See `util.precompute_image_embeddings` for details. 412 i: Index for the image data. Required if `image` has three spatial dimensions 413 or a time dimension and two spatial dimensions. 414 verbose: Whether to print computation progress. By default, set to 'False'. 415 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 416 Can be used together with pbar_update to handle napari progress bar in other thread. 417 To enables using this function within a threadworker. 418 pbar_update: Callback to update an external progress bar. 419 """ 420 original_size = image.shape[:2] 421 self._original_size = original_size 422 423 crop_boxes, layer_idxs = amg_utils.generate_crop_boxes( 424 original_size, self._crop_n_layers, self._crop_overlap_ratio 425 ) 426 427 # We can set fixed image embeddings if we only have a single crop box (the default setting). 428 # Otherwise we have to recompute the embeddings for each crop and can't precompute. 429 if len(crop_boxes) == 1: 430 if image_embeddings is None: 431 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 432 util.set_precomputed(self._predictor, image_embeddings, i=i) 433 precomputed_embeddings = True 434 else: 435 precomputed_embeddings = False 436 437 # we need to cast to the image representation that is compatible with SAM 438 image = util._to_image(image) 439 440 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 441 442 crop_list = [] 443 for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 444 crop_data = self._process_crop( 445 image, crop_box, layer_idx, 446 precomputed_embeddings=precomputed_embeddings, 447 pbar_init=pbar_init, pbar_update=pbar_update, 448 ) 449 crop_list.append(crop_data) 450 pbar_close() 451 452 self._is_initialized = True 453 self._crop_list = crop_list 454 self._crop_boxes = crop_boxes 455 456 @torch.no_grad() 457 def generate( 458 self, 459 pred_iou_thresh: float = 0.88, 460 stability_score_thresh: float = 0.95, 461 box_nms_thresh: float = 0.7, 462 crop_nms_thresh: float = 0.7, 463 min_mask_region_area: int = 0, 464 output_mode: str = "instance_segmentation", 465 with_background: bool = True, 466 ) -> Union[List[Dict[str, Any]], np.ndarray]: 467 """Generate instance segmentation for the currently initialized image. 468 469 Args: 470 pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model. 471 By default, set to '0.88'. 472 stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask 473 under changes to the cutoff used to binarize the model prediction. By default, set to '0.95'. 474 box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks. 475 By default, set to '0.7'. 476 crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops. 477 By default, set to '0.7'. 478 min_mask_region_area: Minimal size for the predicted masks. By default, set to '0'. 479 output_mode: The form masks are returned in. Possible values are: 480 - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks. 481 - 'coco_rle': Return a list of dictionaries with run-length encoded masks in MS COCO format. 482 - 'instance_segmentation': Return masks merged into an instance segmentation in a single array. 483 - 'rle': Return a list of dictionaries with run-length encoded masks. 484 By default, set to 'instance_segmentation'. 485 with_background: Whether to remove the largest object, which often covers the background. 486 487 Returns: 488 The segmentation masks. 489 """ 490 if not self.is_initialized: 491 raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.") 492 493 data = amg_utils.MaskData() 494 for data_, crop_box in zip(self.crop_list, self.crop_boxes): 495 crop_data = self._postprocess_batch( 496 data=deepcopy(data_), 497 crop_box=crop_box, original_size=self.original_size, 498 pred_iou_thresh=pred_iou_thresh, 499 stability_score_thresh=stability_score_thresh, 500 box_nms_thresh=box_nms_thresh 501 ) 502 data.cat(crop_data) 503 504 if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0: 505 # Prefer masks from smaller crops 506 scores = 1 / box_area(data["crop_boxes"]) 507 scores = scores.to(data["boxes"].device) 508 keep_by_nms = batched_nms( 509 data["boxes"].float(), 510 scores, 511 torch.zeros_like(data["boxes"][:, 0]), # categories 512 iou_threshold=crop_nms_thresh, 513 ) 514 data.filter(keep_by_nms) 515 516 data.to_numpy() 517 masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode) 518 if output_mode == "instance_segmentation": 519 shape = next(iter(masks))["segmentation"].shape if len(masks) > 0 else self.original_size 520 masks = util.mask_data_to_segmentation( 521 masks, shape=shape, with_background=with_background, merge_exclusively=False 522 ) 523 return masks 524 525 526# Helper function for tiled embedding computation and checking consistent state. 527def _process_tiled_embeddings(predictor, image, image_embeddings, tile_shape, halo, verbose, batch_size, mask, i): 528 if image_embeddings is None: 529 if tile_shape is None or halo is None: 530 raise ValueError("To compute tiled embeddings the parameters tile_shape and halo have to be passed.") 531 image_embeddings = util.precompute_image_embeddings( 532 predictor, image, tile_shape=tile_shape, halo=halo, verbose=verbose, batch_size=batch_size, mask=mask, 533 ) 534 535 # Use tile shape and halo from the precomputed embeddings if not given. 536 # Otherwise check that they are consistent. 537 feats = image_embeddings["features"] 538 tile_shape_, halo_ = tuple(feats.attrs["tile_shape"]), tuple(feats.attrs["halo"]) 539 if tile_shape is None: 540 tile_shape = tile_shape_ 541 elif tile_shape != tile_shape_: 542 raise ValueError( 543 f"Inconsistent tile_shape parameter {tile_shape} with precomputed embeedings: {tile_shape_}." 544 ) 545 if halo is None: 546 halo = halo_ 547 elif halo != halo_: 548 raise ValueError(f"Inconsistent halo parameter {halo} with precomputed embeedings: {halo_}.") 549 550 tiles_in_mask = feats.attrs.get("tiles_in_mask", None) 551 if tiles_in_mask is not None and i is not None: 552 tiles_in_mask = tiles_in_mask[str(i)] 553 554 return image_embeddings, tile_shape, halo, tiles_in_mask 555 556 557class TiledAutomaticMaskGenerator(AutomaticMaskGenerator): 558 """Generates an instance segmentation without prompts, using a point grid. 559 560 Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings. 561 562 Args: 563 predictor: The Segment Anything predictor. 564 points_per_side: The number of points to be sampled along one side of the image. 565 If None, `point_grids` must provide explicit point sampling. By default, set to '32'. 566 points_per_batch: The number of points run simultaneously by the model. 567 Higher numbers may be faster but use more GPU memory. By default, set to '64'. 568 point_grids: A list over explicit grids of points used for sampling masks. 569 Normalized to [0, 1] with respect to the image coordinate system. 570 stability_score_offset: The amount to shift the cutoff when calculating the stability score. 571 By default, set to '1.0'. 572 """ 573 574 # We only expose the arguments that make sense for the tiled mask generator. 575 # Anything related to crops doesn't make sense, because we re-use that functionality 576 # for tiling, so these parameters wouldn't have any effect. 577 def __init__( 578 self, 579 predictor: SamPredictor, 580 points_per_side: Optional[int] = 32, 581 points_per_batch: int = 64, 582 point_grids: Optional[List[np.ndarray]] = None, 583 stability_score_offset: float = 1.0, 584 ) -> None: 585 super().__init__( 586 predictor=predictor, 587 points_per_side=points_per_side, 588 points_per_batch=points_per_batch, 589 point_grids=point_grids, 590 stability_score_offset=stability_score_offset, 591 ) 592 593 @torch.no_grad() 594 def initialize( 595 self, 596 image: np.ndarray, 597 image_embeddings: Optional[util.ImageEmbeddings] = None, 598 i: Optional[int] = None, 599 tile_shape: Optional[Tuple[int, int]] = None, 600 halo: Optional[Tuple[int, int]] = None, 601 verbose: bool = False, 602 pbar_init: Optional[callable] = None, 603 pbar_update: Optional[callable] = None, 604 batch_size: int = 1, 605 mask: Optional[np.typing.ArrayLike] = None, 606 ) -> None: 607 """Initialize image embeddings and masks for an image. 608 609 Args: 610 image: The input image, volume or timeseries. 611 image_embeddings: Optional precomputed image embeddings. 612 See `util.precompute_image_embeddings` for details. 613 i: Index for the image data. Required if `image` has three spatial dimensions 614 or a time dimension and two spatial dimensions. 615 tile_shape: The tile shape for embedding prediction. 616 halo: The overlap of between tiles. 617 verbose: Whether to print computation progress. By default, set to 'False'. 618 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 619 Can be used together with pbar_update to handle napari progress bar in other thread. 620 To enables using this function within a threadworker. 621 pbar_update: Callback to update an external progress bar. 622 batch_size: The batch size for image embedding prediction. By default, set to '1'. 623 mask: An optional mask to define areas that are ignored in the segmentation. 624 """ 625 original_size = image.shape[:2] 626 self._original_size = original_size 627 628 self._image_embeddings, tile_shape, halo, tiles_in_mask = _process_tiled_embeddings( 629 self._predictor, image, image_embeddings, tile_shape, halo, 630 verbose=verbose, batch_size=batch_size, mask=mask, i=i, 631 ) 632 633 tiling = blocking([0, 0], original_size, tile_shape) 634 if tiles_in_mask is None: 635 n_tiles = tiling.numberOfBlocks 636 tile_ids = range(n_tiles) 637 else: 638 n_tiles = len(tiles_in_mask) 639 tile_ids = tiles_in_mask 640 641 # The crop box is always the full local tile. 642 tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in tile_ids] 643 crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles] 644 645 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 646 pbar_init(n_tiles, "Compute masks for tile") 647 648 # We need to cast to the image representation that is compatible with SAM. 649 image = util._to_image(image) 650 651 mask_data = [] 652 for idx, tile_id in enumerate(tile_ids): 653 # set the pre-computed embeddings for this tile 654 features = image_embeddings["features"][str(tile_id)] 655 tile_embeddings = { 656 "features": features, 657 "input_size": features.attrs["input_size"], 658 "original_size": features.attrs["original_size"], 659 } 660 util.set_precomputed(self._predictor, tile_embeddings, i) 661 662 # Compute the mask data for this tile and append it 663 this_mask_data = self._process_crop( 664 image, crop_box=crop_boxes[idx], crop_layer_idx=0, precomputed_embeddings=True 665 ) 666 mask_data.append(this_mask_data) 667 pbar_update(1) 668 pbar_close() 669 670 # set the initialized data 671 self._is_initialized = True 672 self._crop_list = mask_data 673 self._crop_boxes = crop_boxes 674 675 676# 677# Instance segmentation functionality based on fine-tuned decoder 678# 679 680 681class DecoderAdapter(torch.nn.Module): 682 """Adapter to contain the UNETR decoder in a single module. 683 684 To apply the decoder on top of pre-computed embeddings for the segmentation functionality. 685 See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py 686 """ 687 def __init__(self, unetr: torch.nn.Module): 688 super().__init__() 689 690 self.base = unetr.base 691 self.out_conv = unetr.out_conv 692 self.deconv_out = unetr.deconv_out 693 self.decoder_head = unetr.decoder_head 694 self.final_activation = unetr.final_activation 695 self.postprocess_masks = unetr.postprocess_masks 696 697 self.decoder = unetr.decoder 698 self.deconv1 = unetr.deconv1 699 self.deconv2 = unetr.deconv2 700 self.deconv3 = unetr.deconv3 701 self.deconv4 = unetr.deconv4 702 703 def _forward_impl(self, input_): 704 z12 = input_ 705 706 z9 = self.deconv1(z12) 707 z6 = self.deconv2(z9) 708 z3 = self.deconv3(z6) 709 z0 = self.deconv4(z3) 710 711 updated_from_encoder = [z9, z6, z3] 712 713 x = self.base(z12) 714 x = self.decoder(x, encoder_inputs=updated_from_encoder) 715 x = self.deconv_out(x) 716 717 x = torch.cat([x, z0], dim=1) 718 x = self.decoder_head(x) 719 720 x = self.out_conv(x) 721 if self.final_activation is not None: 722 x = self.final_activation(x) 723 return x 724 725 def forward(self, input_, input_shape, original_shape): 726 x = self._forward_impl(input_) 727 x = self.postprocess_masks(x, input_shape, original_shape) 728 return x 729 730 731def get_unetr( 732 image_encoder: torch.nn.Module, 733 decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None, 734 device: Optional[Union[str, torch.device]] = None, 735 out_channels: int = 3, 736 flexible_load_checkpoint: bool = False, 737) -> torch.nn.Module: 738 """Get UNETR model for automatic instance segmentation. 739 740 Args: 741 image_encoder: The image encoder of the SAM model. 742 This is used as encoder by the UNETR too. 743 decoder_state: Optional decoder state to initialize the weights of the UNETR decoder. 744 device: The device. By default, automatically chooses the best available device. 745 out_channels: The number of output channels. By default, set to '3'. 746 flexible_load_checkpoint: Whether to allow reinitialization of parameters 747 which could not be found in the provided decoder state. By default, set to 'False'. 748 749 Returns: 750 The UNETR model. 751 """ 752 device = util.get_device(device) 753 754 if decoder_state is None: 755 use_conv_transpose = False # By default, we use interpolation for upsampling. 756 else: 757 # From the provided pretrained 'decoder_state', we check whether it uses transposed convolutions. 758 # NOTE: Explanation to the logic below - 759 # - We do this by looking for parameter names that contain '.block.' within the "decoder.samplers" 760 # submodules. This naming convention indicates that transposed convolutions are used, 761 # wrapped inside a custom block module. 762 # - Otherwise '.conv.' appears. It indicates a standard `Conv2d` applied after interpolation for upsampling. 763 use_conv_transpose = any(".block." in k for k in decoder_state.keys() if k.startswith("decoder.samplers")) 764 765 unetr = UNETR( 766 backbone="sam", 767 encoder=image_encoder, 768 out_channels=out_channels, 769 use_sam_stats=True, 770 final_activation="Sigmoid", 771 use_skip_connection=False, 772 resize_input=True, 773 use_conv_transpose=use_conv_transpose, 774 ) 775 776 if decoder_state is not None: 777 unetr_state_dict = unetr.state_dict() 778 for k, v in unetr_state_dict.items(): 779 if not k.startswith("encoder"): 780 if flexible_load_checkpoint: # Whether allow reinitalization of params, if not found. 781 if k in decoder_state: # First check whether the key is available in the provided decoder state. 782 unetr_state_dict[k] = decoder_state[k] 783 else: # Otherwise, allow it to initialize it. 784 warnings.warn(f"Could not find '{k}' in the pretrained state dict. Hence, we reinitialize it.") 785 unetr_state_dict[k] = v 786 787 else: # Whether be strict on finding the parameter in the decoder state. 788 if k not in decoder_state: 789 raise RuntimeError(f"The parameters for '{k}' could not be found.") 790 unetr_state_dict[k] = decoder_state[k] 791 792 unetr.load_state_dict(unetr_state_dict) 793 794 unetr.to(device) 795 return unetr 796 797 798def get_decoder( 799 image_encoder: torch.nn.Module, 800 decoder_state: OrderedDict[str, torch.Tensor], 801 device: Optional[Union[str, torch.device]] = None, 802) -> DecoderAdapter: 803 """Get decoder to predict outputs for automatic instance segmentation 804 805 Args: 806 image_encoder: The image encoder of the SAM model. 807 decoder_state: State to initialize the weights of the UNETR decoder. 808 device: The device. By default, automatically chooses the best available device. 809 810 Returns: 811 The decoder for instance segmentation. 812 """ 813 unetr = get_unetr(image_encoder, decoder_state, device) 814 return DecoderAdapter(unetr) 815 816 817def get_predictor_and_decoder( 818 model_type: str, 819 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 820 device: Optional[Union[str, torch.device]] = None, 821 peft_kwargs: Optional[Dict] = None, 822) -> Tuple[SamPredictor, DecoderAdapter]: 823 """Load the SAM model (predictor) and instance segmentation decoder. 824 825 This requires a checkpoint that contains the state for both predictor 826 and decoder. 827 828 Args: 829 model_type: The type of the image encoder used in the SAM model. 830 checkpoint_path: Path to the checkpoint from which to load the data. 831 device: The device. By default, automatically chooses the best available device. 832 peft_kwargs: Keyword arguments for the PEFT wrapper class. 833 834 Returns: 835 The SAM predictor. 836 The decoder for instance segmentation. 837 """ 838 device = util.get_device(device) 839 predictor, state = util.get_sam_model( 840 model_type=model_type, 841 checkpoint_path=checkpoint_path, 842 device=device, 843 return_state=True, 844 peft_kwargs=peft_kwargs, 845 ) 846 847 if "decoder_state" not in state: 848 raise ValueError( 849 f"The checkpoint at '{checkpoint_path}' or the chosen model '{model_type}' does not contain a decoder state" 850 ) 851 852 decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device) 853 return predictor, decoder 854 855 856def _watershed_from_center_and_boundary_distances_parallel( 857 center_distances, 858 boundary_distances, 859 foreground_map, 860 center_distance_threshold, 861 boundary_distance_threshold, 862 foreground_threshold, 863 distance_smoothing, 864 min_size, 865 tile_shape, 866 halo, 867 n_threads, 868 verbose=False, 869): 870 center_distances = apply_filter( 871 center_distances, "gaussianSmoothing", sigma=distance_smoothing, 872 block_shape=tile_shape, n_threads=n_threads 873 ) 874 boundary_distances = apply_filter( 875 boundary_distances, "gaussianSmoothing", sigma=distance_smoothing, 876 block_shape=tile_shape, n_threads=n_threads 877 ) 878 879 fg_mask = foreground_map > foreground_threshold 880 881 marker_map = np.logical_and( 882 center_distances < center_distance_threshold, boundary_distances < boundary_distance_threshold 883 ) 884 marker_map[~fg_mask] = 0 885 886 markers = np.zeros(marker_map.shape, dtype="uint64") 887 markers = parallel_impl.label( 888 marker_map, out=markers, block_shape=tile_shape, n_threads=n_threads, verbose=verbose, 889 ) 890 891 seg = np.zeros_like(markers, dtype="uint64") 892 seg = parallel_impl.seeded_watershed( 893 boundary_distances, seeds=markers, out=seg, block_shape=tile_shape, 894 halo=halo, n_threads=n_threads, verbose=verbose, mask=fg_mask, 895 ) 896 897 out = np.zeros_like(seg, dtype="uint64") 898 out = parallel_impl.size_filter( 899 seg, out=out, min_size=min_size, block_shape=tile_shape, n_threads=n_threads, verbose=verbose 900 ) 901 902 return out 903 904 905class InstanceSegmentationWithDecoder: 906 """Generates an instance segmentation without prompts, using a decoder. 907 908 Implements the same interface as `AutomaticMaskGenerator`. 909 910 Use this class as follows: 911 ```python 912 segmenter = InstanceSegmentationWithDecoder(predictor, decoder) 913 segmenter.initialize(image) # Predict the image embeddings and decoder outputs. 914 masks = segmenter.generate(center_distance_threshold=0.75) # Generate the instance segmentation. 915 ``` 916 917 Args: 918 predictor: The segment anything predictor. 919 decoder: The decoder to predict intermediate representations for instance segmentation. 920 """ 921 def __init__(self, predictor: SamPredictor, decoder: torch.nn.Module) -> None: 922 self._predictor = predictor 923 self._decoder = decoder 924 925 # The decoder outputs. 926 self._foreground = None 927 self._center_distances = None 928 self._boundary_distances = None 929 930 self._is_initialized = False 931 932 @property 933 def is_initialized(self): 934 """Whether the mask generator has already been initialized. 935 """ 936 return self._is_initialized 937 938 @torch.no_grad() 939 def initialize( 940 self, 941 image: np.ndarray, 942 image_embeddings: Optional[util.ImageEmbeddings] = None, 943 i: Optional[int] = None, 944 verbose: bool = False, 945 pbar_init: Optional[callable] = None, 946 pbar_update: Optional[callable] = None, 947 ndim: int = 2, 948 ) -> None: 949 """Initialize image embeddings and decoder predictions for an image. 950 951 Args: 952 image: The input image, volume or timeseries. 953 image_embeddings: Optional precomputed image embeddings. 954 See `util.precompute_image_embeddings` for details. 955 i: Index for the image data. Required if `image` has three spatial dimensions 956 or a time dimension and two spatial dimensions. 957 verbose: Whether to be verbose. By default, set to 'False'. 958 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 959 Can be used together with pbar_update to handle napari progress bar in other thread. 960 To enables using this function within a threadworker. 961 pbar_update: Callback to update an external progress bar. 962 ndim: The dimensionality of the data. If not given will be deduced from the input data. By default, 2. 963 """ 964 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 965 pbar_init(1, "Initialize instance segmentation with decoder") 966 967 if image_embeddings is None: 968 image_embeddings = util.precompute_image_embeddings( 969 predictor=self._predictor, input_=image, ndim=ndim, verbose=verbose 970 ) 971 972 # Get the image embeddings from the predictor. 973 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i) 974 embeddings = self._predictor.features 975 input_shape = tuple(self._predictor.input_size) 976 original_shape = tuple(self._predictor.original_size) 977 978 # Run prediction with the UNETR decoder. 979 output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0) 980 assert output.shape[0] == 3, f"{output.shape}" 981 pbar_update(1) 982 pbar_close() 983 984 # Set the state. 985 self._foreground = output[0] 986 self._center_distances = output[1] 987 self._boundary_distances = output[2] 988 self._i = i 989 self._is_initialized = True 990 991 def _to_masks(self, segmentation, output_mode): 992 if output_mode != "binary_mask": 993 raise ValueError( 994 f"Output mode {output_mode} is not supported. Choose one of 'instance_segmentation', 'binary_masks'" 995 ) 996 997 props = regionprops(segmentation) 998 ndim = segmentation.ndim 999 assert ndim in (2, 3) 1000 1001 shape = segmentation.shape 1002 if ndim == 2: 1003 crop_box = [0, shape[1], 0, shape[0]] 1004 else: 1005 crop_box = [0, shape[2], 0, shape[1], 0, shape[0]] 1006 1007 # go from skimage bbox in format [y0, x0, y1, x1] to SAM format [x0, w, y0, h] 1008 def to_bbox_2d(bbox): 1009 y0, x0 = bbox[0], bbox[1] 1010 w = bbox[3] - x0 1011 h = bbox[2] - y0 1012 return [x0, w, y0, h] 1013 1014 def to_bbox_3d(bbox): 1015 z0, y0, x0 = bbox[0], bbox[1], bbox[2] 1016 w = bbox[5] - x0 1017 h = bbox[4] - y0 1018 d = bbox[3] - y0 1019 return [x0, w, y0, h, z0, d] 1020 1021 to_bbox = to_bbox_2d if ndim == 2 else to_bbox_3d 1022 masks = [ 1023 { 1024 "segmentation": segmentation == prop.label, 1025 "area": prop.area, 1026 "bbox": to_bbox(prop.bbox), 1027 "crop_box": crop_box, 1028 "seg_id": prop.label, 1029 } for prop in props 1030 ] 1031 return masks 1032 1033 def generate( 1034 self, 1035 center_distance_threshold: float = 0.5, 1036 boundary_distance_threshold: float = 0.5, 1037 foreground_threshold: float = 0.5, 1038 foreground_smoothing: float = 1.0, 1039 distance_smoothing: float = 1.6, 1040 min_size: int = 0, 1041 output_mode: str = "instance_segmentation", 1042 tile_shape: Optional[Tuple[int, int]] = None, 1043 halo: Optional[Tuple[int, int]] = None, 1044 n_threads: Optional[int] = None, 1045 ) -> Union[List[Dict[str, Any]], np.ndarray]: 1046 """Generate instance segmentation for the currently initialized image. 1047 1048 Args: 1049 center_distance_threshold: Center distance predictions below this value will be 1050 used to find seeds (intersected with thresholded boundary distance predictions). 1051 By default, set to '0.5'. 1052 boundary_distance_threshold: Boundary distance predictions below this value will be 1053 used to find seeds (intersected with thresholded center distance predictions). 1054 By default, set to '0.5'. 1055 foreground_threshold: Foreground predictions above this value will be used as foreground mask. 1056 By default, set to '0.5'. 1057 foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid 1058 checkerboard artifacts in the prediction. By default, set to '1.0'. 1059 distance_smoothing: Sigma value for smoothing the distance predictions. 1060 min_size: Minimal object size in the segmentation result. By default, set to '0'. 1061 output_mode: The form masks are returned in. Possible values are: 1062 - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks. 1063 - 'instance_segmentation': Return masks merged into an instance segmentation in a single array. 1064 By default, set to 'instance_segmentation'. 1065 tile_shape: Tile shape for parallelizing the instance segmentation post-processing. 1066 This parameter is independent from the tile shape for computing the embeddings. 1067 If not given then post-processing will not be parallelized. 1068 halo: Halo for parallel post-processing. See also `tile_shape`. 1069 n_threads: Number of threads for parallel post-processing. See also `tile_shape`. 1070 1071 Returns: 1072 The segmentation masks. 1073 """ 1074 if not self.is_initialized: 1075 raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.") 1076 1077 if foreground_smoothing > 0: 1078 foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing) 1079 else: 1080 foreground = self._foreground 1081 1082 if tile_shape is None: 1083 segmentation = watershed_from_center_and_boundary_distances( 1084 center_distances=self._center_distances, 1085 boundary_distances=self._boundary_distances, 1086 foreground_map=foreground, 1087 center_distance_threshold=center_distance_threshold, 1088 boundary_distance_threshold=boundary_distance_threshold, 1089 foreground_threshold=foreground_threshold, 1090 distance_smoothing=distance_smoothing, 1091 min_size=min_size, 1092 ) 1093 else: 1094 if halo is None: 1095 raise ValueError("You must pass a value for halo if tile_shape is given.") 1096 segmentation = _watershed_from_center_and_boundary_distances_parallel( 1097 center_distances=self._center_distances, 1098 boundary_distances=self._boundary_distances, 1099 foreground_map=foreground, 1100 center_distance_threshold=center_distance_threshold, 1101 boundary_distance_threshold=boundary_distance_threshold, 1102 foreground_threshold=foreground_threshold, 1103 distance_smoothing=distance_smoothing, 1104 min_size=min_size, 1105 tile_shape=tile_shape, 1106 halo=halo, 1107 n_threads=n_threads, 1108 verbose=False, 1109 ) 1110 1111 if output_mode != "instance_segmentation": 1112 segmentation = self._to_masks(segmentation, output_mode) 1113 return segmentation 1114 1115 def get_state(self) -> Dict[str, Any]: 1116 """Get the initialized state of the instance segmenter. 1117 1118 Returns: 1119 Instance segmentation state. 1120 """ 1121 if not self.is_initialized: 1122 raise RuntimeError("The state has not been computed yet. Call initialize first.") 1123 1124 return { 1125 "foreground": self._foreground, 1126 "center_distances": self._center_distances, 1127 "boundary_distances": self._boundary_distances, 1128 } 1129 1130 def set_state(self, state: Dict[str, Any]) -> None: 1131 """Set the state of the instance segmenter. 1132 1133 Args: 1134 state: The instance segmentation state 1135 """ 1136 self._foreground = state["foreground"] 1137 self._center_distances = state["center_distances"] 1138 self._boundary_distances = state["boundary_distances"] 1139 self._is_initialized = True 1140 1141 def clear_state(self): 1142 """Clear the state of the instance segmenter. 1143 """ 1144 self._foreground = None 1145 self._center_distances = None 1146 self._boundary_distances = None 1147 self._is_initialized = False 1148 1149 1150class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder): 1151 """Same as `InstanceSegmentationWithDecoder` but for tiled image embeddings. 1152 """ 1153 1154 # Apply the decoder in a batched fashion, and then perform the resizing independently per output. 1155 # This is necessary, because the individual tiles may have different tile shapes due to border tiles. 1156 def _predict_decoder(self, batched_embeddings, input_shapes, original_shapes): 1157 batched_embeddings = torch.cat(batched_embeddings) 1158 output = self._decoder._forward_impl(batched_embeddings) 1159 1160 batched_output = [] 1161 for x, input_shape, original_shape in zip(output, input_shapes, original_shapes): 1162 x = self._decoder.postprocess_masks(x.unsqueeze(0), input_shape, original_shape).squeeze(0) 1163 batched_output.append(x.cpu().numpy()) 1164 return batched_output 1165 1166 @torch.no_grad() 1167 def initialize( 1168 self, 1169 image: np.ndarray, 1170 image_embeddings: Optional[util.ImageEmbeddings] = None, 1171 i: Optional[int] = None, 1172 tile_shape: Optional[Tuple[int, int]] = None, 1173 halo: Optional[Tuple[int, int]] = None, 1174 verbose: bool = False, 1175 pbar_init: Optional[callable] = None, 1176 pbar_update: Optional[callable] = None, 1177 batch_size: int = 1, 1178 mask: Optional[np.typing.ArrayLike] = None, 1179 ) -> None: 1180 """Initialize image embeddings and decoder predictions for an image. 1181 1182 Args: 1183 image: The input image, volume or timeseries. 1184 image_embeddings: Optional precomputed image embeddings. 1185 See `util.precompute_image_embeddings` for details. 1186 i: Index for the image data. Required if `image` has three spatial dimensions 1187 or a time dimension and two spatial dimensions. 1188 tile_shape: Shape of the tiles for precomputing image embeddings. 1189 halo: Overlap of the tiles for tiled precomputation of image embeddings. 1190 verbose: Dummy input to be compatible with other function signatures. By default, set to 'False'. 1191 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 1192 Can be used together with pbar_update to handle napari progress bar in other thread. 1193 To enables using this function within a threadworker. 1194 pbar_update: Callback to update an external progress bar. 1195 batch_size: The batch size for image embedding computation and segmentation decoder prediction. 1196 mask: An optional mask to define areas that are ignored in the segmentation. 1197 """ 1198 original_size = image.shape[:2] 1199 self._image_embeddings, tile_shape, halo, tiles_in_mask = _process_tiled_embeddings( 1200 self._predictor, image, image_embeddings, tile_shape, halo, 1201 verbose=verbose, batch_size=batch_size, mask=mask, i=i, 1202 ) 1203 tiling = blocking([0, 0], original_size, tile_shape) 1204 1205 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 1206 1207 foreground = np.zeros(original_size, dtype="float32") 1208 center_distances = np.zeros(original_size, dtype="float32") 1209 boundary_distances = np.zeros(original_size, dtype="float32") 1210 1211 msg = "Initialize tiled instance segmentation with decoder" 1212 if tiles_in_mask is None: 1213 n_tiles = tiling.numberOfBlocks 1214 all_tile_ids = list(range(n_tiles)) 1215 else: 1216 n_tiles = len(tiles_in_mask) 1217 all_tile_ids = tiles_in_mask 1218 msg += " and mask" 1219 1220 n_batches = int(np.ceil(n_tiles / batch_size)) 1221 pbar_init(n_tiles, msg) 1222 tile_ids_for_batches = np.array_split(all_tile_ids, n_batches) 1223 1224 for tile_ids in tile_ids_for_batches: 1225 batched_embeddings, input_shapes, original_shapes = [], [], [] 1226 for tile_id in tile_ids: 1227 # Get the image embeddings from the predictor for this tile. 1228 self._predictor = util.set_precomputed(self._predictor, self._image_embeddings, i=i, tile_id=tile_id) 1229 1230 batched_embeddings.append(self._predictor.features) 1231 input_shapes.append(tuple(self._predictor.input_size)) 1232 original_shapes.append(tuple(self._predictor.original_size)) 1233 1234 batched_output = self._predict_decoder(batched_embeddings, input_shapes, original_shapes) 1235 1236 for output_id, tile_id in enumerate(tile_ids): 1237 output = batched_output[output_id] 1238 assert output.shape[0] == 3 1239 1240 # Set the predictions in the output for this tile. 1241 block = tiling.getBlockWithHalo(tile_id, halo=list(halo)) 1242 local_bb = tuple( 1243 slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end) 1244 ) 1245 inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end)) 1246 1247 foreground[inner_bb] = output[0][local_bb] 1248 center_distances[inner_bb] = output[1][local_bb] 1249 boundary_distances[inner_bb] = output[2][local_bb] 1250 pbar_update(1) 1251 1252 pbar_close() 1253 1254 # Set the state. 1255 self._i = i 1256 self._foreground = foreground 1257 self._center_distances = center_distances 1258 self._boundary_distances = boundary_distances 1259 self._is_initialized = True 1260 1261 1262def _get_centers(segmentation, avoid_image_border=True): 1263 # Not sure if 'find_boundaries' has to be paralellized. 1264 # Compute the distance transform to object bounaries. 1265 boundaries = find_boundaries(segmentation, mode="outer") == 0 1266 # Avoid centroids on the border. 1267 if avoid_image_border: 1268 boundaries[0, :] = False 1269 boundaries[:, 0] = False 1270 boundaries[-1, :] = False 1271 boundaries[:, -1] = False 1272 block_shape = (512, 512) 1273 halo = (16, 16) 1274 distances = parallel_impl.distance_transform(boundaries, halo=halo, block_shape=block_shape) 1275 1276 # Get the maximum coordinate of the distance transform for each id. 1277 props = regionprops(segmentation) 1278 centers = [] 1279 for prop in props: 1280 seg_id = prop.label 1281 # Get the bounding box and mask for this segmentation id. 1282 bounding_box = np.s_[prop.bbox[0]:prop.bbox[2], prop.bbox[1]:prop.bbox[3]] 1283 mask = segmentation[bounding_box] == seg_id 1284 # Restrict the distances to the bounding box. 1285 dist = distances[bounding_box].copy() 1286 # Set the distances outside of the mask to 0. 1287 dist[~mask] = 0 1288 # Find the center coordinate (= distance maximum). 1289 center = np.argmax(dist) 1290 center = np.unravel_index(center, dist.shape) 1291 # Bring the center coordinate back to the full coordinate system. 1292 center = tuple(ce + bb.start for ce, bb in zip(center, bounding_box)) 1293 centers.append(center) 1294 1295 return np.array(centers) 1296 1297 1298def _derive_point_prompts( 1299 foreground: np.ndarray, 1300 center_distances: np.ndarray, 1301 boundary_distances: np.ndarray, 1302 foreground_threshold: float = 0.5, 1303 center_distance_threshold: float = 0.5, 1304 boundary_distance_threshold: float = 0.5, 1305): 1306 bg_mask = foreground < foreground_threshold 1307 hmap_cc = np.logical_and( 1308 center_distances < center_distance_threshold, boundary_distances < boundary_distance_threshold, 1309 ) 1310 hmap_cc[bg_mask] = 0 1311 cc = np.zeros_like(hmap_cc, dtype="uint32") 1312 cc = parallel_impl.label(hmap_cc, out=cc, block_shape=(512, 512)) 1313 prompts = _get_centers(cc) 1314 if len(prompts) == 0: 1315 return None 1316 1317 points = prompts[:, None, ::-1] 1318 labels = np.ones((len(prompts), 1)) 1319 return {"points": points, "point_labels": labels} 1320 1321 1322def _derive_box_prompts(predictions, box_extension): 1323 shape = predictions[0]["segmentation"].shape 1324 bboxes = [pred["bbox"] for pred in predictions] 1325 prompts = [[ 1326 max(x - w * box_extension, 0), 1327 max(y - h * box_extension, 0), 1328 min(x + (1 + box_extension) * w, shape[0]), 1329 min(y + (1 + box_extension) * h, shape[1]), 1330 ] for (x, y, w, h) in bboxes] 1331 return {"boxes": np.array(prompts)} 1332 1333 1334class AutomaticPromptGenerator(InstanceSegmentationWithDecoder): 1335 """Generates an instance segmentation automatically, using automatically generated prompts from a decoder. 1336 1337 This class is used in the same way as `InstanceSegmentationWithDecoder` and `AutomaticMaskGenerator` 1338 1339 Args: 1340 predictor: The segment anything predictor. 1341 decoder: The derive prompts for automatic instance segmentation. 1342 """ 1343 def generate( 1344 self, 1345 min_size: int = 25, 1346 center_distance_threshold: float = 0.5, 1347 boundary_distance_threshold: float = 0.5, 1348 foreground_threshold: float = 0.5, 1349 multimasking: bool = False, 1350 batch_size: int = 32, 1351 nms_threshold: float = 0.9, 1352 intersection_over_min: bool = False, 1353 output_mode: str = "instance_segmentation", 1354 mask_threshold: Optional[Union[float, str]] = None, 1355 refine_with_box_prompts: bool = False, 1356 prompt_function: Optional[callable] = None, 1357 ) -> Union[List[Dict[str, Any]], np.ndarray]: 1358 """Generate the instance segmentation for the currently initialized image. 1359 1360 The instance segmentation is generated by deriving prompts from the foreground and 1361 distance predictions of the segmentation decoder by thresholding these predictions, 1362 intersecting them, computing connected components, and then using the component's 1363 centers as point prompts. The masks are then filtered via NMS and merged into a segmentation. 1364 1365 Args: 1366 min_size: Minimal object size in the segmentation result. By default, set to '25'. 1367 center_distance_threshold: The threshold for the center distance predictions. 1368 boundary_distance_threshold: The threshold for the boundary distance predictions. 1369 multimasking: Whether to use multi-mask prediction for turning the prompts into masks. 1370 batch_size: The batch size for parallelizing the prediction based on prompts. 1371 nms_threshold: The threshold for non-maximum suppression (NMS). 1372 intersection_over_min: Whether to use the minimum area of the two objects or the 1373 intersection over union area (default) in NMS. 1374 output_mode: The form masks are returned in. Possible values are: 1375 - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks. 1376 - 'instance_segmentation': Return masks merged into an instance segmentation in a single array. 1377 By default, set to 'instance_segmentation'. 1378 mask_threshold: The threshold for turining logits into masks in `micro_sam.inference.batched_inference`.` 1379 refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps 1380 derived from the segmentations after point prompts. 1381 prompt_function: A custom function for deriving prompts from the segmentation decoder predictions. 1382 If given, the default prompt derivation procedure is not used. Must have the following signature: 1383 ``` 1384 def prompt_function(foreground, center_distances, boundary_distances, **kwargs) 1385 ``` 1386 where `foreground`, `center_distances`, and `boundary_distances` are the respective 1387 predictions from the segmentation decoder. It must returns a dictionary containing 1388 either point, box, or mask prompts in a format compattible with `micro_sam.inference.batched_inference`. 1389 1390 Returns: 1391 The instance segmentation masks. 1392 """ 1393 if not self.is_initialized: 1394 raise RuntimeError("AutomaticPromptGenerator has not been initialized. Call initialize first.") 1395 foreground, center_distances, boundary_distances =\ 1396 self._foreground, self._center_distances, self._boundary_distances 1397 1398 # 1.) Derive promtps from the decoder predictions. 1399 prompt_function = _derive_point_prompts if prompt_function is None else prompt_function 1400 prompts = prompt_function( 1401 foreground=foreground, 1402 center_distances=center_distances, 1403 boundary_distances=boundary_distances, 1404 foreground_threshold=foreground_threshold, 1405 center_distance_threshold=center_distance_threshold, 1406 boundary_distance_threshold=boundary_distance_threshold, 1407 ) 1408 1409 # 2.) Apply the predictor to the prompts. 1410 if prompts is None: # No prompts were derived, we can't do much further and return empty masks. 1411 return np.zeros(foreground.shape, dtype="uint32") if output_mode == "instance_egmentation" else [] 1412 else: 1413 predictions = batched_inference( 1414 self._predictor, 1415 image=None, 1416 batch_size=batch_size, 1417 return_instance_segmentation=False, 1418 multimasking=multimasking, 1419 mask_threshold=mask_threshold, 1420 i=getattr(self, "_i", None), 1421 **prompts, 1422 ) 1423 1424 # 3.) Refine the segmentation with box prompts. 1425 if refine_with_box_prompts: 1426 box_extension = 0.01 # expose as hyperparam? 1427 prompts = _derive_box_prompts(predictions, box_extension) 1428 predictions = batched_inference( 1429 self._predictor, 1430 image=None, 1431 batch_size=batch_size, 1432 return_instance_segmentation=False, 1433 multimasking=multimasking, 1434 mask_threshold=mask_threshold, 1435 i=getattr(self, "_i", None), 1436 **prompts, 1437 ) 1438 1439 # 4.) Apply non-max suppression to the masks. 1440 segmentation = util.apply_nms( 1441 predictions, min_size=min_size, nms_thresh=nms_threshold, intersection_over_min=intersection_over_min 1442 ) 1443 if output_mode != "instance_segmentation": 1444 segmentation = self._to_masks(segmentation, output_mode) 1445 return segmentation 1446 1447 1448class TiledAutomaticPromptGenerator(TiledInstanceSegmentationWithDecoder): 1449 """Same as `AutomaticPromptGenerator` but for tiled image embeddings. 1450 """ 1451 def generate( 1452 self, 1453 min_size: int = 25, 1454 center_distance_threshold: float = 0.5, 1455 boundary_distance_threshold: float = 0.5, 1456 foreground_threshold: float = 0.5, 1457 multimasking: bool = False, 1458 batch_size: int = 32, 1459 nms_threshold: float = 0.9, 1460 intersection_over_min: bool = False, 1461 output_mode: str = "instance_segmentation", 1462 mask_threshold: Optional[Union[float, str]] = None, 1463 refine_with_box_prompts: bool = False, 1464 prompt_function: Optional[callable] = None, 1465 optimize_memory: bool = False, 1466 ) -> List[Dict[str, Any]]: 1467 """Generate tiling-based instance segmentation for the currently initialized image. 1468 1469 Args: 1470 min_size: Minimal object size in the segmentation result. By default, set to '25'. 1471 center_distance_threshold: The threshold for the center distance predictions. 1472 boundary_distance_threshold: The threshold for the boundary distance predictions. 1473 multimasking: Whether to use multi-mask prediction for turning the prompts into masks. 1474 batch_size: The batch size for parallelizing the prediction based on prompts. 1475 nms_threshold: The threshold for non-maximum suppression (NMS). 1476 intersection_over_min: Whether to use the minimum area of the two objects or the 1477 intersection over union area (default) in NMS. 1478 output_mode: The form masks are returned in. Possible values are: 1479 - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks. 1480 - 'instance_segmentation': Return masks merged into an instance segmentation in a single array. 1481 By default, set to 'instance_segmentation'. 1482 mask_threshold: The threshold for turining logits into masks in `micro_sam.inference.batched_inference`.` 1483 refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps 1484 derived from the segmentations after point prompts. Currently not supported for tiled segmentation. 1485 prompt_function: A custom function for deriving prompts from the segmentation decoder predictions. 1486 If given, the default prompt derivation procedure is not used. Must have the following signature: 1487 ``` 1488 def prompt_function(foreground, center_distances, boundary_distances, **kwargs) 1489 ``` 1490 where `foreground`, `center_distances`, and `boundary_distances` are the respective 1491 predictions from the segmentation decoder. It must returns a dictionary containing 1492 either point, box, or mask prompts in a format compattible with `micro_sam.inference.batched_inference`. 1493 optimize_memory: Whether to optimize the memory consumption by merging the per-slice 1494 segmentation results immediatly with NMS, rather than running a NMS for all results. 1495 This may lead to a slightly different segmentation result and is only compatible with 1496 `refine_with_box_prompts=False` and `output_mode="instance_segmentation"`. 1497 1498 Returns: 1499 The instance segmentation masks. 1500 """ 1501 if not self.is_initialized: 1502 raise RuntimeError("TiledAutomaticPromptGenerator has not been initialized. Call initialize first.") 1503 if optimize_memory and (output_mode != "instance_segmentation" or refine_with_box_prompts): 1504 raise ValueError("Invalid settings") 1505 foreground, center_distances, boundary_distances =\ 1506 self._foreground, self._center_distances, self._boundary_distances 1507 1508 # 1.) Derive promtps from the decoder predictions. 1509 prompt_function = _derive_point_prompts if prompt_function is None else prompt_function 1510 prompts = prompt_function( 1511 foreground, 1512 center_distances, 1513 boundary_distances, 1514 foreground_threshold=foreground_threshold, 1515 center_distance_threshold=center_distance_threshold, 1516 boundary_distance_threshold=boundary_distance_threshold, 1517 ) 1518 1519 # 2.) Apply the predictor to the prompts. 1520 shape = foreground.shape 1521 if prompts is None: # No prompts were derived, we can't do much further and return empty masks. 1522 return np.zeros(shape, dtype="uint32") if output_mode == "instance_segmentation" else [] 1523 else: 1524 if optimize_memory: 1525 prompts.update(dict( 1526 min_size=min_size, nms_thresh=nms_threshold, intersection_over_min=intersection_over_min 1527 )) 1528 predictions = batched_tiled_inference( 1529 self._predictor, 1530 image=None, 1531 batch_size=batch_size, 1532 image_embeddings=self._image_embeddings, 1533 return_instance_segmentation=False, 1534 multimasking=multimasking, 1535 optimize_memory=optimize_memory, 1536 i=getattr(self, "_i", None), 1537 **prompts 1538 ) 1539 # Optimize memory directly returns an instance segmentation and does not 1540 # allow for any further refinements. 1541 if optimize_memory: 1542 return predictions 1543 1544 # 3.) Refine the segmentation with box prompts. 1545 if refine_with_box_prompts: 1546 # TODO 1547 raise NotImplementedError 1548 1549 # 4.) Apply non-max suppression to the masks. 1550 segmentation = util.apply_nms( 1551 predictions, shape=shape, min_size=min_size, nms_thresh=nms_threshold, 1552 intersection_over_min=intersection_over_min, 1553 ) 1554 if output_mode != "instance_segmentation": 1555 segmentation = self._to_masks(segmentation, output_mode) 1556 return segmentation 1557 1558 # Set state and get state are not implemented yet, as this generator relies on having the image embeddings 1559 # in the state. However, they should not be serialized here and we have to address this a bit differently. 1560 def get_state(self): 1561 """@private 1562 """ 1563 raise NotImplementedError 1564 1565 def set_state(self, state): 1566 """@private 1567 """ 1568 raise NotImplementedError 1569 1570 1571def get_instance_segmentation_generator( 1572 predictor: SamPredictor, 1573 is_tiled: bool, 1574 decoder: Optional[torch.nn.Module] = None, 1575 segmentation_mode: Optional[str] = None, 1576 **kwargs, 1577) -> Union[AMGBase, InstanceSegmentationWithDecoder]: 1578 f"""Get the automatic mask generator. 1579 1580 Args: 1581 predictor: The segment anything predictor. 1582 is_tiled: Whether tiled embeddings are used. 1583 decoder: Decoder to predict instacne segmmentation. 1584 segmentation_mode: The segmentation mode. One of 'amg', 'ais', or 'apg'. 1585 By default, '{DEFAULT_SEGMENTATION_MODE_WITH_DECODER}' is used if a decoder is passed, 1586 otherwise 'amg' is used. 1587 kwargs: The keyword arguments of the segmentation genetator class. 1588 1589 Returns: 1590 The segmentation generator instance. 1591 """ 1592 # Choose the segmentation decoder default depending on whether we have a decoder. 1593 if segmentation_mode is None: 1594 segmentation_mode = "amg" if decoder is None else DEFAULT_SEGMENTATION_MODE_WITH_DECODER 1595 1596 if segmentation_mode.lower() == "amg": 1597 segmenter_class = TiledAutomaticMaskGenerator if is_tiled else AutomaticMaskGenerator 1598 segmenter = segmenter_class(predictor, **kwargs) 1599 elif segmentation_mode.lower() == "ais": 1600 assert decoder is not None 1601 segmenter_class = TiledInstanceSegmentationWithDecoder if is_tiled else InstanceSegmentationWithDecoder 1602 segmenter = segmenter_class(predictor, decoder, **kwargs) 1603 elif segmentation_mode.lower() == "apg": 1604 assert decoder is not None 1605 segmenter_class = TiledAutomaticPromptGenerator if is_tiled else AutomaticPromptGenerator 1606 segmenter = segmenter_class(predictor, decoder, **kwargs) 1607 else: 1608 raise ValueError(f"Invalid segmentation_mode: {segmentation_mode}. Choose one of 'amg', 'ais', or 'apg'.") 1609 1610 return segmenter
59class AMGBase(ABC): 60 """Base class for the automatic mask generators. 61 """ 62 def __init__(self): 63 # the state that has to be computed by the 'initialize' method of the child classes 64 self._is_initialized = False 65 self._crop_list = None 66 self._crop_boxes = None 67 self._original_size = None 68 69 @property 70 def is_initialized(self): 71 """Whether the mask generator has already been initialized. 72 """ 73 return self._is_initialized 74 75 @property 76 def crop_list(self): 77 """The list of mask data after initialization. 78 """ 79 return self._crop_list 80 81 @property 82 def crop_boxes(self): 83 """The list of crop boxes. 84 """ 85 return self._crop_boxes 86 87 @property 88 def original_size(self): 89 """The original image size. 90 """ 91 return self._original_size 92 93 def _postprocess_batch( 94 self, 95 data, 96 crop_box, 97 original_size, 98 pred_iou_thresh, 99 stability_score_thresh, 100 box_nms_thresh, 101 ): 102 orig_h, orig_w = original_size 103 104 # filter by predicted IoU 105 if pred_iou_thresh > 0.0: 106 keep_mask = data["iou_preds"] > pred_iou_thresh 107 data.filter(keep_mask) 108 109 # filter by stability score 110 if stability_score_thresh > 0.0: 111 keep_mask = data["stability_score"] >= stability_score_thresh 112 data.filter(keep_mask) 113 114 # filter boxes that touch crop boundaries 115 keep_mask = ~amg_utils.is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) 116 if not torch.all(keep_mask): 117 data.filter(keep_mask) 118 119 # remove duplicates within this crop. 120 keep_by_nms = batched_nms( 121 data["boxes"].float(), 122 data["iou_preds"], 123 torch.zeros_like(data["boxes"][:, 0]), # categories 124 iou_threshold=box_nms_thresh, 125 ) 126 data.filter(keep_by_nms) 127 128 # return to the original image frame 129 data["boxes"] = amg_utils.uncrop_boxes_xyxy(data["boxes"], crop_box) 130 data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) 131 # the data from embedding based segmentation doesn't have the points 132 # so we skip if the corresponding key can't be found 133 try: 134 data["points"] = amg_utils.uncrop_points(data["points"], crop_box) 135 except KeyError: 136 pass 137 138 return data 139 140 def _postprocess_small_regions(self, mask_data, min_area, nms_thresh): 141 142 if len(mask_data["rles"]) == 0: 143 return mask_data 144 145 # filter small disconnected regions and holes 146 new_masks = [] 147 scores = [] 148 for rle in mask_data["rles"]: 149 mask = amg_utils.rle_to_mask(rle) 150 151 mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="holes") 152 unchanged = not changed 153 mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="islands") 154 unchanged = unchanged and not changed 155 156 new_masks.append(torch.as_tensor(mask, dtype=torch.int).unsqueeze(0)) 157 # give score=0 to changed masks and score=1 to unchanged masks 158 # so NMS will prefer ones that didn't need postprocessing 159 scores.append(float(unchanged)) 160 161 # recalculate boxes and remove any new duplicates 162 masks = torch.cat(new_masks, dim=0) 163 boxes = batched_mask_to_box(masks.to(torch.bool)) # Casting this to boolean as we work with one-hot labels. 164 keep_by_nms = batched_nms( 165 boxes.float(), 166 torch.as_tensor(scores, dtype=torch.float), 167 torch.zeros_like(boxes[:, 0]), # categories 168 iou_threshold=nms_thresh, 169 ) 170 171 # only recalculate RLEs for masks that have changed 172 for i_mask in keep_by_nms: 173 if scores[i_mask] == 0.0: 174 mask_torch = masks[i_mask].unsqueeze(0) 175 # mask_data["rles"][i_mask] = amg_utils.mask_to_rle_pytorch(mask_torch)[0] 176 mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] 177 mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly 178 mask_data.filter(keep_by_nms) 179 180 return mask_data 181 182 def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode): 183 # filter small disconnected regions and holes in masks 184 if min_mask_region_area > 0: 185 mask_data = self._postprocess_small_regions( 186 mask_data, 187 min_mask_region_area, 188 max(box_nms_thresh, crop_nms_thresh), 189 ) 190 191 # encode masks 192 if output_mode == "coco_rle": 193 mask_data["segmentations"] = [amg_utils.coco_encode_rle(rle) for rle in mask_data["rles"]] 194 # We require the binary mask output for creating the instance-seg output. 195 elif output_mode in ("binary_mask", "instance_segmentation"): 196 mask_data["segmentations"] = [amg_utils.rle_to_mask(rle) for rle in mask_data["rles"]] 197 elif output_mode == "rle": 198 mask_data["segmentations"] = mask_data["rles"] 199 else: 200 raise ValueError(f"Invalid output mode {output_mode}.") 201 202 # write mask records 203 curr_anns = [] 204 for idx in range(len(mask_data["segmentations"])): 205 ann = { 206 "segmentation": mask_data["segmentations"][idx], 207 "area": amg_utils.area_from_rle(mask_data["rles"][idx]), 208 "bbox": amg_utils.box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), 209 "predicted_iou": mask_data["iou_preds"][idx].item(), 210 "stability_score": mask_data["stability_score"][idx].item(), 211 "crop_box": amg_utils.box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), 212 } 213 # the data from embedding based segmentation doesn't have the points 214 # so we skip if the corresponding key can't be found 215 try: 216 ann["point_coords"] = [mask_data["points"][idx].tolist()] 217 except KeyError: 218 pass 219 curr_anns.append(ann) 220 221 return curr_anns 222 223 def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None): 224 orig_h, orig_w = original_size 225 226 # serialize predictions and store in MaskData 227 data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1)) 228 if points is not None: 229 data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0), dtype=torch.float) 230 231 del masks 232 233 # calculate the stability scores 234 data["stability_score"] = amg_utils.calculate_stability_score( 235 data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset 236 ) 237 238 # threshold masks and calculate boxes 239 data["masks"] = data["masks"] > self._predictor.model.mask_threshold 240 data["masks"] = data["masks"].type(torch.bool) 241 data["boxes"] = batched_mask_to_box(data["masks"]) 242 243 # compress to RLE 244 data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w) 245 # data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"]) 246 data["rles"] = mask_to_rle_pytorch(data["masks"]) 247 del data["masks"] 248 249 return data 250 251 def get_state(self) -> Dict[str, Any]: 252 """Get the initialized state of the mask generator. 253 254 Returns: 255 State of the mask generator. 256 """ 257 if not self.is_initialized: 258 raise RuntimeError("The state has not been computed yet. Call initialize first.") 259 260 return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size} 261 262 def set_state(self, state: Dict[str, Any]) -> None: 263 """Set the state of the mask generator. 264 265 Args: 266 state: The state of the mask generator, e.g. from serialized state. 267 """ 268 self._crop_list = state["crop_list"] 269 self._crop_boxes = state["crop_boxes"] 270 self._original_size = state["original_size"] 271 self._is_initialized = True 272 273 def clear_state(self): 274 """Clear the state of the mask generator. 275 """ 276 self._crop_list = None 277 self._crop_boxes = None 278 self._original_size = None 279 self._is_initialized = False
Base class for the automatic mask generators.
69 @property 70 def is_initialized(self): 71 """Whether the mask generator has already been initialized. 72 """ 73 return self._is_initialized
Whether the mask generator has already been initialized.
75 @property 76 def crop_list(self): 77 """The list of mask data after initialization. 78 """ 79 return self._crop_list
The list of mask data after initialization.
81 @property 82 def crop_boxes(self): 83 """The list of crop boxes. 84 """ 85 return self._crop_boxes
The list of crop boxes.
87 @property 88 def original_size(self): 89 """The original image size. 90 """ 91 return self._original_size
The original image size.
251 def get_state(self) -> Dict[str, Any]: 252 """Get the initialized state of the mask generator. 253 254 Returns: 255 State of the mask generator. 256 """ 257 if not self.is_initialized: 258 raise RuntimeError("The state has not been computed yet. Call initialize first.") 259 260 return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size}
Get the initialized state of the mask generator.
Returns:
State of the mask generator.
262 def set_state(self, state: Dict[str, Any]) -> None: 263 """Set the state of the mask generator. 264 265 Args: 266 state: The state of the mask generator, e.g. from serialized state. 267 """ 268 self._crop_list = state["crop_list"] 269 self._crop_boxes = state["crop_boxes"] 270 self._original_size = state["original_size"] 271 self._is_initialized = True
Set the state of the mask generator.
Arguments:
- state: The state of the mask generator, e.g. from serialized state.
282class AutomaticMaskGenerator(AMGBase): 283 """Generates an instance segmentation without prompts, using a point grid. 284 285 This class implements the same logic as 286 https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py 287 It decouples the computationally expensive steps of generating masks from the cheap post-processing operation 288 to filter these masks to enable grid search and interactively changing the post-processing. 289 290 Use this class as follows: 291 ```python 292 amg = AutomaticMaskGenerator(predictor) 293 amg.initialize(image) # Initialize the masks, this takes care of all expensive computations. 294 masks = amg.generate(pred_iou_thresh=0.8) # Generate the masks. This is fast and enables testing parameters 295 ``` 296 297 Args: 298 predictor: The segment anything predictor. 299 points_per_side: The number of points to be sampled along one side of the image. 300 If None, `point_grids` must provide explicit point sampling. By default, set to '32'. 301 points_per_batch: The number of points run simultaneously by the model. 302 Higher numbers may be faster but use more GPU memory. 303 By default, selects '64' for all devices except 'mps' (selects '16' for performance reasons). 304 crop_n_layers: If >0, the mask prediction will be run again on crops of the image. 305 By default, set to '0'. 306 crop_overlap_ratio: Sets the degree to which crops overlap. By default, set to '512 / 1500'. 307 crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops. 308 By default, set to '1'. 309 point_grids: A list over explicit grids of points used for sampling masks. 310 Normalized to [0, 1] with respect to the image coordinate system. 311 stability_score_offset: The amount to shift the cutoff when calculating the stability score. 312 By default, set to '1.0'. 313 """ 314 def __init__( 315 self, 316 predictor: SamPredictor, 317 points_per_side: Optional[int] = 32, 318 points_per_batch: Optional[int] = None, 319 crop_n_layers: int = 0, 320 crop_overlap_ratio: float = 512 / 1500, 321 crop_n_points_downscale_factor: int = 1, 322 point_grids: Optional[List[np.ndarray]] = None, 323 stability_score_offset: float = 1.0, 324 ): 325 super().__init__() 326 327 if points_per_side is not None: 328 self.point_grids = amg_utils.build_all_layer_point_grids( 329 points_per_side, crop_n_layers, crop_n_points_downscale_factor, 330 ) 331 elif point_grids is not None: 332 self.point_grids = point_grids 333 else: 334 raise ValueError("Can't have both points_per_side and point_grid be None or not None.") 335 336 self._predictor = predictor 337 self._points_per_side = points_per_side 338 339 # we set the points per batch to 16 for mps for performance reasons 340 # and otherwise keep them at the default of 64 341 if points_per_batch is None: 342 points_per_batch = 16 if str(predictor.device) == "mps" else 64 343 self._points_per_batch = points_per_batch 344 345 self._crop_n_layers = crop_n_layers 346 self._crop_overlap_ratio = crop_overlap_ratio 347 self._crop_n_points_downscale_factor = crop_n_points_downscale_factor 348 self._stability_score_offset = stability_score_offset 349 350 def _process_batch(self, points, im_size, crop_box, original_size): 351 # run model on this batch 352 transformed_points = self._predictor.transform.apply_coords(points, im_size) 353 in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float) 354 in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) 355 masks, iou_preds, _ = self._predictor.predict_torch( 356 point_coords=in_points[:, None, :], 357 point_labels=in_labels[:, None], 358 multimask_output=True, 359 return_logits=True, 360 ) 361 data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points) 362 del masks 363 return data 364 365 def _process_crop(self, image, crop_box, crop_layer_idx, precomputed_embeddings, pbar_init=None, pbar_update=None): 366 # Crop the image and calculate embeddings. 367 x0, y0, x1, y1 = crop_box 368 cropped_im = image[y0:y1, x0:x1, :] 369 cropped_im_size = cropped_im.shape[:2] 370 371 if not precomputed_embeddings: 372 self._predictor.set_image(cropped_im) 373 374 # Get the points for this crop. 375 points_scale = np.array(cropped_im_size)[None, ::-1] 376 points_for_image = self.point_grids[crop_layer_idx] * points_scale 377 378 # Generate masks for this crop in batches. 379 data = amg_utils.MaskData() 380 n_batches = len(points_for_image) // self._points_per_batch +\ 381 int(len(points_for_image) % self._points_per_batch != 0) 382 if pbar_init is not None: 383 pbar_init(n_batches, "Predict masks for point grid prompts") 384 385 for (points,) in amg_utils.batch_iterator(self._points_per_batch, points_for_image): 386 batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size) 387 data.cat(batch_data) 388 del batch_data 389 if pbar_update is not None: 390 pbar_update(1) 391 392 if not precomputed_embeddings: 393 self._predictor.reset_image() 394 395 return data 396 397 @torch.no_grad() 398 def initialize( 399 self, 400 image: np.ndarray, 401 image_embeddings: Optional[util.ImageEmbeddings] = None, 402 i: Optional[int] = None, 403 verbose: bool = False, 404 pbar_init: Optional[callable] = None, 405 pbar_update: Optional[callable] = None, 406 ) -> None: 407 """Initialize image embeddings and masks for an image. 408 409 Args: 410 image: The input image, volume or timeseries. 411 image_embeddings: Optional precomputed image embeddings. 412 See `util.precompute_image_embeddings` for details. 413 i: Index for the image data. Required if `image` has three spatial dimensions 414 or a time dimension and two spatial dimensions. 415 verbose: Whether to print computation progress. By default, set to 'False'. 416 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 417 Can be used together with pbar_update to handle napari progress bar in other thread. 418 To enables using this function within a threadworker. 419 pbar_update: Callback to update an external progress bar. 420 """ 421 original_size = image.shape[:2] 422 self._original_size = original_size 423 424 crop_boxes, layer_idxs = amg_utils.generate_crop_boxes( 425 original_size, self._crop_n_layers, self._crop_overlap_ratio 426 ) 427 428 # We can set fixed image embeddings if we only have a single crop box (the default setting). 429 # Otherwise we have to recompute the embeddings for each crop and can't precompute. 430 if len(crop_boxes) == 1: 431 if image_embeddings is None: 432 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 433 util.set_precomputed(self._predictor, image_embeddings, i=i) 434 precomputed_embeddings = True 435 else: 436 precomputed_embeddings = False 437 438 # we need to cast to the image representation that is compatible with SAM 439 image = util._to_image(image) 440 441 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 442 443 crop_list = [] 444 for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 445 crop_data = self._process_crop( 446 image, crop_box, layer_idx, 447 precomputed_embeddings=precomputed_embeddings, 448 pbar_init=pbar_init, pbar_update=pbar_update, 449 ) 450 crop_list.append(crop_data) 451 pbar_close() 452 453 self._is_initialized = True 454 self._crop_list = crop_list 455 self._crop_boxes = crop_boxes 456 457 @torch.no_grad() 458 def generate( 459 self, 460 pred_iou_thresh: float = 0.88, 461 stability_score_thresh: float = 0.95, 462 box_nms_thresh: float = 0.7, 463 crop_nms_thresh: float = 0.7, 464 min_mask_region_area: int = 0, 465 output_mode: str = "instance_segmentation", 466 with_background: bool = True, 467 ) -> Union[List[Dict[str, Any]], np.ndarray]: 468 """Generate instance segmentation for the currently initialized image. 469 470 Args: 471 pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model. 472 By default, set to '0.88'. 473 stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask 474 under changes to the cutoff used to binarize the model prediction. By default, set to '0.95'. 475 box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks. 476 By default, set to '0.7'. 477 crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops. 478 By default, set to '0.7'. 479 min_mask_region_area: Minimal size for the predicted masks. By default, set to '0'. 480 output_mode: The form masks are returned in. Possible values are: 481 - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks. 482 - 'coco_rle': Return a list of dictionaries with run-length encoded masks in MS COCO format. 483 - 'instance_segmentation': Return masks merged into an instance segmentation in a single array. 484 - 'rle': Return a list of dictionaries with run-length encoded masks. 485 By default, set to 'instance_segmentation'. 486 with_background: Whether to remove the largest object, which often covers the background. 487 488 Returns: 489 The segmentation masks. 490 """ 491 if not self.is_initialized: 492 raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.") 493 494 data = amg_utils.MaskData() 495 for data_, crop_box in zip(self.crop_list, self.crop_boxes): 496 crop_data = self._postprocess_batch( 497 data=deepcopy(data_), 498 crop_box=crop_box, original_size=self.original_size, 499 pred_iou_thresh=pred_iou_thresh, 500 stability_score_thresh=stability_score_thresh, 501 box_nms_thresh=box_nms_thresh 502 ) 503 data.cat(crop_data) 504 505 if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0: 506 # Prefer masks from smaller crops 507 scores = 1 / box_area(data["crop_boxes"]) 508 scores = scores.to(data["boxes"].device) 509 keep_by_nms = batched_nms( 510 data["boxes"].float(), 511 scores, 512 torch.zeros_like(data["boxes"][:, 0]), # categories 513 iou_threshold=crop_nms_thresh, 514 ) 515 data.filter(keep_by_nms) 516 517 data.to_numpy() 518 masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode) 519 if output_mode == "instance_segmentation": 520 shape = next(iter(masks))["segmentation"].shape if len(masks) > 0 else self.original_size 521 masks = util.mask_data_to_segmentation( 522 masks, shape=shape, with_background=with_background, merge_exclusively=False 523 ) 524 return masks
Generates an instance segmentation without prompts, using a point grid.
This class implements the same logic as https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py It decouples the computationally expensive steps of generating masks from the cheap post-processing operation to filter these masks to enable grid search and interactively changing the post-processing.
Use this class as follows:
amg = AutomaticMaskGenerator(predictor)
amg.initialize(image) # Initialize the masks, this takes care of all expensive computations.
masks = amg.generate(pred_iou_thresh=0.8) # Generate the masks. This is fast and enables testing parameters
Arguments:
- predictor: The segment anything predictor.
- points_per_side: The number of points to be sampled along one side of the image.
If None,
point_gridsmust provide explicit point sampling. By default, set to '32'. - points_per_batch: The number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory. By default, selects '64' for all devices except 'mps' (selects '16' for performance reasons).
- crop_n_layers: If >0, the mask prediction will be run again on crops of the image. By default, set to '0'.
- crop_overlap_ratio: Sets the degree to which crops overlap. By default, set to '512 / 1500'.
- crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops. By default, set to '1'.
- point_grids: A list over explicit grids of points used for sampling masks. Normalized to [0, 1] with respect to the image coordinate system.
- stability_score_offset: The amount to shift the cutoff when calculating the stability score. By default, set to '1.0'.
314 def __init__( 315 self, 316 predictor: SamPredictor, 317 points_per_side: Optional[int] = 32, 318 points_per_batch: Optional[int] = None, 319 crop_n_layers: int = 0, 320 crop_overlap_ratio: float = 512 / 1500, 321 crop_n_points_downscale_factor: int = 1, 322 point_grids: Optional[List[np.ndarray]] = None, 323 stability_score_offset: float = 1.0, 324 ): 325 super().__init__() 326 327 if points_per_side is not None: 328 self.point_grids = amg_utils.build_all_layer_point_grids( 329 points_per_side, crop_n_layers, crop_n_points_downscale_factor, 330 ) 331 elif point_grids is not None: 332 self.point_grids = point_grids 333 else: 334 raise ValueError("Can't have both points_per_side and point_grid be None or not None.") 335 336 self._predictor = predictor 337 self._points_per_side = points_per_side 338 339 # we set the points per batch to 16 for mps for performance reasons 340 # and otherwise keep them at the default of 64 341 if points_per_batch is None: 342 points_per_batch = 16 if str(predictor.device) == "mps" else 64 343 self._points_per_batch = points_per_batch 344 345 self._crop_n_layers = crop_n_layers 346 self._crop_overlap_ratio = crop_overlap_ratio 347 self._crop_n_points_downscale_factor = crop_n_points_downscale_factor 348 self._stability_score_offset = stability_score_offset
397 @torch.no_grad() 398 def initialize( 399 self, 400 image: np.ndarray, 401 image_embeddings: Optional[util.ImageEmbeddings] = None, 402 i: Optional[int] = None, 403 verbose: bool = False, 404 pbar_init: Optional[callable] = None, 405 pbar_update: Optional[callable] = None, 406 ) -> None: 407 """Initialize image embeddings and masks for an image. 408 409 Args: 410 image: The input image, volume or timeseries. 411 image_embeddings: Optional precomputed image embeddings. 412 See `util.precompute_image_embeddings` for details. 413 i: Index for the image data. Required if `image` has three spatial dimensions 414 or a time dimension and two spatial dimensions. 415 verbose: Whether to print computation progress. By default, set to 'False'. 416 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 417 Can be used together with pbar_update to handle napari progress bar in other thread. 418 To enables using this function within a threadworker. 419 pbar_update: Callback to update an external progress bar. 420 """ 421 original_size = image.shape[:2] 422 self._original_size = original_size 423 424 crop_boxes, layer_idxs = amg_utils.generate_crop_boxes( 425 original_size, self._crop_n_layers, self._crop_overlap_ratio 426 ) 427 428 # We can set fixed image embeddings if we only have a single crop box (the default setting). 429 # Otherwise we have to recompute the embeddings for each crop and can't precompute. 430 if len(crop_boxes) == 1: 431 if image_embeddings is None: 432 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 433 util.set_precomputed(self._predictor, image_embeddings, i=i) 434 precomputed_embeddings = True 435 else: 436 precomputed_embeddings = False 437 438 # we need to cast to the image representation that is compatible with SAM 439 image = util._to_image(image) 440 441 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 442 443 crop_list = [] 444 for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 445 crop_data = self._process_crop( 446 image, crop_box, layer_idx, 447 precomputed_embeddings=precomputed_embeddings, 448 pbar_init=pbar_init, pbar_update=pbar_update, 449 ) 450 crop_list.append(crop_data) 451 pbar_close() 452 453 self._is_initialized = True 454 self._crop_list = crop_list 455 self._crop_boxes = crop_boxes
Initialize image embeddings and masks for an image.
Arguments:
- image: The input image, volume or timeseries.
- image_embeddings: Optional precomputed image embeddings.
See
util.precompute_image_embeddingsfor details. - i: Index for the image data. Required if
imagehas three spatial dimensions or a time dimension and two spatial dimensions. - verbose: Whether to print computation progress. By default, set to 'False'.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
457 @torch.no_grad() 458 def generate( 459 self, 460 pred_iou_thresh: float = 0.88, 461 stability_score_thresh: float = 0.95, 462 box_nms_thresh: float = 0.7, 463 crop_nms_thresh: float = 0.7, 464 min_mask_region_area: int = 0, 465 output_mode: str = "instance_segmentation", 466 with_background: bool = True, 467 ) -> Union[List[Dict[str, Any]], np.ndarray]: 468 """Generate instance segmentation for the currently initialized image. 469 470 Args: 471 pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model. 472 By default, set to '0.88'. 473 stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask 474 under changes to the cutoff used to binarize the model prediction. By default, set to '0.95'. 475 box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks. 476 By default, set to '0.7'. 477 crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops. 478 By default, set to '0.7'. 479 min_mask_region_area: Minimal size for the predicted masks. By default, set to '0'. 480 output_mode: The form masks are returned in. Possible values are: 481 - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks. 482 - 'coco_rle': Return a list of dictionaries with run-length encoded masks in MS COCO format. 483 - 'instance_segmentation': Return masks merged into an instance segmentation in a single array. 484 - 'rle': Return a list of dictionaries with run-length encoded masks. 485 By default, set to 'instance_segmentation'. 486 with_background: Whether to remove the largest object, which often covers the background. 487 488 Returns: 489 The segmentation masks. 490 """ 491 if not self.is_initialized: 492 raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.") 493 494 data = amg_utils.MaskData() 495 for data_, crop_box in zip(self.crop_list, self.crop_boxes): 496 crop_data = self._postprocess_batch( 497 data=deepcopy(data_), 498 crop_box=crop_box, original_size=self.original_size, 499 pred_iou_thresh=pred_iou_thresh, 500 stability_score_thresh=stability_score_thresh, 501 box_nms_thresh=box_nms_thresh 502 ) 503 data.cat(crop_data) 504 505 if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0: 506 # Prefer masks from smaller crops 507 scores = 1 / box_area(data["crop_boxes"]) 508 scores = scores.to(data["boxes"].device) 509 keep_by_nms = batched_nms( 510 data["boxes"].float(), 511 scores, 512 torch.zeros_like(data["boxes"][:, 0]), # categories 513 iou_threshold=crop_nms_thresh, 514 ) 515 data.filter(keep_by_nms) 516 517 data.to_numpy() 518 masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode) 519 if output_mode == "instance_segmentation": 520 shape = next(iter(masks))["segmentation"].shape if len(masks) > 0 else self.original_size 521 masks = util.mask_data_to_segmentation( 522 masks, shape=shape, with_background=with_background, merge_exclusively=False 523 ) 524 return masks
Generate instance segmentation for the currently initialized image.
Arguments:
- pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model. By default, set to '0.88'.
- stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask under changes to the cutoff used to binarize the model prediction. By default, set to '0.95'.
- box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks. By default, set to '0.7'.
- crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops. By default, set to '0.7'.
- min_mask_region_area: Minimal size for the predicted masks. By default, set to '0'.
- output_mode: The form masks are returned in. Possible values are:
- 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
- 'coco_rle': Return a list of dictionaries with run-length encoded masks in MS COCO format.
- 'instance_segmentation': Return masks merged into an instance segmentation in a single array.
- 'rle': Return a list of dictionaries with run-length encoded masks. By default, set to 'instance_segmentation'.
- with_background: Whether to remove the largest object, which often covers the background.
Returns:
The segmentation masks.
Inherited Members
558class TiledAutomaticMaskGenerator(AutomaticMaskGenerator): 559 """Generates an instance segmentation without prompts, using a point grid. 560 561 Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings. 562 563 Args: 564 predictor: The Segment Anything predictor. 565 points_per_side: The number of points to be sampled along one side of the image. 566 If None, `point_grids` must provide explicit point sampling. By default, set to '32'. 567 points_per_batch: The number of points run simultaneously by the model. 568 Higher numbers may be faster but use more GPU memory. By default, set to '64'. 569 point_grids: A list over explicit grids of points used for sampling masks. 570 Normalized to [0, 1] with respect to the image coordinate system. 571 stability_score_offset: The amount to shift the cutoff when calculating the stability score. 572 By default, set to '1.0'. 573 """ 574 575 # We only expose the arguments that make sense for the tiled mask generator. 576 # Anything related to crops doesn't make sense, because we re-use that functionality 577 # for tiling, so these parameters wouldn't have any effect. 578 def __init__( 579 self, 580 predictor: SamPredictor, 581 points_per_side: Optional[int] = 32, 582 points_per_batch: int = 64, 583 point_grids: Optional[List[np.ndarray]] = None, 584 stability_score_offset: float = 1.0, 585 ) -> None: 586 super().__init__( 587 predictor=predictor, 588 points_per_side=points_per_side, 589 points_per_batch=points_per_batch, 590 point_grids=point_grids, 591 stability_score_offset=stability_score_offset, 592 ) 593 594 @torch.no_grad() 595 def initialize( 596 self, 597 image: np.ndarray, 598 image_embeddings: Optional[util.ImageEmbeddings] = None, 599 i: Optional[int] = None, 600 tile_shape: Optional[Tuple[int, int]] = None, 601 halo: Optional[Tuple[int, int]] = None, 602 verbose: bool = False, 603 pbar_init: Optional[callable] = None, 604 pbar_update: Optional[callable] = None, 605 batch_size: int = 1, 606 mask: Optional[np.typing.ArrayLike] = None, 607 ) -> None: 608 """Initialize image embeddings and masks for an image. 609 610 Args: 611 image: The input image, volume or timeseries. 612 image_embeddings: Optional precomputed image embeddings. 613 See `util.precompute_image_embeddings` for details. 614 i: Index for the image data. Required if `image` has three spatial dimensions 615 or a time dimension and two spatial dimensions. 616 tile_shape: The tile shape for embedding prediction. 617 halo: The overlap of between tiles. 618 verbose: Whether to print computation progress. By default, set to 'False'. 619 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 620 Can be used together with pbar_update to handle napari progress bar in other thread. 621 To enables using this function within a threadworker. 622 pbar_update: Callback to update an external progress bar. 623 batch_size: The batch size for image embedding prediction. By default, set to '1'. 624 mask: An optional mask to define areas that are ignored in the segmentation. 625 """ 626 original_size = image.shape[:2] 627 self._original_size = original_size 628 629 self._image_embeddings, tile_shape, halo, tiles_in_mask = _process_tiled_embeddings( 630 self._predictor, image, image_embeddings, tile_shape, halo, 631 verbose=verbose, batch_size=batch_size, mask=mask, i=i, 632 ) 633 634 tiling = blocking([0, 0], original_size, tile_shape) 635 if tiles_in_mask is None: 636 n_tiles = tiling.numberOfBlocks 637 tile_ids = range(n_tiles) 638 else: 639 n_tiles = len(tiles_in_mask) 640 tile_ids = tiles_in_mask 641 642 # The crop box is always the full local tile. 643 tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in tile_ids] 644 crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles] 645 646 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 647 pbar_init(n_tiles, "Compute masks for tile") 648 649 # We need to cast to the image representation that is compatible with SAM. 650 image = util._to_image(image) 651 652 mask_data = [] 653 for idx, tile_id in enumerate(tile_ids): 654 # set the pre-computed embeddings for this tile 655 features = image_embeddings["features"][str(tile_id)] 656 tile_embeddings = { 657 "features": features, 658 "input_size": features.attrs["input_size"], 659 "original_size": features.attrs["original_size"], 660 } 661 util.set_precomputed(self._predictor, tile_embeddings, i) 662 663 # Compute the mask data for this tile and append it 664 this_mask_data = self._process_crop( 665 image, crop_box=crop_boxes[idx], crop_layer_idx=0, precomputed_embeddings=True 666 ) 667 mask_data.append(this_mask_data) 668 pbar_update(1) 669 pbar_close() 670 671 # set the initialized data 672 self._is_initialized = True 673 self._crop_list = mask_data 674 self._crop_boxes = crop_boxes
Generates an instance segmentation without prompts, using a point grid.
Implements the same functionality as AutomaticMaskGenerator but for tiled embeddings.
Arguments:
- predictor: The Segment Anything predictor.
- points_per_side: The number of points to be sampled along one side of the image.
If None,
point_gridsmust provide explicit point sampling. By default, set to '32'. - points_per_batch: The number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory. By default, set to '64'.
- point_grids: A list over explicit grids of points used for sampling masks. Normalized to [0, 1] with respect to the image coordinate system.
- stability_score_offset: The amount to shift the cutoff when calculating the stability score. By default, set to '1.0'.
578 def __init__( 579 self, 580 predictor: SamPredictor, 581 points_per_side: Optional[int] = 32, 582 points_per_batch: int = 64, 583 point_grids: Optional[List[np.ndarray]] = None, 584 stability_score_offset: float = 1.0, 585 ) -> None: 586 super().__init__( 587 predictor=predictor, 588 points_per_side=points_per_side, 589 points_per_batch=points_per_batch, 590 point_grids=point_grids, 591 stability_score_offset=stability_score_offset, 592 )
594 @torch.no_grad() 595 def initialize( 596 self, 597 image: np.ndarray, 598 image_embeddings: Optional[util.ImageEmbeddings] = None, 599 i: Optional[int] = None, 600 tile_shape: Optional[Tuple[int, int]] = None, 601 halo: Optional[Tuple[int, int]] = None, 602 verbose: bool = False, 603 pbar_init: Optional[callable] = None, 604 pbar_update: Optional[callable] = None, 605 batch_size: int = 1, 606 mask: Optional[np.typing.ArrayLike] = None, 607 ) -> None: 608 """Initialize image embeddings and masks for an image. 609 610 Args: 611 image: The input image, volume or timeseries. 612 image_embeddings: Optional precomputed image embeddings. 613 See `util.precompute_image_embeddings` for details. 614 i: Index for the image data. Required if `image` has three spatial dimensions 615 or a time dimension and two spatial dimensions. 616 tile_shape: The tile shape for embedding prediction. 617 halo: The overlap of between tiles. 618 verbose: Whether to print computation progress. By default, set to 'False'. 619 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 620 Can be used together with pbar_update to handle napari progress bar in other thread. 621 To enables using this function within a threadworker. 622 pbar_update: Callback to update an external progress bar. 623 batch_size: The batch size for image embedding prediction. By default, set to '1'. 624 mask: An optional mask to define areas that are ignored in the segmentation. 625 """ 626 original_size = image.shape[:2] 627 self._original_size = original_size 628 629 self._image_embeddings, tile_shape, halo, tiles_in_mask = _process_tiled_embeddings( 630 self._predictor, image, image_embeddings, tile_shape, halo, 631 verbose=verbose, batch_size=batch_size, mask=mask, i=i, 632 ) 633 634 tiling = blocking([0, 0], original_size, tile_shape) 635 if tiles_in_mask is None: 636 n_tiles = tiling.numberOfBlocks 637 tile_ids = range(n_tiles) 638 else: 639 n_tiles = len(tiles_in_mask) 640 tile_ids = tiles_in_mask 641 642 # The crop box is always the full local tile. 643 tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in tile_ids] 644 crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles] 645 646 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 647 pbar_init(n_tiles, "Compute masks for tile") 648 649 # We need to cast to the image representation that is compatible with SAM. 650 image = util._to_image(image) 651 652 mask_data = [] 653 for idx, tile_id in enumerate(tile_ids): 654 # set the pre-computed embeddings for this tile 655 features = image_embeddings["features"][str(tile_id)] 656 tile_embeddings = { 657 "features": features, 658 "input_size": features.attrs["input_size"], 659 "original_size": features.attrs["original_size"], 660 } 661 util.set_precomputed(self._predictor, tile_embeddings, i) 662 663 # Compute the mask data for this tile and append it 664 this_mask_data = self._process_crop( 665 image, crop_box=crop_boxes[idx], crop_layer_idx=0, precomputed_embeddings=True 666 ) 667 mask_data.append(this_mask_data) 668 pbar_update(1) 669 pbar_close() 670 671 # set the initialized data 672 self._is_initialized = True 673 self._crop_list = mask_data 674 self._crop_boxes = crop_boxes
Initialize image embeddings and masks for an image.
Arguments:
- image: The input image, volume or timeseries.
- image_embeddings: Optional precomputed image embeddings.
See
util.precompute_image_embeddingsfor details. - i: Index for the image data. Required if
imagehas three spatial dimensions or a time dimension and two spatial dimensions. - tile_shape: The tile shape for embedding prediction.
- halo: The overlap of between tiles.
- verbose: Whether to print computation progress. By default, set to 'False'.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
- batch_size: The batch size for image embedding prediction. By default, set to '1'.
- mask: An optional mask to define areas that are ignored in the segmentation.
682class DecoderAdapter(torch.nn.Module): 683 """Adapter to contain the UNETR decoder in a single module. 684 685 To apply the decoder on top of pre-computed embeddings for the segmentation functionality. 686 See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py 687 """ 688 def __init__(self, unetr: torch.nn.Module): 689 super().__init__() 690 691 self.base = unetr.base 692 self.out_conv = unetr.out_conv 693 self.deconv_out = unetr.deconv_out 694 self.decoder_head = unetr.decoder_head 695 self.final_activation = unetr.final_activation 696 self.postprocess_masks = unetr.postprocess_masks 697 698 self.decoder = unetr.decoder 699 self.deconv1 = unetr.deconv1 700 self.deconv2 = unetr.deconv2 701 self.deconv3 = unetr.deconv3 702 self.deconv4 = unetr.deconv4 703 704 def _forward_impl(self, input_): 705 z12 = input_ 706 707 z9 = self.deconv1(z12) 708 z6 = self.deconv2(z9) 709 z3 = self.deconv3(z6) 710 z0 = self.deconv4(z3) 711 712 updated_from_encoder = [z9, z6, z3] 713 714 x = self.base(z12) 715 x = self.decoder(x, encoder_inputs=updated_from_encoder) 716 x = self.deconv_out(x) 717 718 x = torch.cat([x, z0], dim=1) 719 x = self.decoder_head(x) 720 721 x = self.out_conv(x) 722 if self.final_activation is not None: 723 x = self.final_activation(x) 724 return x 725 726 def forward(self, input_, input_shape, original_shape): 727 x = self._forward_impl(input_) 728 x = self.postprocess_masks(x, input_shape, original_shape) 729 return x
Adapter to contain the UNETR decoder in a single module.
To apply the decoder on top of pre-computed embeddings for the segmentation functionality. See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py
688 def __init__(self, unetr: torch.nn.Module): 689 super().__init__() 690 691 self.base = unetr.base 692 self.out_conv = unetr.out_conv 693 self.deconv_out = unetr.deconv_out 694 self.decoder_head = unetr.decoder_head 695 self.final_activation = unetr.final_activation 696 self.postprocess_masks = unetr.postprocess_masks 697 698 self.decoder = unetr.decoder 699 self.deconv1 = unetr.deconv1 700 self.deconv2 = unetr.deconv2 701 self.deconv3 = unetr.deconv3 702 self.deconv4 = unetr.deconv4
Initialize internal Module state, shared by both nn.Module and ScriptModule.
726 def forward(self, input_, input_shape, original_shape): 727 x = self._forward_impl(input_) 728 x = self.postprocess_masks(x, input_shape, original_shape) 729 return x
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- set_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- mtia
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_post_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_pre_hook
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
732def get_unetr( 733 image_encoder: torch.nn.Module, 734 decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None, 735 device: Optional[Union[str, torch.device]] = None, 736 out_channels: int = 3, 737 flexible_load_checkpoint: bool = False, 738) -> torch.nn.Module: 739 """Get UNETR model for automatic instance segmentation. 740 741 Args: 742 image_encoder: The image encoder of the SAM model. 743 This is used as encoder by the UNETR too. 744 decoder_state: Optional decoder state to initialize the weights of the UNETR decoder. 745 device: The device. By default, automatically chooses the best available device. 746 out_channels: The number of output channels. By default, set to '3'. 747 flexible_load_checkpoint: Whether to allow reinitialization of parameters 748 which could not be found in the provided decoder state. By default, set to 'False'. 749 750 Returns: 751 The UNETR model. 752 """ 753 device = util.get_device(device) 754 755 if decoder_state is None: 756 use_conv_transpose = False # By default, we use interpolation for upsampling. 757 else: 758 # From the provided pretrained 'decoder_state', we check whether it uses transposed convolutions. 759 # NOTE: Explanation to the logic below - 760 # - We do this by looking for parameter names that contain '.block.' within the "decoder.samplers" 761 # submodules. This naming convention indicates that transposed convolutions are used, 762 # wrapped inside a custom block module. 763 # - Otherwise '.conv.' appears. It indicates a standard `Conv2d` applied after interpolation for upsampling. 764 use_conv_transpose = any(".block." in k for k in decoder_state.keys() if k.startswith("decoder.samplers")) 765 766 unetr = UNETR( 767 backbone="sam", 768 encoder=image_encoder, 769 out_channels=out_channels, 770 use_sam_stats=True, 771 final_activation="Sigmoid", 772 use_skip_connection=False, 773 resize_input=True, 774 use_conv_transpose=use_conv_transpose, 775 ) 776 777 if decoder_state is not None: 778 unetr_state_dict = unetr.state_dict() 779 for k, v in unetr_state_dict.items(): 780 if not k.startswith("encoder"): 781 if flexible_load_checkpoint: # Whether allow reinitalization of params, if not found. 782 if k in decoder_state: # First check whether the key is available in the provided decoder state. 783 unetr_state_dict[k] = decoder_state[k] 784 else: # Otherwise, allow it to initialize it. 785 warnings.warn(f"Could not find '{k}' in the pretrained state dict. Hence, we reinitialize it.") 786 unetr_state_dict[k] = v 787 788 else: # Whether be strict on finding the parameter in the decoder state. 789 if k not in decoder_state: 790 raise RuntimeError(f"The parameters for '{k}' could not be found.") 791 unetr_state_dict[k] = decoder_state[k] 792 793 unetr.load_state_dict(unetr_state_dict) 794 795 unetr.to(device) 796 return unetr
Get UNETR model for automatic instance segmentation.
Arguments:
- image_encoder: The image encoder of the SAM model. This is used as encoder by the UNETR too.
- decoder_state: Optional decoder state to initialize the weights of the UNETR decoder.
- device: The device. By default, automatically chooses the best available device.
- out_channels: The number of output channels. By default, set to '3'.
- flexible_load_checkpoint: Whether to allow reinitialization of parameters which could not be found in the provided decoder state. By default, set to 'False'.
Returns:
The UNETR model.
799def get_decoder( 800 image_encoder: torch.nn.Module, 801 decoder_state: OrderedDict[str, torch.Tensor], 802 device: Optional[Union[str, torch.device]] = None, 803) -> DecoderAdapter: 804 """Get decoder to predict outputs for automatic instance segmentation 805 806 Args: 807 image_encoder: The image encoder of the SAM model. 808 decoder_state: State to initialize the weights of the UNETR decoder. 809 device: The device. By default, automatically chooses the best available device. 810 811 Returns: 812 The decoder for instance segmentation. 813 """ 814 unetr = get_unetr(image_encoder, decoder_state, device) 815 return DecoderAdapter(unetr)
Get decoder to predict outputs for automatic instance segmentation
Arguments:
- image_encoder: The image encoder of the SAM model.
- decoder_state: State to initialize the weights of the UNETR decoder.
- device: The device. By default, automatically chooses the best available device.
Returns:
The decoder for instance segmentation.
818def get_predictor_and_decoder( 819 model_type: str, 820 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 821 device: Optional[Union[str, torch.device]] = None, 822 peft_kwargs: Optional[Dict] = None, 823) -> Tuple[SamPredictor, DecoderAdapter]: 824 """Load the SAM model (predictor) and instance segmentation decoder. 825 826 This requires a checkpoint that contains the state for both predictor 827 and decoder. 828 829 Args: 830 model_type: The type of the image encoder used in the SAM model. 831 checkpoint_path: Path to the checkpoint from which to load the data. 832 device: The device. By default, automatically chooses the best available device. 833 peft_kwargs: Keyword arguments for the PEFT wrapper class. 834 835 Returns: 836 The SAM predictor. 837 The decoder for instance segmentation. 838 """ 839 device = util.get_device(device) 840 predictor, state = util.get_sam_model( 841 model_type=model_type, 842 checkpoint_path=checkpoint_path, 843 device=device, 844 return_state=True, 845 peft_kwargs=peft_kwargs, 846 ) 847 848 if "decoder_state" not in state: 849 raise ValueError( 850 f"The checkpoint at '{checkpoint_path}' or the chosen model '{model_type}' does not contain a decoder state" 851 ) 852 853 decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device) 854 return predictor, decoder
Load the SAM model (predictor) and instance segmentation decoder.
This requires a checkpoint that contains the state for both predictor and decoder.
Arguments:
- model_type: The type of the image encoder used in the SAM model.
- checkpoint_path: Path to the checkpoint from which to load the data.
- device: The device. By default, automatically chooses the best available device.
- peft_kwargs: Keyword arguments for the PEFT wrapper class.
Returns:
The SAM predictor. The decoder for instance segmentation.
906class InstanceSegmentationWithDecoder: 907 """Generates an instance segmentation without prompts, using a decoder. 908 909 Implements the same interface as `AutomaticMaskGenerator`. 910 911 Use this class as follows: 912 ```python 913 segmenter = InstanceSegmentationWithDecoder(predictor, decoder) 914 segmenter.initialize(image) # Predict the image embeddings and decoder outputs. 915 masks = segmenter.generate(center_distance_threshold=0.75) # Generate the instance segmentation. 916 ``` 917 918 Args: 919 predictor: The segment anything predictor. 920 decoder: The decoder to predict intermediate representations for instance segmentation. 921 """ 922 def __init__(self, predictor: SamPredictor, decoder: torch.nn.Module) -> None: 923 self._predictor = predictor 924 self._decoder = decoder 925 926 # The decoder outputs. 927 self._foreground = None 928 self._center_distances = None 929 self._boundary_distances = None 930 931 self._is_initialized = False 932 933 @property 934 def is_initialized(self): 935 """Whether the mask generator has already been initialized. 936 """ 937 return self._is_initialized 938 939 @torch.no_grad() 940 def initialize( 941 self, 942 image: np.ndarray, 943 image_embeddings: Optional[util.ImageEmbeddings] = None, 944 i: Optional[int] = None, 945 verbose: bool = False, 946 pbar_init: Optional[callable] = None, 947 pbar_update: Optional[callable] = None, 948 ndim: int = 2, 949 ) -> None: 950 """Initialize image embeddings and decoder predictions for an image. 951 952 Args: 953 image: The input image, volume or timeseries. 954 image_embeddings: Optional precomputed image embeddings. 955 See `util.precompute_image_embeddings` for details. 956 i: Index for the image data. Required if `image` has three spatial dimensions 957 or a time dimension and two spatial dimensions. 958 verbose: Whether to be verbose. By default, set to 'False'. 959 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 960 Can be used together with pbar_update to handle napari progress bar in other thread. 961 To enables using this function within a threadworker. 962 pbar_update: Callback to update an external progress bar. 963 ndim: The dimensionality of the data. If not given will be deduced from the input data. By default, 2. 964 """ 965 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 966 pbar_init(1, "Initialize instance segmentation with decoder") 967 968 if image_embeddings is None: 969 image_embeddings = util.precompute_image_embeddings( 970 predictor=self._predictor, input_=image, ndim=ndim, verbose=verbose 971 ) 972 973 # Get the image embeddings from the predictor. 974 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i) 975 embeddings = self._predictor.features 976 input_shape = tuple(self._predictor.input_size) 977 original_shape = tuple(self._predictor.original_size) 978 979 # Run prediction with the UNETR decoder. 980 output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0) 981 assert output.shape[0] == 3, f"{output.shape}" 982 pbar_update(1) 983 pbar_close() 984 985 # Set the state. 986 self._foreground = output[0] 987 self._center_distances = output[1] 988 self._boundary_distances = output[2] 989 self._i = i 990 self._is_initialized = True 991 992 def _to_masks(self, segmentation, output_mode): 993 if output_mode != "binary_mask": 994 raise ValueError( 995 f"Output mode {output_mode} is not supported. Choose one of 'instance_segmentation', 'binary_masks'" 996 ) 997 998 props = regionprops(segmentation) 999 ndim = segmentation.ndim 1000 assert ndim in (2, 3) 1001 1002 shape = segmentation.shape 1003 if ndim == 2: 1004 crop_box = [0, shape[1], 0, shape[0]] 1005 else: 1006 crop_box = [0, shape[2], 0, shape[1], 0, shape[0]] 1007 1008 # go from skimage bbox in format [y0, x0, y1, x1] to SAM format [x0, w, y0, h] 1009 def to_bbox_2d(bbox): 1010 y0, x0 = bbox[0], bbox[1] 1011 w = bbox[3] - x0 1012 h = bbox[2] - y0 1013 return [x0, w, y0, h] 1014 1015 def to_bbox_3d(bbox): 1016 z0, y0, x0 = bbox[0], bbox[1], bbox[2] 1017 w = bbox[5] - x0 1018 h = bbox[4] - y0 1019 d = bbox[3] - y0 1020 return [x0, w, y0, h, z0, d] 1021 1022 to_bbox = to_bbox_2d if ndim == 2 else to_bbox_3d 1023 masks = [ 1024 { 1025 "segmentation": segmentation == prop.label, 1026 "area": prop.area, 1027 "bbox": to_bbox(prop.bbox), 1028 "crop_box": crop_box, 1029 "seg_id": prop.label, 1030 } for prop in props 1031 ] 1032 return masks 1033 1034 def generate( 1035 self, 1036 center_distance_threshold: float = 0.5, 1037 boundary_distance_threshold: float = 0.5, 1038 foreground_threshold: float = 0.5, 1039 foreground_smoothing: float = 1.0, 1040 distance_smoothing: float = 1.6, 1041 min_size: int = 0, 1042 output_mode: str = "instance_segmentation", 1043 tile_shape: Optional[Tuple[int, int]] = None, 1044 halo: Optional[Tuple[int, int]] = None, 1045 n_threads: Optional[int] = None, 1046 ) -> Union[List[Dict[str, Any]], np.ndarray]: 1047 """Generate instance segmentation for the currently initialized image. 1048 1049 Args: 1050 center_distance_threshold: Center distance predictions below this value will be 1051 used to find seeds (intersected with thresholded boundary distance predictions). 1052 By default, set to '0.5'. 1053 boundary_distance_threshold: Boundary distance predictions below this value will be 1054 used to find seeds (intersected with thresholded center distance predictions). 1055 By default, set to '0.5'. 1056 foreground_threshold: Foreground predictions above this value will be used as foreground mask. 1057 By default, set to '0.5'. 1058 foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid 1059 checkerboard artifacts in the prediction. By default, set to '1.0'. 1060 distance_smoothing: Sigma value for smoothing the distance predictions. 1061 min_size: Minimal object size in the segmentation result. By default, set to '0'. 1062 output_mode: The form masks are returned in. Possible values are: 1063 - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks. 1064 - 'instance_segmentation': Return masks merged into an instance segmentation in a single array. 1065 By default, set to 'instance_segmentation'. 1066 tile_shape: Tile shape for parallelizing the instance segmentation post-processing. 1067 This parameter is independent from the tile shape for computing the embeddings. 1068 If not given then post-processing will not be parallelized. 1069 halo: Halo for parallel post-processing. See also `tile_shape`. 1070 n_threads: Number of threads for parallel post-processing. See also `tile_shape`. 1071 1072 Returns: 1073 The segmentation masks. 1074 """ 1075 if not self.is_initialized: 1076 raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.") 1077 1078 if foreground_smoothing > 0: 1079 foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing) 1080 else: 1081 foreground = self._foreground 1082 1083 if tile_shape is None: 1084 segmentation = watershed_from_center_and_boundary_distances( 1085 center_distances=self._center_distances, 1086 boundary_distances=self._boundary_distances, 1087 foreground_map=foreground, 1088 center_distance_threshold=center_distance_threshold, 1089 boundary_distance_threshold=boundary_distance_threshold, 1090 foreground_threshold=foreground_threshold, 1091 distance_smoothing=distance_smoothing, 1092 min_size=min_size, 1093 ) 1094 else: 1095 if halo is None: 1096 raise ValueError("You must pass a value for halo if tile_shape is given.") 1097 segmentation = _watershed_from_center_and_boundary_distances_parallel( 1098 center_distances=self._center_distances, 1099 boundary_distances=self._boundary_distances, 1100 foreground_map=foreground, 1101 center_distance_threshold=center_distance_threshold, 1102 boundary_distance_threshold=boundary_distance_threshold, 1103 foreground_threshold=foreground_threshold, 1104 distance_smoothing=distance_smoothing, 1105 min_size=min_size, 1106 tile_shape=tile_shape, 1107 halo=halo, 1108 n_threads=n_threads, 1109 verbose=False, 1110 ) 1111 1112 if output_mode != "instance_segmentation": 1113 segmentation = self._to_masks(segmentation, output_mode) 1114 return segmentation 1115 1116 def get_state(self) -> Dict[str, Any]: 1117 """Get the initialized state of the instance segmenter. 1118 1119 Returns: 1120 Instance segmentation state. 1121 """ 1122 if not self.is_initialized: 1123 raise RuntimeError("The state has not been computed yet. Call initialize first.") 1124 1125 return { 1126 "foreground": self._foreground, 1127 "center_distances": self._center_distances, 1128 "boundary_distances": self._boundary_distances, 1129 } 1130 1131 def set_state(self, state: Dict[str, Any]) -> None: 1132 """Set the state of the instance segmenter. 1133 1134 Args: 1135 state: The instance segmentation state 1136 """ 1137 self._foreground = state["foreground"] 1138 self._center_distances = state["center_distances"] 1139 self._boundary_distances = state["boundary_distances"] 1140 self._is_initialized = True 1141 1142 def clear_state(self): 1143 """Clear the state of the instance segmenter. 1144 """ 1145 self._foreground = None 1146 self._center_distances = None 1147 self._boundary_distances = None 1148 self._is_initialized = False
Generates an instance segmentation without prompts, using a decoder.
Implements the same interface as AutomaticMaskGenerator.
Use this class as follows:
segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
segmenter.initialize(image) # Predict the image embeddings and decoder outputs.
masks = segmenter.generate(center_distance_threshold=0.75) # Generate the instance segmentation.
Arguments:
- predictor: The segment anything predictor.
- decoder: The decoder to predict intermediate representations for instance segmentation.
922 def __init__(self, predictor: SamPredictor, decoder: torch.nn.Module) -> None: 923 self._predictor = predictor 924 self._decoder = decoder 925 926 # The decoder outputs. 927 self._foreground = None 928 self._center_distances = None 929 self._boundary_distances = None 930 931 self._is_initialized = False
933 @property 934 def is_initialized(self): 935 """Whether the mask generator has already been initialized. 936 """ 937 return self._is_initialized
Whether the mask generator has already been initialized.
939 @torch.no_grad() 940 def initialize( 941 self, 942 image: np.ndarray, 943 image_embeddings: Optional[util.ImageEmbeddings] = None, 944 i: Optional[int] = None, 945 verbose: bool = False, 946 pbar_init: Optional[callable] = None, 947 pbar_update: Optional[callable] = None, 948 ndim: int = 2, 949 ) -> None: 950 """Initialize image embeddings and decoder predictions for an image. 951 952 Args: 953 image: The input image, volume or timeseries. 954 image_embeddings: Optional precomputed image embeddings. 955 See `util.precompute_image_embeddings` for details. 956 i: Index for the image data. Required if `image` has three spatial dimensions 957 or a time dimension and two spatial dimensions. 958 verbose: Whether to be verbose. By default, set to 'False'. 959 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 960 Can be used together with pbar_update to handle napari progress bar in other thread. 961 To enables using this function within a threadworker. 962 pbar_update: Callback to update an external progress bar. 963 ndim: The dimensionality of the data. If not given will be deduced from the input data. By default, 2. 964 """ 965 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 966 pbar_init(1, "Initialize instance segmentation with decoder") 967 968 if image_embeddings is None: 969 image_embeddings = util.precompute_image_embeddings( 970 predictor=self._predictor, input_=image, ndim=ndim, verbose=verbose 971 ) 972 973 # Get the image embeddings from the predictor. 974 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i) 975 embeddings = self._predictor.features 976 input_shape = tuple(self._predictor.input_size) 977 original_shape = tuple(self._predictor.original_size) 978 979 # Run prediction with the UNETR decoder. 980 output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0) 981 assert output.shape[0] == 3, f"{output.shape}" 982 pbar_update(1) 983 pbar_close() 984 985 # Set the state. 986 self._foreground = output[0] 987 self._center_distances = output[1] 988 self._boundary_distances = output[2] 989 self._i = i 990 self._is_initialized = True
Initialize image embeddings and decoder predictions for an image.
Arguments:
- image: The input image, volume or timeseries.
- image_embeddings: Optional precomputed image embeddings.
See
util.precompute_image_embeddingsfor details. - i: Index for the image data. Required if
imagehas three spatial dimensions or a time dimension and two spatial dimensions. - verbose: Whether to be verbose. By default, set to 'False'.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
- ndim: The dimensionality of the data. If not given will be deduced from the input data. By default, 2.
1034 def generate( 1035 self, 1036 center_distance_threshold: float = 0.5, 1037 boundary_distance_threshold: float = 0.5, 1038 foreground_threshold: float = 0.5, 1039 foreground_smoothing: float = 1.0, 1040 distance_smoothing: float = 1.6, 1041 min_size: int = 0, 1042 output_mode: str = "instance_segmentation", 1043 tile_shape: Optional[Tuple[int, int]] = None, 1044 halo: Optional[Tuple[int, int]] = None, 1045 n_threads: Optional[int] = None, 1046 ) -> Union[List[Dict[str, Any]], np.ndarray]: 1047 """Generate instance segmentation for the currently initialized image. 1048 1049 Args: 1050 center_distance_threshold: Center distance predictions below this value will be 1051 used to find seeds (intersected with thresholded boundary distance predictions). 1052 By default, set to '0.5'. 1053 boundary_distance_threshold: Boundary distance predictions below this value will be 1054 used to find seeds (intersected with thresholded center distance predictions). 1055 By default, set to '0.5'. 1056 foreground_threshold: Foreground predictions above this value will be used as foreground mask. 1057 By default, set to '0.5'. 1058 foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid 1059 checkerboard artifacts in the prediction. By default, set to '1.0'. 1060 distance_smoothing: Sigma value for smoothing the distance predictions. 1061 min_size: Minimal object size in the segmentation result. By default, set to '0'. 1062 output_mode: The form masks are returned in. Possible values are: 1063 - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks. 1064 - 'instance_segmentation': Return masks merged into an instance segmentation in a single array. 1065 By default, set to 'instance_segmentation'. 1066 tile_shape: Tile shape for parallelizing the instance segmentation post-processing. 1067 This parameter is independent from the tile shape for computing the embeddings. 1068 If not given then post-processing will not be parallelized. 1069 halo: Halo for parallel post-processing. See also `tile_shape`. 1070 n_threads: Number of threads for parallel post-processing. See also `tile_shape`. 1071 1072 Returns: 1073 The segmentation masks. 1074 """ 1075 if not self.is_initialized: 1076 raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.") 1077 1078 if foreground_smoothing > 0: 1079 foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing) 1080 else: 1081 foreground = self._foreground 1082 1083 if tile_shape is None: 1084 segmentation = watershed_from_center_and_boundary_distances( 1085 center_distances=self._center_distances, 1086 boundary_distances=self._boundary_distances, 1087 foreground_map=foreground, 1088 center_distance_threshold=center_distance_threshold, 1089 boundary_distance_threshold=boundary_distance_threshold, 1090 foreground_threshold=foreground_threshold, 1091 distance_smoothing=distance_smoothing, 1092 min_size=min_size, 1093 ) 1094 else: 1095 if halo is None: 1096 raise ValueError("You must pass a value for halo if tile_shape is given.") 1097 segmentation = _watershed_from_center_and_boundary_distances_parallel( 1098 center_distances=self._center_distances, 1099 boundary_distances=self._boundary_distances, 1100 foreground_map=foreground, 1101 center_distance_threshold=center_distance_threshold, 1102 boundary_distance_threshold=boundary_distance_threshold, 1103 foreground_threshold=foreground_threshold, 1104 distance_smoothing=distance_smoothing, 1105 min_size=min_size, 1106 tile_shape=tile_shape, 1107 halo=halo, 1108 n_threads=n_threads, 1109 verbose=False, 1110 ) 1111 1112 if output_mode != "instance_segmentation": 1113 segmentation = self._to_masks(segmentation, output_mode) 1114 return segmentation
Generate instance segmentation for the currently initialized image.
Arguments:
- center_distance_threshold: Center distance predictions below this value will be used to find seeds (intersected with thresholded boundary distance predictions). By default, set to '0.5'.
- boundary_distance_threshold: Boundary distance predictions below this value will be used to find seeds (intersected with thresholded center distance predictions). By default, set to '0.5'.
- foreground_threshold: Foreground predictions above this value will be used as foreground mask. By default, set to '0.5'.
- foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid checkerboard artifacts in the prediction. By default, set to '1.0'.
- distance_smoothing: Sigma value for smoothing the distance predictions.
- min_size: Minimal object size in the segmentation result. By default, set to '0'.
- output_mode: The form masks are returned in. Possible values are:
- 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
- 'instance_segmentation': Return masks merged into an instance segmentation in a single array. By default, set to 'instance_segmentation'.
- tile_shape: Tile shape for parallelizing the instance segmentation post-processing. This parameter is independent from the tile shape for computing the embeddings. If not given then post-processing will not be parallelized.
- halo: Halo for parallel post-processing. See also
tile_shape. - n_threads: Number of threads for parallel post-processing. See also
tile_shape.
Returns:
The segmentation masks.
1116 def get_state(self) -> Dict[str, Any]: 1117 """Get the initialized state of the instance segmenter. 1118 1119 Returns: 1120 Instance segmentation state. 1121 """ 1122 if not self.is_initialized: 1123 raise RuntimeError("The state has not been computed yet. Call initialize first.") 1124 1125 return { 1126 "foreground": self._foreground, 1127 "center_distances": self._center_distances, 1128 "boundary_distances": self._boundary_distances, 1129 }
Get the initialized state of the instance segmenter.
Returns:
Instance segmentation state.
1131 def set_state(self, state: Dict[str, Any]) -> None: 1132 """Set the state of the instance segmenter. 1133 1134 Args: 1135 state: The instance segmentation state 1136 """ 1137 self._foreground = state["foreground"] 1138 self._center_distances = state["center_distances"] 1139 self._boundary_distances = state["boundary_distances"] 1140 self._is_initialized = True
Set the state of the instance segmenter.
Arguments:
- state: The instance segmentation state
1151class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder): 1152 """Same as `InstanceSegmentationWithDecoder` but for tiled image embeddings. 1153 """ 1154 1155 # Apply the decoder in a batched fashion, and then perform the resizing independently per output. 1156 # This is necessary, because the individual tiles may have different tile shapes due to border tiles. 1157 def _predict_decoder(self, batched_embeddings, input_shapes, original_shapes): 1158 batched_embeddings = torch.cat(batched_embeddings) 1159 output = self._decoder._forward_impl(batched_embeddings) 1160 1161 batched_output = [] 1162 for x, input_shape, original_shape in zip(output, input_shapes, original_shapes): 1163 x = self._decoder.postprocess_masks(x.unsqueeze(0), input_shape, original_shape).squeeze(0) 1164 batched_output.append(x.cpu().numpy()) 1165 return batched_output 1166 1167 @torch.no_grad() 1168 def initialize( 1169 self, 1170 image: np.ndarray, 1171 image_embeddings: Optional[util.ImageEmbeddings] = None, 1172 i: Optional[int] = None, 1173 tile_shape: Optional[Tuple[int, int]] = None, 1174 halo: Optional[Tuple[int, int]] = None, 1175 verbose: bool = False, 1176 pbar_init: Optional[callable] = None, 1177 pbar_update: Optional[callable] = None, 1178 batch_size: int = 1, 1179 mask: Optional[np.typing.ArrayLike] = None, 1180 ) -> None: 1181 """Initialize image embeddings and decoder predictions for an image. 1182 1183 Args: 1184 image: The input image, volume or timeseries. 1185 image_embeddings: Optional precomputed image embeddings. 1186 See `util.precompute_image_embeddings` for details. 1187 i: Index for the image data. Required if `image` has three spatial dimensions 1188 or a time dimension and two spatial dimensions. 1189 tile_shape: Shape of the tiles for precomputing image embeddings. 1190 halo: Overlap of the tiles for tiled precomputation of image embeddings. 1191 verbose: Dummy input to be compatible with other function signatures. By default, set to 'False'. 1192 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 1193 Can be used together with pbar_update to handle napari progress bar in other thread. 1194 To enables using this function within a threadworker. 1195 pbar_update: Callback to update an external progress bar. 1196 batch_size: The batch size for image embedding computation and segmentation decoder prediction. 1197 mask: An optional mask to define areas that are ignored in the segmentation. 1198 """ 1199 original_size = image.shape[:2] 1200 self._image_embeddings, tile_shape, halo, tiles_in_mask = _process_tiled_embeddings( 1201 self._predictor, image, image_embeddings, tile_shape, halo, 1202 verbose=verbose, batch_size=batch_size, mask=mask, i=i, 1203 ) 1204 tiling = blocking([0, 0], original_size, tile_shape) 1205 1206 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 1207 1208 foreground = np.zeros(original_size, dtype="float32") 1209 center_distances = np.zeros(original_size, dtype="float32") 1210 boundary_distances = np.zeros(original_size, dtype="float32") 1211 1212 msg = "Initialize tiled instance segmentation with decoder" 1213 if tiles_in_mask is None: 1214 n_tiles = tiling.numberOfBlocks 1215 all_tile_ids = list(range(n_tiles)) 1216 else: 1217 n_tiles = len(tiles_in_mask) 1218 all_tile_ids = tiles_in_mask 1219 msg += " and mask" 1220 1221 n_batches = int(np.ceil(n_tiles / batch_size)) 1222 pbar_init(n_tiles, msg) 1223 tile_ids_for_batches = np.array_split(all_tile_ids, n_batches) 1224 1225 for tile_ids in tile_ids_for_batches: 1226 batched_embeddings, input_shapes, original_shapes = [], [], [] 1227 for tile_id in tile_ids: 1228 # Get the image embeddings from the predictor for this tile. 1229 self._predictor = util.set_precomputed(self._predictor, self._image_embeddings, i=i, tile_id=tile_id) 1230 1231 batched_embeddings.append(self._predictor.features) 1232 input_shapes.append(tuple(self._predictor.input_size)) 1233 original_shapes.append(tuple(self._predictor.original_size)) 1234 1235 batched_output = self._predict_decoder(batched_embeddings, input_shapes, original_shapes) 1236 1237 for output_id, tile_id in enumerate(tile_ids): 1238 output = batched_output[output_id] 1239 assert output.shape[0] == 3 1240 1241 # Set the predictions in the output for this tile. 1242 block = tiling.getBlockWithHalo(tile_id, halo=list(halo)) 1243 local_bb = tuple( 1244 slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end) 1245 ) 1246 inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end)) 1247 1248 foreground[inner_bb] = output[0][local_bb] 1249 center_distances[inner_bb] = output[1][local_bb] 1250 boundary_distances[inner_bb] = output[2][local_bb] 1251 pbar_update(1) 1252 1253 pbar_close() 1254 1255 # Set the state. 1256 self._i = i 1257 self._foreground = foreground 1258 self._center_distances = center_distances 1259 self._boundary_distances = boundary_distances 1260 self._is_initialized = True
Same as InstanceSegmentationWithDecoder but for tiled image embeddings.
1167 @torch.no_grad() 1168 def initialize( 1169 self, 1170 image: np.ndarray, 1171 image_embeddings: Optional[util.ImageEmbeddings] = None, 1172 i: Optional[int] = None, 1173 tile_shape: Optional[Tuple[int, int]] = None, 1174 halo: Optional[Tuple[int, int]] = None, 1175 verbose: bool = False, 1176 pbar_init: Optional[callable] = None, 1177 pbar_update: Optional[callable] = None, 1178 batch_size: int = 1, 1179 mask: Optional[np.typing.ArrayLike] = None, 1180 ) -> None: 1181 """Initialize image embeddings and decoder predictions for an image. 1182 1183 Args: 1184 image: The input image, volume or timeseries. 1185 image_embeddings: Optional precomputed image embeddings. 1186 See `util.precompute_image_embeddings` for details. 1187 i: Index for the image data. Required if `image` has three spatial dimensions 1188 or a time dimension and two spatial dimensions. 1189 tile_shape: Shape of the tiles for precomputing image embeddings. 1190 halo: Overlap of the tiles for tiled precomputation of image embeddings. 1191 verbose: Dummy input to be compatible with other function signatures. By default, set to 'False'. 1192 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 1193 Can be used together with pbar_update to handle napari progress bar in other thread. 1194 To enables using this function within a threadworker. 1195 pbar_update: Callback to update an external progress bar. 1196 batch_size: The batch size for image embedding computation and segmentation decoder prediction. 1197 mask: An optional mask to define areas that are ignored in the segmentation. 1198 """ 1199 original_size = image.shape[:2] 1200 self._image_embeddings, tile_shape, halo, tiles_in_mask = _process_tiled_embeddings( 1201 self._predictor, image, image_embeddings, tile_shape, halo, 1202 verbose=verbose, batch_size=batch_size, mask=mask, i=i, 1203 ) 1204 tiling = blocking([0, 0], original_size, tile_shape) 1205 1206 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 1207 1208 foreground = np.zeros(original_size, dtype="float32") 1209 center_distances = np.zeros(original_size, dtype="float32") 1210 boundary_distances = np.zeros(original_size, dtype="float32") 1211 1212 msg = "Initialize tiled instance segmentation with decoder" 1213 if tiles_in_mask is None: 1214 n_tiles = tiling.numberOfBlocks 1215 all_tile_ids = list(range(n_tiles)) 1216 else: 1217 n_tiles = len(tiles_in_mask) 1218 all_tile_ids = tiles_in_mask 1219 msg += " and mask" 1220 1221 n_batches = int(np.ceil(n_tiles / batch_size)) 1222 pbar_init(n_tiles, msg) 1223 tile_ids_for_batches = np.array_split(all_tile_ids, n_batches) 1224 1225 for tile_ids in tile_ids_for_batches: 1226 batched_embeddings, input_shapes, original_shapes = [], [], [] 1227 for tile_id in tile_ids: 1228 # Get the image embeddings from the predictor for this tile. 1229 self._predictor = util.set_precomputed(self._predictor, self._image_embeddings, i=i, tile_id=tile_id) 1230 1231 batched_embeddings.append(self._predictor.features) 1232 input_shapes.append(tuple(self._predictor.input_size)) 1233 original_shapes.append(tuple(self._predictor.original_size)) 1234 1235 batched_output = self._predict_decoder(batched_embeddings, input_shapes, original_shapes) 1236 1237 for output_id, tile_id in enumerate(tile_ids): 1238 output = batched_output[output_id] 1239 assert output.shape[0] == 3 1240 1241 # Set the predictions in the output for this tile. 1242 block = tiling.getBlockWithHalo(tile_id, halo=list(halo)) 1243 local_bb = tuple( 1244 slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end) 1245 ) 1246 inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end)) 1247 1248 foreground[inner_bb] = output[0][local_bb] 1249 center_distances[inner_bb] = output[1][local_bb] 1250 boundary_distances[inner_bb] = output[2][local_bb] 1251 pbar_update(1) 1252 1253 pbar_close() 1254 1255 # Set the state. 1256 self._i = i 1257 self._foreground = foreground 1258 self._center_distances = center_distances 1259 self._boundary_distances = boundary_distances 1260 self._is_initialized = True
Initialize image embeddings and decoder predictions for an image.
Arguments:
- image: The input image, volume or timeseries.
- image_embeddings: Optional precomputed image embeddings.
See
util.precompute_image_embeddingsfor details. - i: Index for the image data. Required if
imagehas three spatial dimensions or a time dimension and two spatial dimensions. - tile_shape: Shape of the tiles for precomputing image embeddings.
- halo: Overlap of the tiles for tiled precomputation of image embeddings.
- verbose: Dummy input to be compatible with other function signatures. By default, set to 'False'.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
- batch_size: The batch size for image embedding computation and segmentation decoder prediction.
- mask: An optional mask to define areas that are ignored in the segmentation.
1335class AutomaticPromptGenerator(InstanceSegmentationWithDecoder): 1336 """Generates an instance segmentation automatically, using automatically generated prompts from a decoder. 1337 1338 This class is used in the same way as `InstanceSegmentationWithDecoder` and `AutomaticMaskGenerator` 1339 1340 Args: 1341 predictor: The segment anything predictor. 1342 decoder: The derive prompts for automatic instance segmentation. 1343 """ 1344 def generate( 1345 self, 1346 min_size: int = 25, 1347 center_distance_threshold: float = 0.5, 1348 boundary_distance_threshold: float = 0.5, 1349 foreground_threshold: float = 0.5, 1350 multimasking: bool = False, 1351 batch_size: int = 32, 1352 nms_threshold: float = 0.9, 1353 intersection_over_min: bool = False, 1354 output_mode: str = "instance_segmentation", 1355 mask_threshold: Optional[Union[float, str]] = None, 1356 refine_with_box_prompts: bool = False, 1357 prompt_function: Optional[callable] = None, 1358 ) -> Union[List[Dict[str, Any]], np.ndarray]: 1359 """Generate the instance segmentation for the currently initialized image. 1360 1361 The instance segmentation is generated by deriving prompts from the foreground and 1362 distance predictions of the segmentation decoder by thresholding these predictions, 1363 intersecting them, computing connected components, and then using the component's 1364 centers as point prompts. The masks are then filtered via NMS and merged into a segmentation. 1365 1366 Args: 1367 min_size: Minimal object size in the segmentation result. By default, set to '25'. 1368 center_distance_threshold: The threshold for the center distance predictions. 1369 boundary_distance_threshold: The threshold for the boundary distance predictions. 1370 multimasking: Whether to use multi-mask prediction for turning the prompts into masks. 1371 batch_size: The batch size for parallelizing the prediction based on prompts. 1372 nms_threshold: The threshold for non-maximum suppression (NMS). 1373 intersection_over_min: Whether to use the minimum area of the two objects or the 1374 intersection over union area (default) in NMS. 1375 output_mode: The form masks are returned in. Possible values are: 1376 - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks. 1377 - 'instance_segmentation': Return masks merged into an instance segmentation in a single array. 1378 By default, set to 'instance_segmentation'. 1379 mask_threshold: The threshold for turining logits into masks in `micro_sam.inference.batched_inference`.` 1380 refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps 1381 derived from the segmentations after point prompts. 1382 prompt_function: A custom function for deriving prompts from the segmentation decoder predictions. 1383 If given, the default prompt derivation procedure is not used. Must have the following signature: 1384 ``` 1385 def prompt_function(foreground, center_distances, boundary_distances, **kwargs) 1386 ``` 1387 where `foreground`, `center_distances`, and `boundary_distances` are the respective 1388 predictions from the segmentation decoder. It must returns a dictionary containing 1389 either point, box, or mask prompts in a format compattible with `micro_sam.inference.batched_inference`. 1390 1391 Returns: 1392 The instance segmentation masks. 1393 """ 1394 if not self.is_initialized: 1395 raise RuntimeError("AutomaticPromptGenerator has not been initialized. Call initialize first.") 1396 foreground, center_distances, boundary_distances =\ 1397 self._foreground, self._center_distances, self._boundary_distances 1398 1399 # 1.) Derive promtps from the decoder predictions. 1400 prompt_function = _derive_point_prompts if prompt_function is None else prompt_function 1401 prompts = prompt_function( 1402 foreground=foreground, 1403 center_distances=center_distances, 1404 boundary_distances=boundary_distances, 1405 foreground_threshold=foreground_threshold, 1406 center_distance_threshold=center_distance_threshold, 1407 boundary_distance_threshold=boundary_distance_threshold, 1408 ) 1409 1410 # 2.) Apply the predictor to the prompts. 1411 if prompts is None: # No prompts were derived, we can't do much further and return empty masks. 1412 return np.zeros(foreground.shape, dtype="uint32") if output_mode == "instance_egmentation" else [] 1413 else: 1414 predictions = batched_inference( 1415 self._predictor, 1416 image=None, 1417 batch_size=batch_size, 1418 return_instance_segmentation=False, 1419 multimasking=multimasking, 1420 mask_threshold=mask_threshold, 1421 i=getattr(self, "_i", None), 1422 **prompts, 1423 ) 1424 1425 # 3.) Refine the segmentation with box prompts. 1426 if refine_with_box_prompts: 1427 box_extension = 0.01 # expose as hyperparam? 1428 prompts = _derive_box_prompts(predictions, box_extension) 1429 predictions = batched_inference( 1430 self._predictor, 1431 image=None, 1432 batch_size=batch_size, 1433 return_instance_segmentation=False, 1434 multimasking=multimasking, 1435 mask_threshold=mask_threshold, 1436 i=getattr(self, "_i", None), 1437 **prompts, 1438 ) 1439 1440 # 4.) Apply non-max suppression to the masks. 1441 segmentation = util.apply_nms( 1442 predictions, min_size=min_size, nms_thresh=nms_threshold, intersection_over_min=intersection_over_min 1443 ) 1444 if output_mode != "instance_segmentation": 1445 segmentation = self._to_masks(segmentation, output_mode) 1446 return segmentation
Generates an instance segmentation automatically, using automatically generated prompts from a decoder.
This class is used in the same way as InstanceSegmentationWithDecoder and AutomaticMaskGenerator
Arguments:
- predictor: The segment anything predictor.
- decoder: The derive prompts for automatic instance segmentation.
1344 def generate( 1345 self, 1346 min_size: int = 25, 1347 center_distance_threshold: float = 0.5, 1348 boundary_distance_threshold: float = 0.5, 1349 foreground_threshold: float = 0.5, 1350 multimasking: bool = False, 1351 batch_size: int = 32, 1352 nms_threshold: float = 0.9, 1353 intersection_over_min: bool = False, 1354 output_mode: str = "instance_segmentation", 1355 mask_threshold: Optional[Union[float, str]] = None, 1356 refine_with_box_prompts: bool = False, 1357 prompt_function: Optional[callable] = None, 1358 ) -> Union[List[Dict[str, Any]], np.ndarray]: 1359 """Generate the instance segmentation for the currently initialized image. 1360 1361 The instance segmentation is generated by deriving prompts from the foreground and 1362 distance predictions of the segmentation decoder by thresholding these predictions, 1363 intersecting them, computing connected components, and then using the component's 1364 centers as point prompts. The masks are then filtered via NMS and merged into a segmentation. 1365 1366 Args: 1367 min_size: Minimal object size in the segmentation result. By default, set to '25'. 1368 center_distance_threshold: The threshold for the center distance predictions. 1369 boundary_distance_threshold: The threshold for the boundary distance predictions. 1370 multimasking: Whether to use multi-mask prediction for turning the prompts into masks. 1371 batch_size: The batch size for parallelizing the prediction based on prompts. 1372 nms_threshold: The threshold for non-maximum suppression (NMS). 1373 intersection_over_min: Whether to use the minimum area of the two objects or the 1374 intersection over union area (default) in NMS. 1375 output_mode: The form masks are returned in. Possible values are: 1376 - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks. 1377 - 'instance_segmentation': Return masks merged into an instance segmentation in a single array. 1378 By default, set to 'instance_segmentation'. 1379 mask_threshold: The threshold for turining logits into masks in `micro_sam.inference.batched_inference`.` 1380 refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps 1381 derived from the segmentations after point prompts. 1382 prompt_function: A custom function for deriving prompts from the segmentation decoder predictions. 1383 If given, the default prompt derivation procedure is not used. Must have the following signature: 1384 ``` 1385 def prompt_function(foreground, center_distances, boundary_distances, **kwargs) 1386 ``` 1387 where `foreground`, `center_distances`, and `boundary_distances` are the respective 1388 predictions from the segmentation decoder. It must returns a dictionary containing 1389 either point, box, or mask prompts in a format compattible with `micro_sam.inference.batched_inference`. 1390 1391 Returns: 1392 The instance segmentation masks. 1393 """ 1394 if not self.is_initialized: 1395 raise RuntimeError("AutomaticPromptGenerator has not been initialized. Call initialize first.") 1396 foreground, center_distances, boundary_distances =\ 1397 self._foreground, self._center_distances, self._boundary_distances 1398 1399 # 1.) Derive promtps from the decoder predictions. 1400 prompt_function = _derive_point_prompts if prompt_function is None else prompt_function 1401 prompts = prompt_function( 1402 foreground=foreground, 1403 center_distances=center_distances, 1404 boundary_distances=boundary_distances, 1405 foreground_threshold=foreground_threshold, 1406 center_distance_threshold=center_distance_threshold, 1407 boundary_distance_threshold=boundary_distance_threshold, 1408 ) 1409 1410 # 2.) Apply the predictor to the prompts. 1411 if prompts is None: # No prompts were derived, we can't do much further and return empty masks. 1412 return np.zeros(foreground.shape, dtype="uint32") if output_mode == "instance_egmentation" else [] 1413 else: 1414 predictions = batched_inference( 1415 self._predictor, 1416 image=None, 1417 batch_size=batch_size, 1418 return_instance_segmentation=False, 1419 multimasking=multimasking, 1420 mask_threshold=mask_threshold, 1421 i=getattr(self, "_i", None), 1422 **prompts, 1423 ) 1424 1425 # 3.) Refine the segmentation with box prompts. 1426 if refine_with_box_prompts: 1427 box_extension = 0.01 # expose as hyperparam? 1428 prompts = _derive_box_prompts(predictions, box_extension) 1429 predictions = batched_inference( 1430 self._predictor, 1431 image=None, 1432 batch_size=batch_size, 1433 return_instance_segmentation=False, 1434 multimasking=multimasking, 1435 mask_threshold=mask_threshold, 1436 i=getattr(self, "_i", None), 1437 **prompts, 1438 ) 1439 1440 # 4.) Apply non-max suppression to the masks. 1441 segmentation = util.apply_nms( 1442 predictions, min_size=min_size, nms_thresh=nms_threshold, intersection_over_min=intersection_over_min 1443 ) 1444 if output_mode != "instance_segmentation": 1445 segmentation = self._to_masks(segmentation, output_mode) 1446 return segmentation
Generate the instance segmentation for the currently initialized image.
The instance segmentation is generated by deriving prompts from the foreground and distance predictions of the segmentation decoder by thresholding these predictions, intersecting them, computing connected components, and then using the component's centers as point prompts. The masks are then filtered via NMS and merged into a segmentation.
Arguments:
- min_size: Minimal object size in the segmentation result. By default, set to '25'.
- center_distance_threshold: The threshold for the center distance predictions.
- boundary_distance_threshold: The threshold for the boundary distance predictions.
- multimasking: Whether to use multi-mask prediction for turning the prompts into masks.
- batch_size: The batch size for parallelizing the prediction based on prompts.
- nms_threshold: The threshold for non-maximum suppression (NMS).
- intersection_over_min: Whether to use the minimum area of the two objects or the intersection over union area (default) in NMS.
- output_mode: The form masks are returned in. Possible values are:
- 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
- 'instance_segmentation': Return masks merged into an instance segmentation in a single array. By default, set to 'instance_segmentation'.
- mask_threshold: The threshold for turining logits into masks in
micro_sam.inference.batched_inference.` - refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps derived from the segmentations after point prompts.
- prompt_function: A custom function for deriving prompts from the segmentation decoder predictions.
If given, the default prompt derivation procedure is not used. Must have the following signature:
wheredef prompt_function(foreground, center_distances, boundary_distances, **kwargs)foreground,center_distances, andboundary_distancesare the respective predictions from the segmentation decoder. It must returns a dictionary containing either point, box, or mask prompts in a format compattible withmicro_sam.inference.batched_inference.
Returns:
The instance segmentation masks.
1449class TiledAutomaticPromptGenerator(TiledInstanceSegmentationWithDecoder): 1450 """Same as `AutomaticPromptGenerator` but for tiled image embeddings. 1451 """ 1452 def generate( 1453 self, 1454 min_size: int = 25, 1455 center_distance_threshold: float = 0.5, 1456 boundary_distance_threshold: float = 0.5, 1457 foreground_threshold: float = 0.5, 1458 multimasking: bool = False, 1459 batch_size: int = 32, 1460 nms_threshold: float = 0.9, 1461 intersection_over_min: bool = False, 1462 output_mode: str = "instance_segmentation", 1463 mask_threshold: Optional[Union[float, str]] = None, 1464 refine_with_box_prompts: bool = False, 1465 prompt_function: Optional[callable] = None, 1466 optimize_memory: bool = False, 1467 ) -> List[Dict[str, Any]]: 1468 """Generate tiling-based instance segmentation for the currently initialized image. 1469 1470 Args: 1471 min_size: Minimal object size in the segmentation result. By default, set to '25'. 1472 center_distance_threshold: The threshold for the center distance predictions. 1473 boundary_distance_threshold: The threshold for the boundary distance predictions. 1474 multimasking: Whether to use multi-mask prediction for turning the prompts into masks. 1475 batch_size: The batch size for parallelizing the prediction based on prompts. 1476 nms_threshold: The threshold for non-maximum suppression (NMS). 1477 intersection_over_min: Whether to use the minimum area of the two objects or the 1478 intersection over union area (default) in NMS. 1479 output_mode: The form masks are returned in. Possible values are: 1480 - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks. 1481 - 'instance_segmentation': Return masks merged into an instance segmentation in a single array. 1482 By default, set to 'instance_segmentation'. 1483 mask_threshold: The threshold for turining logits into masks in `micro_sam.inference.batched_inference`.` 1484 refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps 1485 derived from the segmentations after point prompts. Currently not supported for tiled segmentation. 1486 prompt_function: A custom function for deriving prompts from the segmentation decoder predictions. 1487 If given, the default prompt derivation procedure is not used. Must have the following signature: 1488 ``` 1489 def prompt_function(foreground, center_distances, boundary_distances, **kwargs) 1490 ``` 1491 where `foreground`, `center_distances`, and `boundary_distances` are the respective 1492 predictions from the segmentation decoder. It must returns a dictionary containing 1493 either point, box, or mask prompts in a format compattible with `micro_sam.inference.batched_inference`. 1494 optimize_memory: Whether to optimize the memory consumption by merging the per-slice 1495 segmentation results immediatly with NMS, rather than running a NMS for all results. 1496 This may lead to a slightly different segmentation result and is only compatible with 1497 `refine_with_box_prompts=False` and `output_mode="instance_segmentation"`. 1498 1499 Returns: 1500 The instance segmentation masks. 1501 """ 1502 if not self.is_initialized: 1503 raise RuntimeError("TiledAutomaticPromptGenerator has not been initialized. Call initialize first.") 1504 if optimize_memory and (output_mode != "instance_segmentation" or refine_with_box_prompts): 1505 raise ValueError("Invalid settings") 1506 foreground, center_distances, boundary_distances =\ 1507 self._foreground, self._center_distances, self._boundary_distances 1508 1509 # 1.) Derive promtps from the decoder predictions. 1510 prompt_function = _derive_point_prompts if prompt_function is None else prompt_function 1511 prompts = prompt_function( 1512 foreground, 1513 center_distances, 1514 boundary_distances, 1515 foreground_threshold=foreground_threshold, 1516 center_distance_threshold=center_distance_threshold, 1517 boundary_distance_threshold=boundary_distance_threshold, 1518 ) 1519 1520 # 2.) Apply the predictor to the prompts. 1521 shape = foreground.shape 1522 if prompts is None: # No prompts were derived, we can't do much further and return empty masks. 1523 return np.zeros(shape, dtype="uint32") if output_mode == "instance_segmentation" else [] 1524 else: 1525 if optimize_memory: 1526 prompts.update(dict( 1527 min_size=min_size, nms_thresh=nms_threshold, intersection_over_min=intersection_over_min 1528 )) 1529 predictions = batched_tiled_inference( 1530 self._predictor, 1531 image=None, 1532 batch_size=batch_size, 1533 image_embeddings=self._image_embeddings, 1534 return_instance_segmentation=False, 1535 multimasking=multimasking, 1536 optimize_memory=optimize_memory, 1537 i=getattr(self, "_i", None), 1538 **prompts 1539 ) 1540 # Optimize memory directly returns an instance segmentation and does not 1541 # allow for any further refinements. 1542 if optimize_memory: 1543 return predictions 1544 1545 # 3.) Refine the segmentation with box prompts. 1546 if refine_with_box_prompts: 1547 # TODO 1548 raise NotImplementedError 1549 1550 # 4.) Apply non-max suppression to the masks. 1551 segmentation = util.apply_nms( 1552 predictions, shape=shape, min_size=min_size, nms_thresh=nms_threshold, 1553 intersection_over_min=intersection_over_min, 1554 ) 1555 if output_mode != "instance_segmentation": 1556 segmentation = self._to_masks(segmentation, output_mode) 1557 return segmentation 1558 1559 # Set state and get state are not implemented yet, as this generator relies on having the image embeddings 1560 # in the state. However, they should not be serialized here and we have to address this a bit differently. 1561 def get_state(self): 1562 """@private 1563 """ 1564 raise NotImplementedError 1565 1566 def set_state(self, state): 1567 """@private 1568 """ 1569 raise NotImplementedError
Same as AutomaticPromptGenerator but for tiled image embeddings.
1452 def generate( 1453 self, 1454 min_size: int = 25, 1455 center_distance_threshold: float = 0.5, 1456 boundary_distance_threshold: float = 0.5, 1457 foreground_threshold: float = 0.5, 1458 multimasking: bool = False, 1459 batch_size: int = 32, 1460 nms_threshold: float = 0.9, 1461 intersection_over_min: bool = False, 1462 output_mode: str = "instance_segmentation", 1463 mask_threshold: Optional[Union[float, str]] = None, 1464 refine_with_box_prompts: bool = False, 1465 prompt_function: Optional[callable] = None, 1466 optimize_memory: bool = False, 1467 ) -> List[Dict[str, Any]]: 1468 """Generate tiling-based instance segmentation for the currently initialized image. 1469 1470 Args: 1471 min_size: Minimal object size in the segmentation result. By default, set to '25'. 1472 center_distance_threshold: The threshold for the center distance predictions. 1473 boundary_distance_threshold: The threshold for the boundary distance predictions. 1474 multimasking: Whether to use multi-mask prediction for turning the prompts into masks. 1475 batch_size: The batch size for parallelizing the prediction based on prompts. 1476 nms_threshold: The threshold for non-maximum suppression (NMS). 1477 intersection_over_min: Whether to use the minimum area of the two objects or the 1478 intersection over union area (default) in NMS. 1479 output_mode: The form masks are returned in. Possible values are: 1480 - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks. 1481 - 'instance_segmentation': Return masks merged into an instance segmentation in a single array. 1482 By default, set to 'instance_segmentation'. 1483 mask_threshold: The threshold for turining logits into masks in `micro_sam.inference.batched_inference`.` 1484 refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps 1485 derived from the segmentations after point prompts. Currently not supported for tiled segmentation. 1486 prompt_function: A custom function for deriving prompts from the segmentation decoder predictions. 1487 If given, the default prompt derivation procedure is not used. Must have the following signature: 1488 ``` 1489 def prompt_function(foreground, center_distances, boundary_distances, **kwargs) 1490 ``` 1491 where `foreground`, `center_distances`, and `boundary_distances` are the respective 1492 predictions from the segmentation decoder. It must returns a dictionary containing 1493 either point, box, or mask prompts in a format compattible with `micro_sam.inference.batched_inference`. 1494 optimize_memory: Whether to optimize the memory consumption by merging the per-slice 1495 segmentation results immediatly with NMS, rather than running a NMS for all results. 1496 This may lead to a slightly different segmentation result and is only compatible with 1497 `refine_with_box_prompts=False` and `output_mode="instance_segmentation"`. 1498 1499 Returns: 1500 The instance segmentation masks. 1501 """ 1502 if not self.is_initialized: 1503 raise RuntimeError("TiledAutomaticPromptGenerator has not been initialized. Call initialize first.") 1504 if optimize_memory and (output_mode != "instance_segmentation" or refine_with_box_prompts): 1505 raise ValueError("Invalid settings") 1506 foreground, center_distances, boundary_distances =\ 1507 self._foreground, self._center_distances, self._boundary_distances 1508 1509 # 1.) Derive promtps from the decoder predictions. 1510 prompt_function = _derive_point_prompts if prompt_function is None else prompt_function 1511 prompts = prompt_function( 1512 foreground, 1513 center_distances, 1514 boundary_distances, 1515 foreground_threshold=foreground_threshold, 1516 center_distance_threshold=center_distance_threshold, 1517 boundary_distance_threshold=boundary_distance_threshold, 1518 ) 1519 1520 # 2.) Apply the predictor to the prompts. 1521 shape = foreground.shape 1522 if prompts is None: # No prompts were derived, we can't do much further and return empty masks. 1523 return np.zeros(shape, dtype="uint32") if output_mode == "instance_segmentation" else [] 1524 else: 1525 if optimize_memory: 1526 prompts.update(dict( 1527 min_size=min_size, nms_thresh=nms_threshold, intersection_over_min=intersection_over_min 1528 )) 1529 predictions = batched_tiled_inference( 1530 self._predictor, 1531 image=None, 1532 batch_size=batch_size, 1533 image_embeddings=self._image_embeddings, 1534 return_instance_segmentation=False, 1535 multimasking=multimasking, 1536 optimize_memory=optimize_memory, 1537 i=getattr(self, "_i", None), 1538 **prompts 1539 ) 1540 # Optimize memory directly returns an instance segmentation and does not 1541 # allow for any further refinements. 1542 if optimize_memory: 1543 return predictions 1544 1545 # 3.) Refine the segmentation with box prompts. 1546 if refine_with_box_prompts: 1547 # TODO 1548 raise NotImplementedError 1549 1550 # 4.) Apply non-max suppression to the masks. 1551 segmentation = util.apply_nms( 1552 predictions, shape=shape, min_size=min_size, nms_thresh=nms_threshold, 1553 intersection_over_min=intersection_over_min, 1554 ) 1555 if output_mode != "instance_segmentation": 1556 segmentation = self._to_masks(segmentation, output_mode) 1557 return segmentation
Generate tiling-based instance segmentation for the currently initialized image.
Arguments:
- min_size: Minimal object size in the segmentation result. By default, set to '25'.
- center_distance_threshold: The threshold for the center distance predictions.
- boundary_distance_threshold: The threshold for the boundary distance predictions.
- multimasking: Whether to use multi-mask prediction for turning the prompts into masks.
- batch_size: The batch size for parallelizing the prediction based on prompts.
- nms_threshold: The threshold for non-maximum suppression (NMS).
- intersection_over_min: Whether to use the minimum area of the two objects or the intersection over union area (default) in NMS.
- output_mode: The form masks are returned in. Possible values are:
- 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
- 'instance_segmentation': Return masks merged into an instance segmentation in a single array. By default, set to 'instance_segmentation'.
- mask_threshold: The threshold for turining logits into masks in
micro_sam.inference.batched_inference.` - refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps derived from the segmentations after point prompts. Currently not supported for tiled segmentation.
- prompt_function: A custom function for deriving prompts from the segmentation decoder predictions.
If given, the default prompt derivation procedure is not used. Must have the following signature:
wheredef prompt_function(foreground, center_distances, boundary_distances, **kwargs)foreground,center_distances, andboundary_distancesare the respective predictions from the segmentation decoder. It must returns a dictionary containing either point, box, or mask prompts in a format compattible withmicro_sam.inference.batched_inference. - optimize_memory: Whether to optimize the memory consumption by merging the per-slice
segmentation results immediatly with NMS, rather than running a NMS for all results.
This may lead to a slightly different segmentation result and is only compatible with
refine_with_box_prompts=Falseandoutput_mode="instance_segmentation".
Returns:
The instance segmentation masks.
1572def get_instance_segmentation_generator( 1573 predictor: SamPredictor, 1574 is_tiled: bool, 1575 decoder: Optional[torch.nn.Module] = None, 1576 segmentation_mode: Optional[str] = None, 1577 **kwargs, 1578) -> Union[AMGBase, InstanceSegmentationWithDecoder]: 1579 f"""Get the automatic mask generator. 1580 1581 Args: 1582 predictor: The segment anything predictor. 1583 is_tiled: Whether tiled embeddings are used. 1584 decoder: Decoder to predict instacne segmmentation. 1585 segmentation_mode: The segmentation mode. One of 'amg', 'ais', or 'apg'. 1586 By default, '{DEFAULT_SEGMENTATION_MODE_WITH_DECODER}' is used if a decoder is passed, 1587 otherwise 'amg' is used. 1588 kwargs: The keyword arguments of the segmentation genetator class. 1589 1590 Returns: 1591 The segmentation generator instance. 1592 """ 1593 # Choose the segmentation decoder default depending on whether we have a decoder. 1594 if segmentation_mode is None: 1595 segmentation_mode = "amg" if decoder is None else DEFAULT_SEGMENTATION_MODE_WITH_DECODER 1596 1597 if segmentation_mode.lower() == "amg": 1598 segmenter_class = TiledAutomaticMaskGenerator if is_tiled else AutomaticMaskGenerator 1599 segmenter = segmenter_class(predictor, **kwargs) 1600 elif segmentation_mode.lower() == "ais": 1601 assert decoder is not None 1602 segmenter_class = TiledInstanceSegmentationWithDecoder if is_tiled else InstanceSegmentationWithDecoder 1603 segmenter = segmenter_class(predictor, decoder, **kwargs) 1604 elif segmentation_mode.lower() == "apg": 1605 assert decoder is not None 1606 segmenter_class = TiledAutomaticPromptGenerator if is_tiled else AutomaticPromptGenerator 1607 segmenter = segmenter_class(predictor, decoder, **kwargs) 1608 else: 1609 raise ValueError(f"Invalid segmentation_mode: {segmentation_mode}. Choose one of 'amg', 'ais', or 'apg'.") 1610 1611 return segmenter