micro_sam.instance_segmentation
Automated instance segmentation functionality. The classes implemented here extend the automatic instance segmentation from Segment Anything: https://computational-cell-analytics.github.io/micro-sam/micro_sam.html
1"""Automated instance segmentation functionality. 2The classes implemented here extend the automatic instance segmentation from Segment Anything: 3https://computational-cell-analytics.github.io/micro-sam/micro_sam.html 4""" 5 6import os 7import warnings 8from abc import ABC 9from copy import deepcopy 10from collections import OrderedDict 11from typing import Any, Dict, List, Optional, Tuple, Union 12 13import vigra 14import numpy as np 15from skimage.measure import label, regionprops 16from skimage.segmentation import relabel_sequential 17 18import torch 19from torchvision.ops.boxes import batched_nms, box_area 20 21from torch_em.model import UNETR 22from torch_em.util.segmentation import watershed_from_center_and_boundary_distances 23 24import elf.parallel as parallel 25from elf.parallel.filters import apply_filter 26 27from nifty.tools import blocking 28 29import segment_anything.utils.amg as amg_utils 30from segment_anything.predictor import SamPredictor 31 32from . import util 33from ._vendored import batched_mask_to_box, mask_to_rle_pytorch 34 35# 36# Utility Functionality 37# 38 39 40class _FakeInput: 41 def __init__(self, shape): 42 self.shape = shape 43 44 def __getitem__(self, index): 45 block_shape = tuple(ind.stop - ind.start for ind in index) 46 return np.zeros(block_shape, dtype="float32") 47 48 49def mask_data_to_segmentation( 50 masks: List[Dict[str, Any]], 51 with_background: bool, 52 min_object_size: int = 0, 53 max_object_size: Optional[int] = None, 54 label_masks: bool = True, 55) -> np.ndarray: 56 """Convert the output of the automatic mask generation to an instance segmentation. 57 58 Args: 59 masks: The outputs generated by AutomaticMaskGenerator or EmbeddingMaskGenerator. 60 Only supports output_mode=binary_mask. 61 with_background: Whether the segmentation has background. If yes this function assures that the largest 62 object in the output will be mapped to zero (the background value). 63 min_object_size: The minimal size of an object in pixels. 64 max_object_size: The maximal size of an object in pixels. 65 label_masks: Whether to apply connected components to the result before removing small objects. 66 67 Returns: 68 The instance segmentation. 69 """ 70 71 masks = sorted(masks, key=(lambda x: x["area"]), reverse=True) 72 # we could also get the shape from the crop box 73 shape = next(iter(masks))["segmentation"].shape 74 segmentation = np.zeros(shape, dtype="uint32") 75 76 def require_numpy(mask): 77 return mask.cpu().numpy() if torch.is_tensor(mask) else mask 78 79 seg_id = 1 80 for mask in masks: 81 if mask["area"] < min_object_size: 82 continue 83 if max_object_size is not None and mask["area"] > max_object_size: 84 continue 85 86 this_seg_id = mask.get("seg_id", seg_id) 87 segmentation[require_numpy(mask["segmentation"])] = this_seg_id 88 seg_id = this_seg_id + 1 89 90 if label_masks: 91 segmentation = label(segmentation).astype(segmentation.dtype) 92 93 seg_ids, sizes = np.unique(segmentation, return_counts=True) 94 95 # In some cases objects may be smaller than peviously calculated, 96 # since they are covered by other objects. We ensure these also get 97 # filtered out here. 98 filter_ids = seg_ids[sizes < min_object_size] 99 100 # If we run segmentation with background we also map the largest segment 101 # (the most likely background object) to zero. This is often zero already, 102 # but it does not hurt to reset that to zero either. 103 if with_background: 104 bg_id = seg_ids[np.argmax(sizes)] 105 filter_ids = np.concatenate([filter_ids, [bg_id]]) 106 107 segmentation[np.isin(segmentation, filter_ids)] = 0 108 segmentation = relabel_sequential(segmentation)[0] 109 110 return segmentation 111 112 113# 114# Classes for automatic instance segmentation 115# 116 117 118class AMGBase(ABC): 119 """Base class for the automatic mask generators. 120 """ 121 def __init__(self): 122 # the state that has to be computed by the 'initialize' method of the child classes 123 self._is_initialized = False 124 self._crop_list = None 125 self._crop_boxes = None 126 self._original_size = None 127 128 @property 129 def is_initialized(self): 130 """Whether the mask generator has already been initialized. 131 """ 132 return self._is_initialized 133 134 @property 135 def crop_list(self): 136 """The list of mask data after initialization. 137 """ 138 return self._crop_list 139 140 @property 141 def crop_boxes(self): 142 """The list of crop boxes. 143 """ 144 return self._crop_boxes 145 146 @property 147 def original_size(self): 148 """The original image size. 149 """ 150 return self._original_size 151 152 def _postprocess_batch( 153 self, 154 data, 155 crop_box, 156 original_size, 157 pred_iou_thresh, 158 stability_score_thresh, 159 box_nms_thresh, 160 ): 161 orig_h, orig_w = original_size 162 163 # filter by predicted IoU 164 if pred_iou_thresh > 0.0: 165 keep_mask = data["iou_preds"] > pred_iou_thresh 166 data.filter(keep_mask) 167 168 # filter by stability score 169 if stability_score_thresh > 0.0: 170 keep_mask = data["stability_score"] >= stability_score_thresh 171 data.filter(keep_mask) 172 173 # filter boxes that touch crop boundaries 174 keep_mask = ~amg_utils.is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) 175 if not torch.all(keep_mask): 176 data.filter(keep_mask) 177 178 # remove duplicates within this crop. 179 keep_by_nms = batched_nms( 180 data["boxes"].float(), 181 data["iou_preds"], 182 torch.zeros_like(data["boxes"][:, 0]), # categories 183 iou_threshold=box_nms_thresh, 184 ) 185 data.filter(keep_by_nms) 186 187 # return to the original image frame 188 data["boxes"] = amg_utils.uncrop_boxes_xyxy(data["boxes"], crop_box) 189 data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) 190 # the data from embedding based segmentation doesn't have the points 191 # so we skip if the corresponding key can't be found 192 try: 193 data["points"] = amg_utils.uncrop_points(data["points"], crop_box) 194 except KeyError: 195 pass 196 197 return data 198 199 def _postprocess_small_regions(self, mask_data, min_area, nms_thresh): 200 201 if len(mask_data["rles"]) == 0: 202 return mask_data 203 204 # filter small disconnected regions and holes 205 new_masks = [] 206 scores = [] 207 for rle in mask_data["rles"]: 208 mask = amg_utils.rle_to_mask(rle) 209 210 mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="holes") 211 unchanged = not changed 212 mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="islands") 213 unchanged = unchanged and not changed 214 215 new_masks.append(torch.as_tensor(mask, dtype=torch.int).unsqueeze(0)) 216 # give score=0 to changed masks and score=1 to unchanged masks 217 # so NMS will prefer ones that didn't need postprocessing 218 scores.append(float(unchanged)) 219 220 # recalculate boxes and remove any new duplicates 221 masks = torch.cat(new_masks, dim=0) 222 boxes = batched_mask_to_box(masks.to(torch.bool)) # Casting this to boolean as we work with one-hot labels. 223 keep_by_nms = batched_nms( 224 boxes.float(), 225 torch.as_tensor(scores, dtype=torch.float), 226 torch.zeros_like(boxes[:, 0]), # categories 227 iou_threshold=nms_thresh, 228 ) 229 230 # only recalculate RLEs for masks that have changed 231 for i_mask in keep_by_nms: 232 if scores[i_mask] == 0.0: 233 mask_torch = masks[i_mask].unsqueeze(0) 234 # mask_data["rles"][i_mask] = amg_utils.mask_to_rle_pytorch(mask_torch)[0] 235 mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] 236 mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly 237 mask_data.filter(keep_by_nms) 238 239 return mask_data 240 241 def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode): 242 # filter small disconnected regions and holes in masks 243 if min_mask_region_area > 0: 244 mask_data = self._postprocess_small_regions( 245 mask_data, 246 min_mask_region_area, 247 max(box_nms_thresh, crop_nms_thresh), 248 ) 249 250 # encode masks 251 if output_mode == "coco_rle": 252 mask_data["segmentations"] = [amg_utils.coco_encode_rle(rle) for rle in mask_data["rles"]] 253 elif output_mode == "binary_mask": 254 mask_data["segmentations"] = [amg_utils.rle_to_mask(rle) for rle in mask_data["rles"]] 255 else: 256 mask_data["segmentations"] = mask_data["rles"] 257 258 # write mask records 259 curr_anns = [] 260 for idx in range(len(mask_data["segmentations"])): 261 ann = { 262 "segmentation": mask_data["segmentations"][idx], 263 "area": amg_utils.area_from_rle(mask_data["rles"][idx]), 264 "bbox": amg_utils.box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), 265 "predicted_iou": mask_data["iou_preds"][idx].item(), 266 "stability_score": mask_data["stability_score"][idx].item(), 267 "crop_box": amg_utils.box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), 268 } 269 # the data from embedding based segmentation doesn't have the points 270 # so we skip if the corresponding key can't be found 271 try: 272 ann["point_coords"] = [mask_data["points"][idx].tolist()] 273 except KeyError: 274 pass 275 curr_anns.append(ann) 276 277 return curr_anns 278 279 def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None): 280 orig_h, orig_w = original_size 281 282 # serialize predictions and store in MaskData 283 data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1)) 284 if points is not None: 285 data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0), dtype=torch.float) 286 287 del masks 288 289 # calculate the stability scores 290 data["stability_score"] = amg_utils.calculate_stability_score( 291 data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset 292 ) 293 294 # threshold masks and calculate boxes 295 data["masks"] = data["masks"] > self._predictor.model.mask_threshold 296 data["masks"] = data["masks"].type(torch.bool) 297 data["boxes"] = batched_mask_to_box(data["masks"]) 298 299 # compress to RLE 300 data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w) 301 # data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"]) 302 data["rles"] = mask_to_rle_pytorch(data["masks"]) 303 del data["masks"] 304 305 return data 306 307 def get_state(self) -> Dict[str, Any]: 308 """Get the initialized state of the mask generator. 309 310 Returns: 311 State of the mask generator. 312 """ 313 if not self.is_initialized: 314 raise RuntimeError("The state has not been computed yet. Call initialize first.") 315 316 return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size} 317 318 def set_state(self, state: Dict[str, Any]) -> None: 319 """Set the state of the mask generator. 320 321 Args: 322 state: The state of the mask generator, e.g. from serialized state. 323 """ 324 self._crop_list = state["crop_list"] 325 self._crop_boxes = state["crop_boxes"] 326 self._original_size = state["original_size"] 327 self._is_initialized = True 328 329 def clear_state(self): 330 """Clear the state of the mask generator. 331 """ 332 self._crop_list = None 333 self._crop_boxes = None 334 self._original_size = None 335 self._is_initialized = False 336 337 338class AutomaticMaskGenerator(AMGBase): 339 """Generates an instance segmentation without prompts, using a point grid. 340 341 This class implements the same logic as 342 https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py 343 It decouples the computationally expensive steps of generating masks from the cheap post-processing operation 344 to filter these masks to enable grid search and interactively changing the post-processing. 345 346 Use this class as follows: 347 ```python 348 amg = AutomaticMaskGenerator(predictor) 349 amg.initialize(image) # Initialize the masks, this takes care of all expensive computations. 350 masks = amg.generate(pred_iou_thresh=0.8) # Generate the masks. This is fast and enables testing parameters 351 ``` 352 353 Args: 354 predictor: The segment anything predictor. 355 points_per_side: The number of points to be sampled along one side of the image. 356 If None, `point_grids` must provide explicit point sampling. 357 points_per_batch: The number of points run simultaneously by the model. 358 Higher numbers may be faster but use more GPU memory. 359 crop_n_layers: If >0, the mask prediction will be run again on crops of the image. 360 crop_overlap_ratio: Sets the degree to which crops overlap. 361 crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops. 362 point_grids: A lisst over explicit grids of points used for sampling masks. 363 Normalized to [0, 1] with respect to the image coordinate system. 364 stability_score_offset: The amount to shift the cutoff when calculating the stability score. 365 """ 366 def __init__( 367 self, 368 predictor: SamPredictor, 369 points_per_side: Optional[int] = 32, 370 points_per_batch: Optional[int] = None, 371 crop_n_layers: int = 0, 372 crop_overlap_ratio: float = 512 / 1500, 373 crop_n_points_downscale_factor: int = 1, 374 point_grids: Optional[List[np.ndarray]] = None, 375 stability_score_offset: float = 1.0, 376 ): 377 super().__init__() 378 379 if points_per_side is not None: 380 self.point_grids = amg_utils.build_all_layer_point_grids( 381 points_per_side, crop_n_layers, crop_n_points_downscale_factor, 382 ) 383 elif point_grids is not None: 384 self.point_grids = point_grids 385 else: 386 raise ValueError("Can't have both points_per_side and point_grid be None or not None.") 387 388 self._predictor = predictor 389 self._points_per_side = points_per_side 390 391 # we set the points per batch to 16 for mps for performance reasons 392 # and otherwise keep them at the default of 64 393 if points_per_batch is None: 394 points_per_batch = 16 if str(predictor.device) == "mps" else 64 395 self._points_per_batch = points_per_batch 396 397 self._crop_n_layers = crop_n_layers 398 self._crop_overlap_ratio = crop_overlap_ratio 399 self._crop_n_points_downscale_factor = crop_n_points_downscale_factor 400 self._stability_score_offset = stability_score_offset 401 402 def _process_batch(self, points, im_size, crop_box, original_size): 403 # run model on this batch 404 transformed_points = self._predictor.transform.apply_coords(points, im_size) 405 in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float) 406 in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) 407 masks, iou_preds, _ = self._predictor.predict_torch( 408 point_coords=in_points[:, None, :], 409 point_labels=in_labels[:, None], 410 multimask_output=True, 411 return_logits=True, 412 ) 413 data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points) 414 del masks 415 return data 416 417 def _process_crop(self, image, crop_box, crop_layer_idx, precomputed_embeddings, pbar_init=None, pbar_update=None): 418 # Crop the image and calculate embeddings. 419 x0, y0, x1, y1 = crop_box 420 cropped_im = image[y0:y1, x0:x1, :] 421 cropped_im_size = cropped_im.shape[:2] 422 423 if not precomputed_embeddings: 424 self._predictor.set_image(cropped_im) 425 426 # Get the points for this crop. 427 points_scale = np.array(cropped_im_size)[None, ::-1] 428 points_for_image = self.point_grids[crop_layer_idx] * points_scale 429 430 # Generate masks for this crop in batches. 431 data = amg_utils.MaskData() 432 n_batches = len(points_for_image) // self._points_per_batch +\ 433 int(len(points_for_image) % self._points_per_batch != 0) 434 if pbar_init is not None: 435 pbar_init(n_batches, "Predict masks for point grid prompts") 436 437 for (points,) in amg_utils.batch_iterator(self._points_per_batch, points_for_image): 438 batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size) 439 data.cat(batch_data) 440 del batch_data 441 if pbar_update is not None: 442 pbar_update(1) 443 444 if not precomputed_embeddings: 445 self._predictor.reset_image() 446 447 return data 448 449 @torch.no_grad() 450 def initialize( 451 self, 452 image: np.ndarray, 453 image_embeddings: Optional[util.ImageEmbeddings] = None, 454 i: Optional[int] = None, 455 verbose: bool = False, 456 pbar_init: Optional[callable] = None, 457 pbar_update: Optional[callable] = None, 458 ) -> None: 459 """Initialize image embeddings and masks for an image. 460 461 Args: 462 image: The input image, volume or timeseries. 463 image_embeddings: Optional precomputed image embeddings. 464 See `util.precompute_image_embeddings` for details. 465 i: Index for the image data. Required if `image` has three spatial dimensions 466 or a time dimension and two spatial dimensions. 467 verbose: Whether to print computation progress. 468 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 469 Can be used together with pbar_update to handle napari progress bar in other thread. 470 To enables using this function within a threadworker. 471 pbar_update: Callback to update an external progress bar. 472 """ 473 original_size = image.shape[:2] 474 self._original_size = original_size 475 476 crop_boxes, layer_idxs = amg_utils.generate_crop_boxes( 477 original_size, self._crop_n_layers, self._crop_overlap_ratio 478 ) 479 480 # We can set fixed image embeddings if we only have a single crop box (the default setting). 481 # Otherwise we have to recompute the embeddings for each crop and can't precompute. 482 if len(crop_boxes) == 1: 483 if image_embeddings is None: 484 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 485 util.set_precomputed(self._predictor, image_embeddings, i=i) 486 precomputed_embeddings = True 487 else: 488 precomputed_embeddings = False 489 490 # we need to cast to the image representation that is compatible with SAM 491 image = util._to_image(image) 492 493 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 494 495 crop_list = [] 496 for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 497 crop_data = self._process_crop( 498 image, crop_box, layer_idx, 499 precomputed_embeddings=precomputed_embeddings, 500 pbar_init=pbar_init, pbar_update=pbar_update, 501 ) 502 crop_list.append(crop_data) 503 pbar_close() 504 505 self._is_initialized = True 506 self._crop_list = crop_list 507 self._crop_boxes = crop_boxes 508 509 @torch.no_grad() 510 def generate( 511 self, 512 pred_iou_thresh: float = 0.88, 513 stability_score_thresh: float = 0.95, 514 box_nms_thresh: float = 0.7, 515 crop_nms_thresh: float = 0.7, 516 min_mask_region_area: int = 0, 517 output_mode: str = "binary_mask", 518 ) -> List[Dict[str, Any]]: 519 """Generate instance segmentation for the currently initialized image. 520 521 Args: 522 pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model. 523 stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask 524 under changes to the cutoff used to binarize the model prediction. 525 box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks. 526 crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops. 527 min_mask_region_area: Minimal size for the predicted masks. 528 output_mode: The form masks are returned in. 529 530 Returns: 531 The instance segmentation masks. 532 """ 533 if not self.is_initialized: 534 raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.") 535 536 data = amg_utils.MaskData() 537 for data_, crop_box in zip(self.crop_list, self.crop_boxes): 538 crop_data = self._postprocess_batch( 539 data=deepcopy(data_), 540 crop_box=crop_box, original_size=self.original_size, 541 pred_iou_thresh=pred_iou_thresh, 542 stability_score_thresh=stability_score_thresh, 543 box_nms_thresh=box_nms_thresh 544 ) 545 data.cat(crop_data) 546 547 if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0: 548 # Prefer masks from smaller crops 549 scores = 1 / box_area(data["crop_boxes"]) 550 scores = scores.to(data["boxes"].device) 551 keep_by_nms = batched_nms( 552 data["boxes"].float(), 553 scores, 554 torch.zeros_like(data["boxes"][:, 0]), # categories 555 iou_threshold=crop_nms_thresh, 556 ) 557 data.filter(keep_by_nms) 558 559 data.to_numpy() 560 masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode) 561 return masks 562 563 564# Helper function for tiled embedding computation and checking consistent state. 565def _process_tiled_embeddings(predictor, image, image_embeddings, tile_shape, halo, verbose, batch_size): 566 if image_embeddings is None: 567 if tile_shape is None or halo is None: 568 raise ValueError("To compute tiled embeddings the parameters tile_shape and halo have to be passed.") 569 image_embeddings = util.precompute_image_embeddings( 570 predictor, image, tile_shape=tile_shape, halo=halo, verbose=verbose, batch_size=batch_size 571 ) 572 573 # Use tile shape and halo from the precomputed embeddings if not given. 574 # Otherwise check that they are consistent. 575 feats = image_embeddings["features"] 576 tile_shape_, halo_ = tuple(feats.attrs["tile_shape"]), tuple(feats.attrs["halo"]) 577 if tile_shape is None: 578 tile_shape = tile_shape_ 579 elif tile_shape != tile_shape_: 580 raise ValueError( 581 f"Inconsistent tile_shape parameter {tile_shape} with precomputed embeedings: {tile_shape_}." 582 ) 583 if halo is None: 584 halo = halo_ 585 elif halo != halo_: 586 raise ValueError(f"Inconsistent halo parameter {halo} with precomputed embeedings: {halo_}.") 587 588 return image_embeddings, tile_shape, halo 589 590 591class TiledAutomaticMaskGenerator(AutomaticMaskGenerator): 592 """Generates an instance segmentation without prompts, using a point grid. 593 594 Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings. 595 596 Args: 597 predictor: The segment anything predictor. 598 points_per_side: The number of points to be sampled along one side of the image. 599 If None, `point_grids` must provide explicit point sampling. 600 points_per_batch: The number of points run simultaneously by the model. 601 Higher numbers may be faster but use more GPU memory. 602 point_grids: A lisst over explicit grids of points used for sampling masks. 603 Normalized to [0, 1] with respect to the image coordinate system. 604 stability_score_offset: The amount to shift the cutoff when calculating the stability score. 605 """ 606 607 # We only expose the arguments that make sense for the tiled mask generator. 608 # Anything related to crops doesn't make sense, because we re-use that functionality 609 # for tiling, so these parameters wouldn't have any effect. 610 def __init__( 611 self, 612 predictor: SamPredictor, 613 points_per_side: Optional[int] = 32, 614 points_per_batch: int = 64, 615 point_grids: Optional[List[np.ndarray]] = None, 616 stability_score_offset: float = 1.0, 617 ) -> None: 618 super().__init__( 619 predictor=predictor, 620 points_per_side=points_per_side, 621 points_per_batch=points_per_batch, 622 point_grids=point_grids, 623 stability_score_offset=stability_score_offset, 624 ) 625 626 @torch.no_grad() 627 def initialize( 628 self, 629 image: np.ndarray, 630 image_embeddings: Optional[util.ImageEmbeddings] = None, 631 i: Optional[int] = None, 632 tile_shape: Optional[Tuple[int, int]] = None, 633 halo: Optional[Tuple[int, int]] = None, 634 verbose: bool = False, 635 pbar_init: Optional[callable] = None, 636 pbar_update: Optional[callable] = None, 637 batch_size: int = 1, 638 ) -> None: 639 """Initialize image embeddings and masks for an image. 640 641 Args: 642 image: The input image, volume or timeseries. 643 image_embeddings: Optional precomputed image embeddings. 644 See `util.precompute_image_embeddings` for details. 645 i: Index for the image data. Required if `image` has three spatial dimensions 646 or a time dimension and two spatial dimensions. 647 tile_shape: The tile shape for embedding prediction. 648 halo: The overlap of between tiles. 649 verbose: Whether to print computation progress. 650 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 651 Can be used together with pbar_update to handle napari progress bar in other thread. 652 To enables using this function within a threadworker. 653 pbar_update: Callback to update an external progress bar. 654 batch_size: The batch size for image embedding prediction. 655 """ 656 original_size = image.shape[:2] 657 self._original_size = original_size 658 659 image_embeddings, tile_shape, halo = _process_tiled_embeddings( 660 self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose, batch_size=batch_size 661 ) 662 663 tiling = blocking([0, 0], original_size, tile_shape) 664 n_tiles = tiling.numberOfBlocks 665 666 # The crop box is always the full local tile. 667 tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)] 668 crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles] 669 670 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 671 pbar_init(n_tiles, "Compute masks for tile") 672 673 # We need to cast to the image representation that is compatible with SAM. 674 image = util._to_image(image) 675 676 mask_data = [] 677 for tile_id in range(n_tiles): 678 # set the pre-computed embeddings for this tile 679 features = image_embeddings["features"][tile_id] 680 tile_embeddings = { 681 "features": features, 682 "input_size": features.attrs["input_size"], 683 "original_size": features.attrs["original_size"], 684 } 685 util.set_precomputed(self._predictor, tile_embeddings, i) 686 687 # compute the mask data for this tile and append it 688 this_mask_data = self._process_crop( 689 image, crop_box=crop_boxes[tile_id], crop_layer_idx=0, precomputed_embeddings=True 690 ) 691 mask_data.append(this_mask_data) 692 pbar_update(1) 693 pbar_close() 694 695 # set the initialized data 696 self._is_initialized = True 697 self._crop_list = mask_data 698 self._crop_boxes = crop_boxes 699 700 701# 702# Instance segmentation functionality based on fine-tuned decoder 703# 704 705 706class DecoderAdapter(torch.nn.Module): 707 """Adapter to contain the UNETR decoder in a single module. 708 709 To apply the decoder on top of pre-computed embeddings for the segmentation functionality. 710 See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py 711 """ 712 def __init__(self, unetr): 713 super().__init__() 714 715 self.base = unetr.base 716 self.out_conv = unetr.out_conv 717 self.deconv_out = unetr.deconv_out 718 self.decoder_head = unetr.decoder_head 719 self.final_activation = unetr.final_activation 720 self.postprocess_masks = unetr.postprocess_masks 721 722 self.decoder = unetr.decoder 723 self.deconv1 = unetr.deconv1 724 self.deconv2 = unetr.deconv2 725 self.deconv3 = unetr.deconv3 726 self.deconv4 = unetr.deconv4 727 728 def _forward_impl(self, input_): 729 z12 = input_ 730 731 z9 = self.deconv1(z12) 732 z6 = self.deconv2(z9) 733 z3 = self.deconv3(z6) 734 z0 = self.deconv4(z3) 735 736 updated_from_encoder = [z9, z6, z3] 737 738 x = self.base(z12) 739 x = self.decoder(x, encoder_inputs=updated_from_encoder) 740 x = self.deconv_out(x) 741 742 x = torch.cat([x, z0], dim=1) 743 x = self.decoder_head(x) 744 745 x = self.out_conv(x) 746 if self.final_activation is not None: 747 x = self.final_activation(x) 748 return x 749 750 def forward(self, input_, input_shape, original_shape): 751 x = self._forward_impl(input_) 752 x = self.postprocess_masks(x, input_shape, original_shape) 753 return x 754 755 756def get_unetr( 757 image_encoder: torch.nn.Module, 758 decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None, 759 device: Optional[Union[str, torch.device]] = None, 760 out_channels: int = 3, 761 flexible_load_checkpoint: bool = False, 762) -> torch.nn.Module: 763 """Get UNETR model for automatic instance segmentation. 764 765 Args: 766 image_encoder: The image encoder of the SAM model. 767 This is used as encoder by the UNETR too. 768 decoder_state: Optional decoder state to initialize the weights of the UNETR decoder. 769 device: The device. 770 out_channels: The number of output channels. 771 flexible_load_checkpoint: Whether to allow reinitialization of parameters 772 which could not be found in the provided decoder state. 773 774 Returns: 775 The UNETR model. 776 """ 777 device = util.get_device(device) 778 779 unetr = UNETR( 780 backbone="sam", 781 encoder=image_encoder, 782 out_channels=out_channels, 783 use_sam_stats=True, 784 final_activation="Sigmoid", 785 use_skip_connection=False, 786 resize_input=True, 787 ) 788 if decoder_state is not None: 789 unetr_state_dict = unetr.state_dict() 790 for k, v in unetr_state_dict.items(): 791 if not k.startswith("encoder"): 792 if flexible_load_checkpoint: # Whether allow reinitalization of params, if not found. 793 if k in decoder_state: # First check whether the key is available in the provided decoder state. 794 unetr_state_dict[k] = decoder_state[k] 795 else: # Otherwise, allow it to initialize it. 796 warnings.warn(f"Could not find '{k}' in the pretrained state dict. Hence, we reinitialize it.") 797 unetr_state_dict[k] = v 798 799 else: # Whether be strict on finding the parameter in the decoder state. 800 if k not in decoder_state: 801 raise RuntimeError(f"The parameters for '{k}' could not be found.") 802 unetr_state_dict[k] = decoder_state[k] 803 804 unetr.load_state_dict(unetr_state_dict) 805 806 unetr.to(device) 807 return unetr 808 809 810def get_decoder( 811 image_encoder: torch.nn.Module, 812 decoder_state: OrderedDict[str, torch.Tensor], 813 device: Optional[Union[str, torch.device]] = None, 814) -> DecoderAdapter: 815 """Get decoder to predict outputs for automatic instance segmentation 816 817 Args: 818 image_encoder: The image encoder of the SAM model. 819 decoder_state: State to initialize the weights of the UNETR decoder. 820 device: The device. 821 822 Returns: 823 The decoder for instance segmentation. 824 """ 825 unetr = get_unetr(image_encoder, decoder_state, device) 826 return DecoderAdapter(unetr) 827 828 829def get_predictor_and_decoder( 830 model_type: str, 831 checkpoint_path: Union[str, os.PathLike], 832 device: Optional[Union[str, torch.device]] = None, 833 peft_kwargs: Optional[Dict] = None, 834) -> Tuple[SamPredictor, DecoderAdapter]: 835 """Load the SAM model (predictor) and instance segmentation decoder. 836 837 This requires a checkpoint that contains the state for both predictor 838 and decoder. 839 840 Args: 841 model_type: The type of the image encoder used in the SAM model. 842 checkpoint_path: Path to the checkpoint from which to load the data. 843 device: The device. 844 peft_kwargs: Keyword arguments for the PEFT wrapper class. 845 846 Returns: 847 The SAM predictor. 848 The decoder for instance segmentation. 849 """ 850 device = util.get_device(device) 851 predictor, state = util.get_sam_model( 852 model_type=model_type, 853 checkpoint_path=checkpoint_path, 854 device=device, 855 return_state=True, 856 peft_kwargs=peft_kwargs, 857 ) 858 if "decoder_state" not in state: 859 raise ValueError( 860 f"The checkpoint at '{checkpoint_path}' or the chosen model '{model_type}' does not contain a decoder state" 861 ) 862 decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device) 863 return predictor, decoder 864 865 866def _watershed_from_center_and_boundary_distances_parallel( 867 center_distances, 868 boundary_distances, 869 foreground_map, 870 center_distance_threshold, 871 boundary_distance_threshold, 872 foreground_threshold, 873 distance_smoothing, 874 min_size, 875 tile_shape, 876 halo, 877 n_threads, 878 verbose=False, 879): 880 center_distances = apply_filter( 881 center_distances, "gaussianSmoothing", sigma=distance_smoothing, 882 block_shape=tile_shape, n_threads=n_threads 883 ) 884 boundary_distances = apply_filter( 885 boundary_distances, "gaussianSmoothing", sigma=distance_smoothing, 886 block_shape=tile_shape, n_threads=n_threads 887 ) 888 889 fg_mask = foreground_map > foreground_threshold 890 891 marker_map = np.logical_and( 892 center_distances < center_distance_threshold, boundary_distances < boundary_distance_threshold 893 ) 894 marker_map[~fg_mask] = 0 895 896 markers = np.zeros(marker_map.shape, dtype="uint64") 897 markers = parallel.label( 898 marker_map, out=markers, block_shape=tile_shape, n_threads=n_threads, verbose=verbose, 899 ) 900 901 seg = np.zeros_like(markers, dtype="uint64") 902 seg = parallel.seeded_watershed( 903 boundary_distances, seeds=markers, out=seg, block_shape=tile_shape, 904 halo=halo, n_threads=n_threads, verbose=verbose, mask=fg_mask, 905 ) 906 907 out = np.zeros_like(seg, dtype="uint64") 908 out = parallel.size_filter( 909 seg, out=out, min_size=min_size, block_shape=tile_shape, n_threads=n_threads, verbose=verbose 910 ) 911 912 return out 913 914 915class InstanceSegmentationWithDecoder: 916 """Generates an instance segmentation without prompts, using a decoder. 917 918 Implements the same interface as `AutomaticMaskGenerator`. 919 920 Use this class as follows: 921 ```python 922 segmenter = InstanceSegmentationWithDecoder(predictor, decoder) 923 segmenter.initialize(image) # Predict the image embeddings and decoder outputs. 924 masks = segmenter.generate(center_distance_threshold=0.75) # Generate the instance segmentation. 925 ``` 926 927 Args: 928 predictor: The segment anything predictor. 929 decoder: The decoder to predict intermediate representations 930 for instance segmentation. 931 """ 932 def __init__( 933 self, 934 predictor: SamPredictor, 935 decoder: torch.nn.Module, 936 ) -> None: 937 self._predictor = predictor 938 self._decoder = decoder 939 940 # The decoder outputs. 941 self._foreground = None 942 self._center_distances = None 943 self._boundary_distances = None 944 945 self._is_initialized = False 946 947 @property 948 def is_initialized(self): 949 """Whether the mask generator has already been initialized. 950 """ 951 return self._is_initialized 952 953 @torch.no_grad() 954 def initialize( 955 self, 956 image: np.ndarray, 957 image_embeddings: Optional[util.ImageEmbeddings] = None, 958 i: Optional[int] = None, 959 verbose: bool = False, 960 pbar_init: Optional[callable] = None, 961 pbar_update: Optional[callable] = None, 962 ) -> None: 963 """Initialize image embeddings and decoder predictions for an image. 964 965 Args: 966 image: The input image, volume or timeseries. 967 image_embeddings: Optional precomputed image embeddings. 968 See `util.precompute_image_embeddings` for details. 969 i: Index for the image data. Required if `image` has three spatial dimensions 970 or a time dimension and two spatial dimensions. 971 verbose: Whether to be verbose. 972 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 973 Can be used together with pbar_update to handle napari progress bar in other thread. 974 To enables using this function within a threadworker. 975 pbar_update: Callback to update an external progress bar. 976 """ 977 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 978 pbar_init(1, "Initialize instance segmentation with decoder") 979 980 if image_embeddings is None: 981 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 982 983 # Get the image embeddings from the predictor. 984 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i) 985 embeddings = self._predictor.features 986 input_shape = tuple(self._predictor.input_size) 987 original_shape = tuple(self._predictor.original_size) 988 989 # Run prediction with the UNETR decoder. 990 output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0) 991 assert output.shape[0] == 3, f"{output.shape}" 992 pbar_update(1) 993 pbar_close() 994 995 # Set the state. 996 self._foreground = output[0] 997 self._center_distances = output[1] 998 self._boundary_distances = output[2] 999 self._is_initialized = True 1000 1001 def _to_masks(self, segmentation, output_mode): 1002 if output_mode != "binary_mask": 1003 raise NotImplementedError 1004 1005 props = regionprops(segmentation) 1006 ndim = segmentation.ndim 1007 assert ndim in (2, 3) 1008 1009 shape = segmentation.shape 1010 if ndim == 2: 1011 crop_box = [0, shape[1], 0, shape[0]] 1012 else: 1013 crop_box = [0, shape[2], 0, shape[1], 0, shape[0]] 1014 1015 # go from skimage bbox in format [y0, x0, y1, x1] to SAM format [x0, w, y0, h] 1016 def to_bbox_2d(bbox): 1017 y0, x0 = bbox[0], bbox[1] 1018 w = bbox[3] - x0 1019 h = bbox[2] - y0 1020 return [x0, w, y0, h] 1021 1022 def to_bbox_3d(bbox): 1023 z0, y0, x0 = bbox[0], bbox[1], bbox[2] 1024 w = bbox[5] - x0 1025 h = bbox[4] - y0 1026 d = bbox[3] - y0 1027 return [x0, w, y0, h, z0, d] 1028 1029 to_bbox = to_bbox_2d if ndim == 2 else to_bbox_3d 1030 masks = [ 1031 { 1032 "segmentation": segmentation == prop.label, 1033 "area": prop.area, 1034 "bbox": to_bbox(prop.bbox), 1035 "crop_box": crop_box, 1036 "seg_id": prop.label, 1037 } for prop in props 1038 ] 1039 return masks 1040 1041 def generate( 1042 self, 1043 center_distance_threshold: float = 0.5, 1044 boundary_distance_threshold: float = 0.5, 1045 foreground_threshold: float = 0.5, 1046 foreground_smoothing: float = 1.0, 1047 distance_smoothing: float = 1.6, 1048 min_size: int = 0, 1049 output_mode: Optional[str] = "binary_mask", 1050 tile_shape: Optional[Tuple[int, int]] = None, 1051 halo: Optional[Tuple[int, int]] = None, 1052 n_threads: Optional[int] = None, 1053 ) -> List[Dict[str, Any]]: 1054 """Generate instance segmentation for the currently initialized image. 1055 1056 Args: 1057 center_distance_threshold: Center distance predictions below this value will be 1058 used to find seeds (intersected with thresholded boundary distance predictions). 1059 boundary_distance_threshold: Boundary distance predictions below this value will be 1060 used to find seeds (intersected with thresholded center distance predictions). 1061 foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid 1062 checkerboard artifacts in the prediction. 1063 foreground_threshold: Foreground predictions above this value will be used as foreground mask. 1064 distance_smoothing: Sigma value for smoothing the distance predictions. 1065 min_size: Minimal object size in the segmentation result. 1066 output_mode: The form masks are returned in. Pass None to directly return the instance segmentation. 1067 tile_shape: Tile shape for parallelizing the instance segmentation post-processing. 1068 This parameter is independent from the tile shape for computing the embeddings. 1069 If not given then post-processing will not be parallelized. 1070 halo: Halo for parallel post-processing. See also `tile_shape`. 1071 n_threads: Number of threads for parallel post-processing. See also `tile_shape`. 1072 1073 Returns: 1074 The instance segmentation masks. 1075 """ 1076 if not self.is_initialized: 1077 raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.") 1078 1079 if foreground_smoothing > 0: 1080 foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing) 1081 else: 1082 foreground = self._foreground 1083 1084 if tile_shape is None: 1085 segmentation = watershed_from_center_and_boundary_distances( 1086 center_distances=self._center_distances, 1087 boundary_distances=self._boundary_distances, 1088 foreground_map=foreground, 1089 center_distance_threshold=center_distance_threshold, 1090 boundary_distance_threshold=boundary_distance_threshold, 1091 foreground_threshold=foreground_threshold, 1092 distance_smoothing=distance_smoothing, 1093 min_size=min_size, 1094 ) 1095 else: 1096 if halo is None: 1097 raise ValueError("You must pass a value for halo if tile_shape is given.") 1098 segmentation = _watershed_from_center_and_boundary_distances_parallel( 1099 center_distances=self._center_distances, 1100 boundary_distances=self._boundary_distances, 1101 foreground_map=foreground, 1102 center_distance_threshold=center_distance_threshold, 1103 boundary_distance_threshold=boundary_distance_threshold, 1104 foreground_threshold=foreground_threshold, 1105 distance_smoothing=distance_smoothing, 1106 min_size=min_size, 1107 tile_shape=tile_shape, 1108 halo=halo, 1109 n_threads=n_threads, 1110 verbose=False, 1111 ) 1112 1113 if output_mode is not None: 1114 segmentation = self._to_masks(segmentation, output_mode) 1115 return segmentation 1116 1117 def get_state(self) -> Dict[str, Any]: 1118 """Get the initialized state of the instance segmenter. 1119 1120 Returns: 1121 Instance segmentation state. 1122 """ 1123 if not self.is_initialized: 1124 raise RuntimeError("The state has not been computed yet. Call initialize first.") 1125 1126 return { 1127 "foreground": self._foreground, 1128 "center_distances": self._center_distances, 1129 "boundary_distances": self._boundary_distances, 1130 } 1131 1132 def set_state(self, state: Dict[str, Any]) -> None: 1133 """Set the state of the instance segmenter. 1134 1135 Args: 1136 state: The instance segmentation state 1137 """ 1138 self._foreground = state["foreground"] 1139 self._center_distances = state["center_distances"] 1140 self._boundary_distances = state["boundary_distances"] 1141 self._is_initialized = True 1142 1143 def clear_state(self): 1144 """Clear the state of the instance segmenter. 1145 """ 1146 self._foreground = None 1147 self._center_distances = None 1148 self._boundary_distances = None 1149 self._is_initialized = False 1150 1151 1152class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder): 1153 """Same as `InstanceSegmentationWithDecoder` but for tiled image embeddings. 1154 """ 1155 1156 # Apply the decoder in a batched fashion, and then perform the resizing independently per output. 1157 # This is necessary, because the individual tiles may have different tile shapes due to border tiles. 1158 def _predict_decoder(self, batched_embeddings, input_shapes, original_shapes): 1159 batched_embeddings = torch.cat(batched_embeddings) 1160 output = self._decoder._forward_impl(batched_embeddings) 1161 1162 batched_output = [] 1163 for x, input_shape, original_shape in zip(output, input_shapes, original_shapes): 1164 x = self._decoder.postprocess_masks(x.unsqueeze(0), input_shape, original_shape).squeeze(0) 1165 batched_output.append(x.cpu().numpy()) 1166 return batched_output 1167 1168 @torch.no_grad() 1169 def initialize( 1170 self, 1171 image: np.ndarray, 1172 image_embeddings: Optional[util.ImageEmbeddings] = None, 1173 i: Optional[int] = None, 1174 tile_shape: Optional[Tuple[int, int]] = None, 1175 halo: Optional[Tuple[int, int]] = None, 1176 verbose: bool = False, 1177 pbar_init: Optional[callable] = None, 1178 pbar_update: Optional[callable] = None, 1179 batch_size: int = 1, 1180 ) -> None: 1181 """Initialize image embeddings and decoder predictions for an image. 1182 1183 Args: 1184 image: The input image, volume or timeseries. 1185 image_embeddings: Optional precomputed image embeddings. 1186 See `util.precompute_image_embeddings` for details. 1187 i: Index for the image data. Required if `image` has three spatial dimensions 1188 or a time dimension and two spatial dimensions. 1189 tile_shape: Shape of the tiles for precomputing image embeddings. 1190 halo: Overlap of the tiles for tiled precomputation of image embeddings. 1191 verbose: Dummy input to be compatible with other function signatures. 1192 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 1193 Can be used together with pbar_update to handle napari progress bar in other thread. 1194 To enables using this function within a threadworker. 1195 pbar_update: Callback to update an external progress bar. 1196 batch_size: The batch size for image embedding computation and segmentation decoder prediction. 1197 """ 1198 original_size = image.shape[:2] 1199 image_embeddings, tile_shape, halo = _process_tiled_embeddings( 1200 self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose, batch_size=batch_size 1201 ) 1202 tiling = blocking([0, 0], original_size, tile_shape) 1203 1204 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 1205 pbar_init(tiling.numberOfBlocks, "Initialize tiled instance segmentation with decoder") 1206 1207 foreground = np.zeros(original_size, dtype="float32") 1208 center_distances = np.zeros(original_size, dtype="float32") 1209 boundary_distances = np.zeros(original_size, dtype="float32") 1210 1211 n_tiles = tiling.numberOfBlocks 1212 n_batches = int(np.ceil(n_tiles / batch_size)) 1213 1214 for batch_id in range(n_batches): 1215 tile_start = batch_id * batch_size 1216 tile_stop = min(tile_start + batch_size, n_tiles) 1217 1218 batched_embeddings, input_shapes, original_shapes = [], [], [] 1219 for tile_id in range(tile_start, tile_stop): 1220 # Get the image embeddings from the predictor for this tile. 1221 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i, tile_id=tile_id) 1222 1223 batched_embeddings.append(self._predictor.features) 1224 input_shapes.append(tuple(self._predictor.input_size)) 1225 original_shapes.append(tuple(self._predictor.original_size)) 1226 1227 batched_output = self._predict_decoder(batched_embeddings, input_shapes, original_shapes) 1228 1229 for output_id, tile_id in enumerate(range(tile_start, tile_stop)): 1230 output = batched_output[output_id] 1231 assert output.shape[0] == 3 1232 1233 # Set the predictions in the output for this tile. 1234 block = tiling.getBlockWithHalo(tile_id, halo=list(halo)) 1235 local_bb = tuple( 1236 slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end) 1237 ) 1238 inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end)) 1239 1240 foreground[inner_bb] = output[0][local_bb] 1241 center_distances[inner_bb] = output[1][local_bb] 1242 boundary_distances[inner_bb] = output[2][local_bb] 1243 pbar_update(1) 1244 1245 pbar_close() 1246 1247 # Set the state. 1248 self._foreground = foreground 1249 self._center_distances = center_distances 1250 self._boundary_distances = boundary_distances 1251 self._is_initialized = True 1252 1253 1254def get_amg( 1255 predictor: SamPredictor, is_tiled: bool, decoder: Optional[torch.nn.Module] = None, **kwargs, 1256) -> Union[AMGBase, InstanceSegmentationWithDecoder]: 1257 """Get the automatic mask generator class. 1258 1259 Args: 1260 predictor: The segment anything predictor. 1261 is_tiled: Whether tiled embeddings are used. 1262 decoder: Decoder to predict instacne segmmentation. 1263 kwargs: The keyword arguments for the amg class. 1264 1265 Returns: 1266 The automatic mask generator. 1267 """ 1268 if decoder is None: 1269 segmenter_class = TiledAutomaticMaskGenerator if is_tiled else AutomaticMaskGenerator 1270 segmenter = segmenter_class(predictor, **kwargs) 1271 else: 1272 segmenter_class = TiledInstanceSegmentationWithDecoder if is_tiled else InstanceSegmentationWithDecoder 1273 segmenter = segmenter_class(predictor, decoder, **kwargs) 1274 1275 return segmenter
50def mask_data_to_segmentation( 51 masks: List[Dict[str, Any]], 52 with_background: bool, 53 min_object_size: int = 0, 54 max_object_size: Optional[int] = None, 55 label_masks: bool = True, 56) -> np.ndarray: 57 """Convert the output of the automatic mask generation to an instance segmentation. 58 59 Args: 60 masks: The outputs generated by AutomaticMaskGenerator or EmbeddingMaskGenerator. 61 Only supports output_mode=binary_mask. 62 with_background: Whether the segmentation has background. If yes this function assures that the largest 63 object in the output will be mapped to zero (the background value). 64 min_object_size: The minimal size of an object in pixels. 65 max_object_size: The maximal size of an object in pixels. 66 label_masks: Whether to apply connected components to the result before removing small objects. 67 68 Returns: 69 The instance segmentation. 70 """ 71 72 masks = sorted(masks, key=(lambda x: x["area"]), reverse=True) 73 # we could also get the shape from the crop box 74 shape = next(iter(masks))["segmentation"].shape 75 segmentation = np.zeros(shape, dtype="uint32") 76 77 def require_numpy(mask): 78 return mask.cpu().numpy() if torch.is_tensor(mask) else mask 79 80 seg_id = 1 81 for mask in masks: 82 if mask["area"] < min_object_size: 83 continue 84 if max_object_size is not None and mask["area"] > max_object_size: 85 continue 86 87 this_seg_id = mask.get("seg_id", seg_id) 88 segmentation[require_numpy(mask["segmentation"])] = this_seg_id 89 seg_id = this_seg_id + 1 90 91 if label_masks: 92 segmentation = label(segmentation).astype(segmentation.dtype) 93 94 seg_ids, sizes = np.unique(segmentation, return_counts=True) 95 96 # In some cases objects may be smaller than peviously calculated, 97 # since they are covered by other objects. We ensure these also get 98 # filtered out here. 99 filter_ids = seg_ids[sizes < min_object_size] 100 101 # If we run segmentation with background we also map the largest segment 102 # (the most likely background object) to zero. This is often zero already, 103 # but it does not hurt to reset that to zero either. 104 if with_background: 105 bg_id = seg_ids[np.argmax(sizes)] 106 filter_ids = np.concatenate([filter_ids, [bg_id]]) 107 108 segmentation[np.isin(segmentation, filter_ids)] = 0 109 segmentation = relabel_sequential(segmentation)[0] 110 111 return segmentation
Convert the output of the automatic mask generation to an instance segmentation.
Arguments:
- masks: The outputs generated by AutomaticMaskGenerator or EmbeddingMaskGenerator. Only supports output_mode=binary_mask.
- with_background: Whether the segmentation has background. If yes this function assures that the largest object in the output will be mapped to zero (the background value).
- min_object_size: The minimal size of an object in pixels.
- max_object_size: The maximal size of an object in pixels.
- label_masks: Whether to apply connected components to the result before removing small objects.
Returns:
The instance segmentation.
119class AMGBase(ABC): 120 """Base class for the automatic mask generators. 121 """ 122 def __init__(self): 123 # the state that has to be computed by the 'initialize' method of the child classes 124 self._is_initialized = False 125 self._crop_list = None 126 self._crop_boxes = None 127 self._original_size = None 128 129 @property 130 def is_initialized(self): 131 """Whether the mask generator has already been initialized. 132 """ 133 return self._is_initialized 134 135 @property 136 def crop_list(self): 137 """The list of mask data after initialization. 138 """ 139 return self._crop_list 140 141 @property 142 def crop_boxes(self): 143 """The list of crop boxes. 144 """ 145 return self._crop_boxes 146 147 @property 148 def original_size(self): 149 """The original image size. 150 """ 151 return self._original_size 152 153 def _postprocess_batch( 154 self, 155 data, 156 crop_box, 157 original_size, 158 pred_iou_thresh, 159 stability_score_thresh, 160 box_nms_thresh, 161 ): 162 orig_h, orig_w = original_size 163 164 # filter by predicted IoU 165 if pred_iou_thresh > 0.0: 166 keep_mask = data["iou_preds"] > pred_iou_thresh 167 data.filter(keep_mask) 168 169 # filter by stability score 170 if stability_score_thresh > 0.0: 171 keep_mask = data["stability_score"] >= stability_score_thresh 172 data.filter(keep_mask) 173 174 # filter boxes that touch crop boundaries 175 keep_mask = ~amg_utils.is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) 176 if not torch.all(keep_mask): 177 data.filter(keep_mask) 178 179 # remove duplicates within this crop. 180 keep_by_nms = batched_nms( 181 data["boxes"].float(), 182 data["iou_preds"], 183 torch.zeros_like(data["boxes"][:, 0]), # categories 184 iou_threshold=box_nms_thresh, 185 ) 186 data.filter(keep_by_nms) 187 188 # return to the original image frame 189 data["boxes"] = amg_utils.uncrop_boxes_xyxy(data["boxes"], crop_box) 190 data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) 191 # the data from embedding based segmentation doesn't have the points 192 # so we skip if the corresponding key can't be found 193 try: 194 data["points"] = amg_utils.uncrop_points(data["points"], crop_box) 195 except KeyError: 196 pass 197 198 return data 199 200 def _postprocess_small_regions(self, mask_data, min_area, nms_thresh): 201 202 if len(mask_data["rles"]) == 0: 203 return mask_data 204 205 # filter small disconnected regions and holes 206 new_masks = [] 207 scores = [] 208 for rle in mask_data["rles"]: 209 mask = amg_utils.rle_to_mask(rle) 210 211 mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="holes") 212 unchanged = not changed 213 mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="islands") 214 unchanged = unchanged and not changed 215 216 new_masks.append(torch.as_tensor(mask, dtype=torch.int).unsqueeze(0)) 217 # give score=0 to changed masks and score=1 to unchanged masks 218 # so NMS will prefer ones that didn't need postprocessing 219 scores.append(float(unchanged)) 220 221 # recalculate boxes and remove any new duplicates 222 masks = torch.cat(new_masks, dim=0) 223 boxes = batched_mask_to_box(masks.to(torch.bool)) # Casting this to boolean as we work with one-hot labels. 224 keep_by_nms = batched_nms( 225 boxes.float(), 226 torch.as_tensor(scores, dtype=torch.float), 227 torch.zeros_like(boxes[:, 0]), # categories 228 iou_threshold=nms_thresh, 229 ) 230 231 # only recalculate RLEs for masks that have changed 232 for i_mask in keep_by_nms: 233 if scores[i_mask] == 0.0: 234 mask_torch = masks[i_mask].unsqueeze(0) 235 # mask_data["rles"][i_mask] = amg_utils.mask_to_rle_pytorch(mask_torch)[0] 236 mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] 237 mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly 238 mask_data.filter(keep_by_nms) 239 240 return mask_data 241 242 def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode): 243 # filter small disconnected regions and holes in masks 244 if min_mask_region_area > 0: 245 mask_data = self._postprocess_small_regions( 246 mask_data, 247 min_mask_region_area, 248 max(box_nms_thresh, crop_nms_thresh), 249 ) 250 251 # encode masks 252 if output_mode == "coco_rle": 253 mask_data["segmentations"] = [amg_utils.coco_encode_rle(rle) for rle in mask_data["rles"]] 254 elif output_mode == "binary_mask": 255 mask_data["segmentations"] = [amg_utils.rle_to_mask(rle) for rle in mask_data["rles"]] 256 else: 257 mask_data["segmentations"] = mask_data["rles"] 258 259 # write mask records 260 curr_anns = [] 261 for idx in range(len(mask_data["segmentations"])): 262 ann = { 263 "segmentation": mask_data["segmentations"][idx], 264 "area": amg_utils.area_from_rle(mask_data["rles"][idx]), 265 "bbox": amg_utils.box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), 266 "predicted_iou": mask_data["iou_preds"][idx].item(), 267 "stability_score": mask_data["stability_score"][idx].item(), 268 "crop_box": amg_utils.box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), 269 } 270 # the data from embedding based segmentation doesn't have the points 271 # so we skip if the corresponding key can't be found 272 try: 273 ann["point_coords"] = [mask_data["points"][idx].tolist()] 274 except KeyError: 275 pass 276 curr_anns.append(ann) 277 278 return curr_anns 279 280 def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None): 281 orig_h, orig_w = original_size 282 283 # serialize predictions and store in MaskData 284 data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1)) 285 if points is not None: 286 data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0), dtype=torch.float) 287 288 del masks 289 290 # calculate the stability scores 291 data["stability_score"] = amg_utils.calculate_stability_score( 292 data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset 293 ) 294 295 # threshold masks and calculate boxes 296 data["masks"] = data["masks"] > self._predictor.model.mask_threshold 297 data["masks"] = data["masks"].type(torch.bool) 298 data["boxes"] = batched_mask_to_box(data["masks"]) 299 300 # compress to RLE 301 data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w) 302 # data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"]) 303 data["rles"] = mask_to_rle_pytorch(data["masks"]) 304 del data["masks"] 305 306 return data 307 308 def get_state(self) -> Dict[str, Any]: 309 """Get the initialized state of the mask generator. 310 311 Returns: 312 State of the mask generator. 313 """ 314 if not self.is_initialized: 315 raise RuntimeError("The state has not been computed yet. Call initialize first.") 316 317 return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size} 318 319 def set_state(self, state: Dict[str, Any]) -> None: 320 """Set the state of the mask generator. 321 322 Args: 323 state: The state of the mask generator, e.g. from serialized state. 324 """ 325 self._crop_list = state["crop_list"] 326 self._crop_boxes = state["crop_boxes"] 327 self._original_size = state["original_size"] 328 self._is_initialized = True 329 330 def clear_state(self): 331 """Clear the state of the mask generator. 332 """ 333 self._crop_list = None 334 self._crop_boxes = None 335 self._original_size = None 336 self._is_initialized = False
Base class for the automatic mask generators.
129 @property 130 def is_initialized(self): 131 """Whether the mask generator has already been initialized. 132 """ 133 return self._is_initialized
Whether the mask generator has already been initialized.
135 @property 136 def crop_list(self): 137 """The list of mask data after initialization. 138 """ 139 return self._crop_list
The list of mask data after initialization.
141 @property 142 def crop_boxes(self): 143 """The list of crop boxes. 144 """ 145 return self._crop_boxes
The list of crop boxes.
147 @property 148 def original_size(self): 149 """The original image size. 150 """ 151 return self._original_size
The original image size.
308 def get_state(self) -> Dict[str, Any]: 309 """Get the initialized state of the mask generator. 310 311 Returns: 312 State of the mask generator. 313 """ 314 if not self.is_initialized: 315 raise RuntimeError("The state has not been computed yet. Call initialize first.") 316 317 return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size}
Get the initialized state of the mask generator.
Returns:
State of the mask generator.
319 def set_state(self, state: Dict[str, Any]) -> None: 320 """Set the state of the mask generator. 321 322 Args: 323 state: The state of the mask generator, e.g. from serialized state. 324 """ 325 self._crop_list = state["crop_list"] 326 self._crop_boxes = state["crop_boxes"] 327 self._original_size = state["original_size"] 328 self._is_initialized = True
Set the state of the mask generator.
Arguments:
- state: The state of the mask generator, e.g. from serialized state.
339class AutomaticMaskGenerator(AMGBase): 340 """Generates an instance segmentation without prompts, using a point grid. 341 342 This class implements the same logic as 343 https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py 344 It decouples the computationally expensive steps of generating masks from the cheap post-processing operation 345 to filter these masks to enable grid search and interactively changing the post-processing. 346 347 Use this class as follows: 348 ```python 349 amg = AutomaticMaskGenerator(predictor) 350 amg.initialize(image) # Initialize the masks, this takes care of all expensive computations. 351 masks = amg.generate(pred_iou_thresh=0.8) # Generate the masks. This is fast and enables testing parameters 352 ``` 353 354 Args: 355 predictor: The segment anything predictor. 356 points_per_side: The number of points to be sampled along one side of the image. 357 If None, `point_grids` must provide explicit point sampling. 358 points_per_batch: The number of points run simultaneously by the model. 359 Higher numbers may be faster but use more GPU memory. 360 crop_n_layers: If >0, the mask prediction will be run again on crops of the image. 361 crop_overlap_ratio: Sets the degree to which crops overlap. 362 crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops. 363 point_grids: A lisst over explicit grids of points used for sampling masks. 364 Normalized to [0, 1] with respect to the image coordinate system. 365 stability_score_offset: The amount to shift the cutoff when calculating the stability score. 366 """ 367 def __init__( 368 self, 369 predictor: SamPredictor, 370 points_per_side: Optional[int] = 32, 371 points_per_batch: Optional[int] = None, 372 crop_n_layers: int = 0, 373 crop_overlap_ratio: float = 512 / 1500, 374 crop_n_points_downscale_factor: int = 1, 375 point_grids: Optional[List[np.ndarray]] = None, 376 stability_score_offset: float = 1.0, 377 ): 378 super().__init__() 379 380 if points_per_side is not None: 381 self.point_grids = amg_utils.build_all_layer_point_grids( 382 points_per_side, crop_n_layers, crop_n_points_downscale_factor, 383 ) 384 elif point_grids is not None: 385 self.point_grids = point_grids 386 else: 387 raise ValueError("Can't have both points_per_side and point_grid be None or not None.") 388 389 self._predictor = predictor 390 self._points_per_side = points_per_side 391 392 # we set the points per batch to 16 for mps for performance reasons 393 # and otherwise keep them at the default of 64 394 if points_per_batch is None: 395 points_per_batch = 16 if str(predictor.device) == "mps" else 64 396 self._points_per_batch = points_per_batch 397 398 self._crop_n_layers = crop_n_layers 399 self._crop_overlap_ratio = crop_overlap_ratio 400 self._crop_n_points_downscale_factor = crop_n_points_downscale_factor 401 self._stability_score_offset = stability_score_offset 402 403 def _process_batch(self, points, im_size, crop_box, original_size): 404 # run model on this batch 405 transformed_points = self._predictor.transform.apply_coords(points, im_size) 406 in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float) 407 in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) 408 masks, iou_preds, _ = self._predictor.predict_torch( 409 point_coords=in_points[:, None, :], 410 point_labels=in_labels[:, None], 411 multimask_output=True, 412 return_logits=True, 413 ) 414 data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points) 415 del masks 416 return data 417 418 def _process_crop(self, image, crop_box, crop_layer_idx, precomputed_embeddings, pbar_init=None, pbar_update=None): 419 # Crop the image and calculate embeddings. 420 x0, y0, x1, y1 = crop_box 421 cropped_im = image[y0:y1, x0:x1, :] 422 cropped_im_size = cropped_im.shape[:2] 423 424 if not precomputed_embeddings: 425 self._predictor.set_image(cropped_im) 426 427 # Get the points for this crop. 428 points_scale = np.array(cropped_im_size)[None, ::-1] 429 points_for_image = self.point_grids[crop_layer_idx] * points_scale 430 431 # Generate masks for this crop in batches. 432 data = amg_utils.MaskData() 433 n_batches = len(points_for_image) // self._points_per_batch +\ 434 int(len(points_for_image) % self._points_per_batch != 0) 435 if pbar_init is not None: 436 pbar_init(n_batches, "Predict masks for point grid prompts") 437 438 for (points,) in amg_utils.batch_iterator(self._points_per_batch, points_for_image): 439 batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size) 440 data.cat(batch_data) 441 del batch_data 442 if pbar_update is not None: 443 pbar_update(1) 444 445 if not precomputed_embeddings: 446 self._predictor.reset_image() 447 448 return data 449 450 @torch.no_grad() 451 def initialize( 452 self, 453 image: np.ndarray, 454 image_embeddings: Optional[util.ImageEmbeddings] = None, 455 i: Optional[int] = None, 456 verbose: bool = False, 457 pbar_init: Optional[callable] = None, 458 pbar_update: Optional[callable] = None, 459 ) -> None: 460 """Initialize image embeddings and masks for an image. 461 462 Args: 463 image: The input image, volume or timeseries. 464 image_embeddings: Optional precomputed image embeddings. 465 See `util.precompute_image_embeddings` for details. 466 i: Index for the image data. Required if `image` has three spatial dimensions 467 or a time dimension and two spatial dimensions. 468 verbose: Whether to print computation progress. 469 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 470 Can be used together with pbar_update to handle napari progress bar in other thread. 471 To enables using this function within a threadworker. 472 pbar_update: Callback to update an external progress bar. 473 """ 474 original_size = image.shape[:2] 475 self._original_size = original_size 476 477 crop_boxes, layer_idxs = amg_utils.generate_crop_boxes( 478 original_size, self._crop_n_layers, self._crop_overlap_ratio 479 ) 480 481 # We can set fixed image embeddings if we only have a single crop box (the default setting). 482 # Otherwise we have to recompute the embeddings for each crop and can't precompute. 483 if len(crop_boxes) == 1: 484 if image_embeddings is None: 485 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 486 util.set_precomputed(self._predictor, image_embeddings, i=i) 487 precomputed_embeddings = True 488 else: 489 precomputed_embeddings = False 490 491 # we need to cast to the image representation that is compatible with SAM 492 image = util._to_image(image) 493 494 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 495 496 crop_list = [] 497 for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 498 crop_data = self._process_crop( 499 image, crop_box, layer_idx, 500 precomputed_embeddings=precomputed_embeddings, 501 pbar_init=pbar_init, pbar_update=pbar_update, 502 ) 503 crop_list.append(crop_data) 504 pbar_close() 505 506 self._is_initialized = True 507 self._crop_list = crop_list 508 self._crop_boxes = crop_boxes 509 510 @torch.no_grad() 511 def generate( 512 self, 513 pred_iou_thresh: float = 0.88, 514 stability_score_thresh: float = 0.95, 515 box_nms_thresh: float = 0.7, 516 crop_nms_thresh: float = 0.7, 517 min_mask_region_area: int = 0, 518 output_mode: str = "binary_mask", 519 ) -> List[Dict[str, Any]]: 520 """Generate instance segmentation for the currently initialized image. 521 522 Args: 523 pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model. 524 stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask 525 under changes to the cutoff used to binarize the model prediction. 526 box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks. 527 crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops. 528 min_mask_region_area: Minimal size for the predicted masks. 529 output_mode: The form masks are returned in. 530 531 Returns: 532 The instance segmentation masks. 533 """ 534 if not self.is_initialized: 535 raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.") 536 537 data = amg_utils.MaskData() 538 for data_, crop_box in zip(self.crop_list, self.crop_boxes): 539 crop_data = self._postprocess_batch( 540 data=deepcopy(data_), 541 crop_box=crop_box, original_size=self.original_size, 542 pred_iou_thresh=pred_iou_thresh, 543 stability_score_thresh=stability_score_thresh, 544 box_nms_thresh=box_nms_thresh 545 ) 546 data.cat(crop_data) 547 548 if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0: 549 # Prefer masks from smaller crops 550 scores = 1 / box_area(data["crop_boxes"]) 551 scores = scores.to(data["boxes"].device) 552 keep_by_nms = batched_nms( 553 data["boxes"].float(), 554 scores, 555 torch.zeros_like(data["boxes"][:, 0]), # categories 556 iou_threshold=crop_nms_thresh, 557 ) 558 data.filter(keep_by_nms) 559 560 data.to_numpy() 561 masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode) 562 return masks
Generates an instance segmentation without prompts, using a point grid.
This class implements the same logic as https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py It decouples the computationally expensive steps of generating masks from the cheap post-processing operation to filter these masks to enable grid search and interactively changing the post-processing.
Use this class as follows:
amg = AutomaticMaskGenerator(predictor)
amg.initialize(image) # Initialize the masks, this takes care of all expensive computations.
masks = amg.generate(pred_iou_thresh=0.8) # Generate the masks. This is fast and enables testing parameters
Arguments:
- predictor: The segment anything predictor.
- points_per_side: The number of points to be sampled along one side of the image.
If None,
point_grids
must provide explicit point sampling. - points_per_batch: The number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory.
- crop_n_layers: If >0, the mask prediction will be run again on crops of the image.
- crop_overlap_ratio: Sets the degree to which crops overlap.
- crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops.
- point_grids: A lisst over explicit grids of points used for sampling masks. Normalized to [0, 1] with respect to the image coordinate system.
- stability_score_offset: The amount to shift the cutoff when calculating the stability score.
367 def __init__( 368 self, 369 predictor: SamPredictor, 370 points_per_side: Optional[int] = 32, 371 points_per_batch: Optional[int] = None, 372 crop_n_layers: int = 0, 373 crop_overlap_ratio: float = 512 / 1500, 374 crop_n_points_downscale_factor: int = 1, 375 point_grids: Optional[List[np.ndarray]] = None, 376 stability_score_offset: float = 1.0, 377 ): 378 super().__init__() 379 380 if points_per_side is not None: 381 self.point_grids = amg_utils.build_all_layer_point_grids( 382 points_per_side, crop_n_layers, crop_n_points_downscale_factor, 383 ) 384 elif point_grids is not None: 385 self.point_grids = point_grids 386 else: 387 raise ValueError("Can't have both points_per_side and point_grid be None or not None.") 388 389 self._predictor = predictor 390 self._points_per_side = points_per_side 391 392 # we set the points per batch to 16 for mps for performance reasons 393 # and otherwise keep them at the default of 64 394 if points_per_batch is None: 395 points_per_batch = 16 if str(predictor.device) == "mps" else 64 396 self._points_per_batch = points_per_batch 397 398 self._crop_n_layers = crop_n_layers 399 self._crop_overlap_ratio = crop_overlap_ratio 400 self._crop_n_points_downscale_factor = crop_n_points_downscale_factor 401 self._stability_score_offset = stability_score_offset
450 @torch.no_grad() 451 def initialize( 452 self, 453 image: np.ndarray, 454 image_embeddings: Optional[util.ImageEmbeddings] = None, 455 i: Optional[int] = None, 456 verbose: bool = False, 457 pbar_init: Optional[callable] = None, 458 pbar_update: Optional[callable] = None, 459 ) -> None: 460 """Initialize image embeddings and masks for an image. 461 462 Args: 463 image: The input image, volume or timeseries. 464 image_embeddings: Optional precomputed image embeddings. 465 See `util.precompute_image_embeddings` for details. 466 i: Index for the image data. Required if `image` has three spatial dimensions 467 or a time dimension and two spatial dimensions. 468 verbose: Whether to print computation progress. 469 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 470 Can be used together with pbar_update to handle napari progress bar in other thread. 471 To enables using this function within a threadworker. 472 pbar_update: Callback to update an external progress bar. 473 """ 474 original_size = image.shape[:2] 475 self._original_size = original_size 476 477 crop_boxes, layer_idxs = amg_utils.generate_crop_boxes( 478 original_size, self._crop_n_layers, self._crop_overlap_ratio 479 ) 480 481 # We can set fixed image embeddings if we only have a single crop box (the default setting). 482 # Otherwise we have to recompute the embeddings for each crop and can't precompute. 483 if len(crop_boxes) == 1: 484 if image_embeddings is None: 485 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 486 util.set_precomputed(self._predictor, image_embeddings, i=i) 487 precomputed_embeddings = True 488 else: 489 precomputed_embeddings = False 490 491 # we need to cast to the image representation that is compatible with SAM 492 image = util._to_image(image) 493 494 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 495 496 crop_list = [] 497 for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 498 crop_data = self._process_crop( 499 image, crop_box, layer_idx, 500 precomputed_embeddings=precomputed_embeddings, 501 pbar_init=pbar_init, pbar_update=pbar_update, 502 ) 503 crop_list.append(crop_data) 504 pbar_close() 505 506 self._is_initialized = True 507 self._crop_list = crop_list 508 self._crop_boxes = crop_boxes
Initialize image embeddings and masks for an image.
Arguments:
- image: The input image, volume or timeseries.
- image_embeddings: Optional precomputed image embeddings.
See
util.precompute_image_embeddings
for details. - i: Index for the image data. Required if
image
has three spatial dimensions or a time dimension and two spatial dimensions. - verbose: Whether to print computation progress.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
510 @torch.no_grad() 511 def generate( 512 self, 513 pred_iou_thresh: float = 0.88, 514 stability_score_thresh: float = 0.95, 515 box_nms_thresh: float = 0.7, 516 crop_nms_thresh: float = 0.7, 517 min_mask_region_area: int = 0, 518 output_mode: str = "binary_mask", 519 ) -> List[Dict[str, Any]]: 520 """Generate instance segmentation for the currently initialized image. 521 522 Args: 523 pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model. 524 stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask 525 under changes to the cutoff used to binarize the model prediction. 526 box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks. 527 crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops. 528 min_mask_region_area: Minimal size for the predicted masks. 529 output_mode: The form masks are returned in. 530 531 Returns: 532 The instance segmentation masks. 533 """ 534 if not self.is_initialized: 535 raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.") 536 537 data = amg_utils.MaskData() 538 for data_, crop_box in zip(self.crop_list, self.crop_boxes): 539 crop_data = self._postprocess_batch( 540 data=deepcopy(data_), 541 crop_box=crop_box, original_size=self.original_size, 542 pred_iou_thresh=pred_iou_thresh, 543 stability_score_thresh=stability_score_thresh, 544 box_nms_thresh=box_nms_thresh 545 ) 546 data.cat(crop_data) 547 548 if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0: 549 # Prefer masks from smaller crops 550 scores = 1 / box_area(data["crop_boxes"]) 551 scores = scores.to(data["boxes"].device) 552 keep_by_nms = batched_nms( 553 data["boxes"].float(), 554 scores, 555 torch.zeros_like(data["boxes"][:, 0]), # categories 556 iou_threshold=crop_nms_thresh, 557 ) 558 data.filter(keep_by_nms) 559 560 data.to_numpy() 561 masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode) 562 return masks
Generate instance segmentation for the currently initialized image.
Arguments:
- pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
- stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask under changes to the cutoff used to binarize the model prediction.
- box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
- crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops.
- min_mask_region_area: Minimal size for the predicted masks.
- output_mode: The form masks are returned in.
Returns:
The instance segmentation masks.
Inherited Members
592class TiledAutomaticMaskGenerator(AutomaticMaskGenerator): 593 """Generates an instance segmentation without prompts, using a point grid. 594 595 Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings. 596 597 Args: 598 predictor: The segment anything predictor. 599 points_per_side: The number of points to be sampled along one side of the image. 600 If None, `point_grids` must provide explicit point sampling. 601 points_per_batch: The number of points run simultaneously by the model. 602 Higher numbers may be faster but use more GPU memory. 603 point_grids: A lisst over explicit grids of points used for sampling masks. 604 Normalized to [0, 1] with respect to the image coordinate system. 605 stability_score_offset: The amount to shift the cutoff when calculating the stability score. 606 """ 607 608 # We only expose the arguments that make sense for the tiled mask generator. 609 # Anything related to crops doesn't make sense, because we re-use that functionality 610 # for tiling, so these parameters wouldn't have any effect. 611 def __init__( 612 self, 613 predictor: SamPredictor, 614 points_per_side: Optional[int] = 32, 615 points_per_batch: int = 64, 616 point_grids: Optional[List[np.ndarray]] = None, 617 stability_score_offset: float = 1.0, 618 ) -> None: 619 super().__init__( 620 predictor=predictor, 621 points_per_side=points_per_side, 622 points_per_batch=points_per_batch, 623 point_grids=point_grids, 624 stability_score_offset=stability_score_offset, 625 ) 626 627 @torch.no_grad() 628 def initialize( 629 self, 630 image: np.ndarray, 631 image_embeddings: Optional[util.ImageEmbeddings] = None, 632 i: Optional[int] = None, 633 tile_shape: Optional[Tuple[int, int]] = None, 634 halo: Optional[Tuple[int, int]] = None, 635 verbose: bool = False, 636 pbar_init: Optional[callable] = None, 637 pbar_update: Optional[callable] = None, 638 batch_size: int = 1, 639 ) -> None: 640 """Initialize image embeddings and masks for an image. 641 642 Args: 643 image: The input image, volume or timeseries. 644 image_embeddings: Optional precomputed image embeddings. 645 See `util.precompute_image_embeddings` for details. 646 i: Index for the image data. Required if `image` has three spatial dimensions 647 or a time dimension and two spatial dimensions. 648 tile_shape: The tile shape for embedding prediction. 649 halo: The overlap of between tiles. 650 verbose: Whether to print computation progress. 651 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 652 Can be used together with pbar_update to handle napari progress bar in other thread. 653 To enables using this function within a threadworker. 654 pbar_update: Callback to update an external progress bar. 655 batch_size: The batch size for image embedding prediction. 656 """ 657 original_size = image.shape[:2] 658 self._original_size = original_size 659 660 image_embeddings, tile_shape, halo = _process_tiled_embeddings( 661 self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose, batch_size=batch_size 662 ) 663 664 tiling = blocking([0, 0], original_size, tile_shape) 665 n_tiles = tiling.numberOfBlocks 666 667 # The crop box is always the full local tile. 668 tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)] 669 crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles] 670 671 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 672 pbar_init(n_tiles, "Compute masks for tile") 673 674 # We need to cast to the image representation that is compatible with SAM. 675 image = util._to_image(image) 676 677 mask_data = [] 678 for tile_id in range(n_tiles): 679 # set the pre-computed embeddings for this tile 680 features = image_embeddings["features"][tile_id] 681 tile_embeddings = { 682 "features": features, 683 "input_size": features.attrs["input_size"], 684 "original_size": features.attrs["original_size"], 685 } 686 util.set_precomputed(self._predictor, tile_embeddings, i) 687 688 # compute the mask data for this tile and append it 689 this_mask_data = self._process_crop( 690 image, crop_box=crop_boxes[tile_id], crop_layer_idx=0, precomputed_embeddings=True 691 ) 692 mask_data.append(this_mask_data) 693 pbar_update(1) 694 pbar_close() 695 696 # set the initialized data 697 self._is_initialized = True 698 self._crop_list = mask_data 699 self._crop_boxes = crop_boxes
Generates an instance segmentation without prompts, using a point grid.
Implements the same functionality as AutomaticMaskGenerator
but for tiled embeddings.
Arguments:
- predictor: The segment anything predictor.
- points_per_side: The number of points to be sampled along one side of the image.
If None,
point_grids
must provide explicit point sampling. - points_per_batch: The number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory.
- point_grids: A lisst over explicit grids of points used for sampling masks. Normalized to [0, 1] with respect to the image coordinate system.
- stability_score_offset: The amount to shift the cutoff when calculating the stability score.
611 def __init__( 612 self, 613 predictor: SamPredictor, 614 points_per_side: Optional[int] = 32, 615 points_per_batch: int = 64, 616 point_grids: Optional[List[np.ndarray]] = None, 617 stability_score_offset: float = 1.0, 618 ) -> None: 619 super().__init__( 620 predictor=predictor, 621 points_per_side=points_per_side, 622 points_per_batch=points_per_batch, 623 point_grids=point_grids, 624 stability_score_offset=stability_score_offset, 625 )
627 @torch.no_grad() 628 def initialize( 629 self, 630 image: np.ndarray, 631 image_embeddings: Optional[util.ImageEmbeddings] = None, 632 i: Optional[int] = None, 633 tile_shape: Optional[Tuple[int, int]] = None, 634 halo: Optional[Tuple[int, int]] = None, 635 verbose: bool = False, 636 pbar_init: Optional[callable] = None, 637 pbar_update: Optional[callable] = None, 638 batch_size: int = 1, 639 ) -> None: 640 """Initialize image embeddings and masks for an image. 641 642 Args: 643 image: The input image, volume or timeseries. 644 image_embeddings: Optional precomputed image embeddings. 645 See `util.precompute_image_embeddings` for details. 646 i: Index for the image data. Required if `image` has three spatial dimensions 647 or a time dimension and two spatial dimensions. 648 tile_shape: The tile shape for embedding prediction. 649 halo: The overlap of between tiles. 650 verbose: Whether to print computation progress. 651 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 652 Can be used together with pbar_update to handle napari progress bar in other thread. 653 To enables using this function within a threadworker. 654 pbar_update: Callback to update an external progress bar. 655 batch_size: The batch size for image embedding prediction. 656 """ 657 original_size = image.shape[:2] 658 self._original_size = original_size 659 660 image_embeddings, tile_shape, halo = _process_tiled_embeddings( 661 self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose, batch_size=batch_size 662 ) 663 664 tiling = blocking([0, 0], original_size, tile_shape) 665 n_tiles = tiling.numberOfBlocks 666 667 # The crop box is always the full local tile. 668 tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)] 669 crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles] 670 671 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 672 pbar_init(n_tiles, "Compute masks for tile") 673 674 # We need to cast to the image representation that is compatible with SAM. 675 image = util._to_image(image) 676 677 mask_data = [] 678 for tile_id in range(n_tiles): 679 # set the pre-computed embeddings for this tile 680 features = image_embeddings["features"][tile_id] 681 tile_embeddings = { 682 "features": features, 683 "input_size": features.attrs["input_size"], 684 "original_size": features.attrs["original_size"], 685 } 686 util.set_precomputed(self._predictor, tile_embeddings, i) 687 688 # compute the mask data for this tile and append it 689 this_mask_data = self._process_crop( 690 image, crop_box=crop_boxes[tile_id], crop_layer_idx=0, precomputed_embeddings=True 691 ) 692 mask_data.append(this_mask_data) 693 pbar_update(1) 694 pbar_close() 695 696 # set the initialized data 697 self._is_initialized = True 698 self._crop_list = mask_data 699 self._crop_boxes = crop_boxes
Initialize image embeddings and masks for an image.
Arguments:
- image: The input image, volume or timeseries.
- image_embeddings: Optional precomputed image embeddings.
See
util.precompute_image_embeddings
for details. - i: Index for the image data. Required if
image
has three spatial dimensions or a time dimension and two spatial dimensions. - tile_shape: The tile shape for embedding prediction.
- halo: The overlap of between tiles.
- verbose: Whether to print computation progress.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
- batch_size: The batch size for image embedding prediction.
707class DecoderAdapter(torch.nn.Module): 708 """Adapter to contain the UNETR decoder in a single module. 709 710 To apply the decoder on top of pre-computed embeddings for the segmentation functionality. 711 See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py 712 """ 713 def __init__(self, unetr): 714 super().__init__() 715 716 self.base = unetr.base 717 self.out_conv = unetr.out_conv 718 self.deconv_out = unetr.deconv_out 719 self.decoder_head = unetr.decoder_head 720 self.final_activation = unetr.final_activation 721 self.postprocess_masks = unetr.postprocess_masks 722 723 self.decoder = unetr.decoder 724 self.deconv1 = unetr.deconv1 725 self.deconv2 = unetr.deconv2 726 self.deconv3 = unetr.deconv3 727 self.deconv4 = unetr.deconv4 728 729 def _forward_impl(self, input_): 730 z12 = input_ 731 732 z9 = self.deconv1(z12) 733 z6 = self.deconv2(z9) 734 z3 = self.deconv3(z6) 735 z0 = self.deconv4(z3) 736 737 updated_from_encoder = [z9, z6, z3] 738 739 x = self.base(z12) 740 x = self.decoder(x, encoder_inputs=updated_from_encoder) 741 x = self.deconv_out(x) 742 743 x = torch.cat([x, z0], dim=1) 744 x = self.decoder_head(x) 745 746 x = self.out_conv(x) 747 if self.final_activation is not None: 748 x = self.final_activation(x) 749 return x 750 751 def forward(self, input_, input_shape, original_shape): 752 x = self._forward_impl(input_) 753 x = self.postprocess_masks(x, input_shape, original_shape) 754 return x
Adapter to contain the UNETR decoder in a single module.
To apply the decoder on top of pre-computed embeddings for the segmentation functionality. See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py
713 def __init__(self, unetr): 714 super().__init__() 715 716 self.base = unetr.base 717 self.out_conv = unetr.out_conv 718 self.deconv_out = unetr.deconv_out 719 self.decoder_head = unetr.decoder_head 720 self.final_activation = unetr.final_activation 721 self.postprocess_masks = unetr.postprocess_masks 722 723 self.decoder = unetr.decoder 724 self.deconv1 = unetr.deconv1 725 self.deconv2 = unetr.deconv2 726 self.deconv3 = unetr.deconv3 727 self.deconv4 = unetr.deconv4
Initialize internal Module state, shared by both nn.Module and ScriptModule.
751 def forward(self, input_, input_shape, original_shape): 752 x = self._forward_impl(input_) 753 x = self.postprocess_masks(x, input_shape, original_shape) 754 return x
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
757def get_unetr( 758 image_encoder: torch.nn.Module, 759 decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None, 760 device: Optional[Union[str, torch.device]] = None, 761 out_channels: int = 3, 762 flexible_load_checkpoint: bool = False, 763) -> torch.nn.Module: 764 """Get UNETR model for automatic instance segmentation. 765 766 Args: 767 image_encoder: The image encoder of the SAM model. 768 This is used as encoder by the UNETR too. 769 decoder_state: Optional decoder state to initialize the weights of the UNETR decoder. 770 device: The device. 771 out_channels: The number of output channels. 772 flexible_load_checkpoint: Whether to allow reinitialization of parameters 773 which could not be found in the provided decoder state. 774 775 Returns: 776 The UNETR model. 777 """ 778 device = util.get_device(device) 779 780 unetr = UNETR( 781 backbone="sam", 782 encoder=image_encoder, 783 out_channels=out_channels, 784 use_sam_stats=True, 785 final_activation="Sigmoid", 786 use_skip_connection=False, 787 resize_input=True, 788 ) 789 if decoder_state is not None: 790 unetr_state_dict = unetr.state_dict() 791 for k, v in unetr_state_dict.items(): 792 if not k.startswith("encoder"): 793 if flexible_load_checkpoint: # Whether allow reinitalization of params, if not found. 794 if k in decoder_state: # First check whether the key is available in the provided decoder state. 795 unetr_state_dict[k] = decoder_state[k] 796 else: # Otherwise, allow it to initialize it. 797 warnings.warn(f"Could not find '{k}' in the pretrained state dict. Hence, we reinitialize it.") 798 unetr_state_dict[k] = v 799 800 else: # Whether be strict on finding the parameter in the decoder state. 801 if k not in decoder_state: 802 raise RuntimeError(f"The parameters for '{k}' could not be found.") 803 unetr_state_dict[k] = decoder_state[k] 804 805 unetr.load_state_dict(unetr_state_dict) 806 807 unetr.to(device) 808 return unetr
Get UNETR model for automatic instance segmentation.
Arguments:
- image_encoder: The image encoder of the SAM model. This is used as encoder by the UNETR too.
- decoder_state: Optional decoder state to initialize the weights of the UNETR decoder.
- device: The device.
- out_channels: The number of output channels.
- flexible_load_checkpoint: Whether to allow reinitialization of parameters which could not be found in the provided decoder state.
Returns:
The UNETR model.
811def get_decoder( 812 image_encoder: torch.nn.Module, 813 decoder_state: OrderedDict[str, torch.Tensor], 814 device: Optional[Union[str, torch.device]] = None, 815) -> DecoderAdapter: 816 """Get decoder to predict outputs for automatic instance segmentation 817 818 Args: 819 image_encoder: The image encoder of the SAM model. 820 decoder_state: State to initialize the weights of the UNETR decoder. 821 device: The device. 822 823 Returns: 824 The decoder for instance segmentation. 825 """ 826 unetr = get_unetr(image_encoder, decoder_state, device) 827 return DecoderAdapter(unetr)
Get decoder to predict outputs for automatic instance segmentation
Arguments:
- image_encoder: The image encoder of the SAM model.
- decoder_state: State to initialize the weights of the UNETR decoder.
- device: The device.
Returns:
The decoder for instance segmentation.
830def get_predictor_and_decoder( 831 model_type: str, 832 checkpoint_path: Union[str, os.PathLike], 833 device: Optional[Union[str, torch.device]] = None, 834 peft_kwargs: Optional[Dict] = None, 835) -> Tuple[SamPredictor, DecoderAdapter]: 836 """Load the SAM model (predictor) and instance segmentation decoder. 837 838 This requires a checkpoint that contains the state for both predictor 839 and decoder. 840 841 Args: 842 model_type: The type of the image encoder used in the SAM model. 843 checkpoint_path: Path to the checkpoint from which to load the data. 844 device: The device. 845 peft_kwargs: Keyword arguments for the PEFT wrapper class. 846 847 Returns: 848 The SAM predictor. 849 The decoder for instance segmentation. 850 """ 851 device = util.get_device(device) 852 predictor, state = util.get_sam_model( 853 model_type=model_type, 854 checkpoint_path=checkpoint_path, 855 device=device, 856 return_state=True, 857 peft_kwargs=peft_kwargs, 858 ) 859 if "decoder_state" not in state: 860 raise ValueError( 861 f"The checkpoint at '{checkpoint_path}' or the chosen model '{model_type}' does not contain a decoder state" 862 ) 863 decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device) 864 return predictor, decoder
Load the SAM model (predictor) and instance segmentation decoder.
This requires a checkpoint that contains the state for both predictor and decoder.
Arguments:
- model_type: The type of the image encoder used in the SAM model.
- checkpoint_path: Path to the checkpoint from which to load the data.
- device: The device.
- peft_kwargs: Keyword arguments for the PEFT wrapper class.
Returns:
The SAM predictor. The decoder for instance segmentation.
916class InstanceSegmentationWithDecoder: 917 """Generates an instance segmentation without prompts, using a decoder. 918 919 Implements the same interface as `AutomaticMaskGenerator`. 920 921 Use this class as follows: 922 ```python 923 segmenter = InstanceSegmentationWithDecoder(predictor, decoder) 924 segmenter.initialize(image) # Predict the image embeddings and decoder outputs. 925 masks = segmenter.generate(center_distance_threshold=0.75) # Generate the instance segmentation. 926 ``` 927 928 Args: 929 predictor: The segment anything predictor. 930 decoder: The decoder to predict intermediate representations 931 for instance segmentation. 932 """ 933 def __init__( 934 self, 935 predictor: SamPredictor, 936 decoder: torch.nn.Module, 937 ) -> None: 938 self._predictor = predictor 939 self._decoder = decoder 940 941 # The decoder outputs. 942 self._foreground = None 943 self._center_distances = None 944 self._boundary_distances = None 945 946 self._is_initialized = False 947 948 @property 949 def is_initialized(self): 950 """Whether the mask generator has already been initialized. 951 """ 952 return self._is_initialized 953 954 @torch.no_grad() 955 def initialize( 956 self, 957 image: np.ndarray, 958 image_embeddings: Optional[util.ImageEmbeddings] = None, 959 i: Optional[int] = None, 960 verbose: bool = False, 961 pbar_init: Optional[callable] = None, 962 pbar_update: Optional[callable] = None, 963 ) -> None: 964 """Initialize image embeddings and decoder predictions for an image. 965 966 Args: 967 image: The input image, volume or timeseries. 968 image_embeddings: Optional precomputed image embeddings. 969 See `util.precompute_image_embeddings` for details. 970 i: Index for the image data. Required if `image` has three spatial dimensions 971 or a time dimension and two spatial dimensions. 972 verbose: Whether to be verbose. 973 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 974 Can be used together with pbar_update to handle napari progress bar in other thread. 975 To enables using this function within a threadworker. 976 pbar_update: Callback to update an external progress bar. 977 """ 978 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 979 pbar_init(1, "Initialize instance segmentation with decoder") 980 981 if image_embeddings is None: 982 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 983 984 # Get the image embeddings from the predictor. 985 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i) 986 embeddings = self._predictor.features 987 input_shape = tuple(self._predictor.input_size) 988 original_shape = tuple(self._predictor.original_size) 989 990 # Run prediction with the UNETR decoder. 991 output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0) 992 assert output.shape[0] == 3, f"{output.shape}" 993 pbar_update(1) 994 pbar_close() 995 996 # Set the state. 997 self._foreground = output[0] 998 self._center_distances = output[1] 999 self._boundary_distances = output[2] 1000 self._is_initialized = True 1001 1002 def _to_masks(self, segmentation, output_mode): 1003 if output_mode != "binary_mask": 1004 raise NotImplementedError 1005 1006 props = regionprops(segmentation) 1007 ndim = segmentation.ndim 1008 assert ndim in (2, 3) 1009 1010 shape = segmentation.shape 1011 if ndim == 2: 1012 crop_box = [0, shape[1], 0, shape[0]] 1013 else: 1014 crop_box = [0, shape[2], 0, shape[1], 0, shape[0]] 1015 1016 # go from skimage bbox in format [y0, x0, y1, x1] to SAM format [x0, w, y0, h] 1017 def to_bbox_2d(bbox): 1018 y0, x0 = bbox[0], bbox[1] 1019 w = bbox[3] - x0 1020 h = bbox[2] - y0 1021 return [x0, w, y0, h] 1022 1023 def to_bbox_3d(bbox): 1024 z0, y0, x0 = bbox[0], bbox[1], bbox[2] 1025 w = bbox[5] - x0 1026 h = bbox[4] - y0 1027 d = bbox[3] - y0 1028 return [x0, w, y0, h, z0, d] 1029 1030 to_bbox = to_bbox_2d if ndim == 2 else to_bbox_3d 1031 masks = [ 1032 { 1033 "segmentation": segmentation == prop.label, 1034 "area": prop.area, 1035 "bbox": to_bbox(prop.bbox), 1036 "crop_box": crop_box, 1037 "seg_id": prop.label, 1038 } for prop in props 1039 ] 1040 return masks 1041 1042 def generate( 1043 self, 1044 center_distance_threshold: float = 0.5, 1045 boundary_distance_threshold: float = 0.5, 1046 foreground_threshold: float = 0.5, 1047 foreground_smoothing: float = 1.0, 1048 distance_smoothing: float = 1.6, 1049 min_size: int = 0, 1050 output_mode: Optional[str] = "binary_mask", 1051 tile_shape: Optional[Tuple[int, int]] = None, 1052 halo: Optional[Tuple[int, int]] = None, 1053 n_threads: Optional[int] = None, 1054 ) -> List[Dict[str, Any]]: 1055 """Generate instance segmentation for the currently initialized image. 1056 1057 Args: 1058 center_distance_threshold: Center distance predictions below this value will be 1059 used to find seeds (intersected with thresholded boundary distance predictions). 1060 boundary_distance_threshold: Boundary distance predictions below this value will be 1061 used to find seeds (intersected with thresholded center distance predictions). 1062 foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid 1063 checkerboard artifacts in the prediction. 1064 foreground_threshold: Foreground predictions above this value will be used as foreground mask. 1065 distance_smoothing: Sigma value for smoothing the distance predictions. 1066 min_size: Minimal object size in the segmentation result. 1067 output_mode: The form masks are returned in. Pass None to directly return the instance segmentation. 1068 tile_shape: Tile shape for parallelizing the instance segmentation post-processing. 1069 This parameter is independent from the tile shape for computing the embeddings. 1070 If not given then post-processing will not be parallelized. 1071 halo: Halo for parallel post-processing. See also `tile_shape`. 1072 n_threads: Number of threads for parallel post-processing. See also `tile_shape`. 1073 1074 Returns: 1075 The instance segmentation masks. 1076 """ 1077 if not self.is_initialized: 1078 raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.") 1079 1080 if foreground_smoothing > 0: 1081 foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing) 1082 else: 1083 foreground = self._foreground 1084 1085 if tile_shape is None: 1086 segmentation = watershed_from_center_and_boundary_distances( 1087 center_distances=self._center_distances, 1088 boundary_distances=self._boundary_distances, 1089 foreground_map=foreground, 1090 center_distance_threshold=center_distance_threshold, 1091 boundary_distance_threshold=boundary_distance_threshold, 1092 foreground_threshold=foreground_threshold, 1093 distance_smoothing=distance_smoothing, 1094 min_size=min_size, 1095 ) 1096 else: 1097 if halo is None: 1098 raise ValueError("You must pass a value for halo if tile_shape is given.") 1099 segmentation = _watershed_from_center_and_boundary_distances_parallel( 1100 center_distances=self._center_distances, 1101 boundary_distances=self._boundary_distances, 1102 foreground_map=foreground, 1103 center_distance_threshold=center_distance_threshold, 1104 boundary_distance_threshold=boundary_distance_threshold, 1105 foreground_threshold=foreground_threshold, 1106 distance_smoothing=distance_smoothing, 1107 min_size=min_size, 1108 tile_shape=tile_shape, 1109 halo=halo, 1110 n_threads=n_threads, 1111 verbose=False, 1112 ) 1113 1114 if output_mode is not None: 1115 segmentation = self._to_masks(segmentation, output_mode) 1116 return segmentation 1117 1118 def get_state(self) -> Dict[str, Any]: 1119 """Get the initialized state of the instance segmenter. 1120 1121 Returns: 1122 Instance segmentation state. 1123 """ 1124 if not self.is_initialized: 1125 raise RuntimeError("The state has not been computed yet. Call initialize first.") 1126 1127 return { 1128 "foreground": self._foreground, 1129 "center_distances": self._center_distances, 1130 "boundary_distances": self._boundary_distances, 1131 } 1132 1133 def set_state(self, state: Dict[str, Any]) -> None: 1134 """Set the state of the instance segmenter. 1135 1136 Args: 1137 state: The instance segmentation state 1138 """ 1139 self._foreground = state["foreground"] 1140 self._center_distances = state["center_distances"] 1141 self._boundary_distances = state["boundary_distances"] 1142 self._is_initialized = True 1143 1144 def clear_state(self): 1145 """Clear the state of the instance segmenter. 1146 """ 1147 self._foreground = None 1148 self._center_distances = None 1149 self._boundary_distances = None 1150 self._is_initialized = False
Generates an instance segmentation without prompts, using a decoder.
Implements the same interface as AutomaticMaskGenerator
.
Use this class as follows:
segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
segmenter.initialize(image) # Predict the image embeddings and decoder outputs.
masks = segmenter.generate(center_distance_threshold=0.75) # Generate the instance segmentation.
Arguments:
- predictor: The segment anything predictor.
- decoder: The decoder to predict intermediate representations for instance segmentation.
933 def __init__( 934 self, 935 predictor: SamPredictor, 936 decoder: torch.nn.Module, 937 ) -> None: 938 self._predictor = predictor 939 self._decoder = decoder 940 941 # The decoder outputs. 942 self._foreground = None 943 self._center_distances = None 944 self._boundary_distances = None 945 946 self._is_initialized = False
948 @property 949 def is_initialized(self): 950 """Whether the mask generator has already been initialized. 951 """ 952 return self._is_initialized
Whether the mask generator has already been initialized.
954 @torch.no_grad() 955 def initialize( 956 self, 957 image: np.ndarray, 958 image_embeddings: Optional[util.ImageEmbeddings] = None, 959 i: Optional[int] = None, 960 verbose: bool = False, 961 pbar_init: Optional[callable] = None, 962 pbar_update: Optional[callable] = None, 963 ) -> None: 964 """Initialize image embeddings and decoder predictions for an image. 965 966 Args: 967 image: The input image, volume or timeseries. 968 image_embeddings: Optional precomputed image embeddings. 969 See `util.precompute_image_embeddings` for details. 970 i: Index for the image data. Required if `image` has three spatial dimensions 971 or a time dimension and two spatial dimensions. 972 verbose: Whether to be verbose. 973 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 974 Can be used together with pbar_update to handle napari progress bar in other thread. 975 To enables using this function within a threadworker. 976 pbar_update: Callback to update an external progress bar. 977 """ 978 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 979 pbar_init(1, "Initialize instance segmentation with decoder") 980 981 if image_embeddings is None: 982 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 983 984 # Get the image embeddings from the predictor. 985 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i) 986 embeddings = self._predictor.features 987 input_shape = tuple(self._predictor.input_size) 988 original_shape = tuple(self._predictor.original_size) 989 990 # Run prediction with the UNETR decoder. 991 output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0) 992 assert output.shape[0] == 3, f"{output.shape}" 993 pbar_update(1) 994 pbar_close() 995 996 # Set the state. 997 self._foreground = output[0] 998 self._center_distances = output[1] 999 self._boundary_distances = output[2] 1000 self._is_initialized = True
Initialize image embeddings and decoder predictions for an image.
Arguments:
- image: The input image, volume or timeseries.
- image_embeddings: Optional precomputed image embeddings.
See
util.precompute_image_embeddings
for details. - i: Index for the image data. Required if
image
has three spatial dimensions or a time dimension and two spatial dimensions. - verbose: Whether to be verbose.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
1042 def generate( 1043 self, 1044 center_distance_threshold: float = 0.5, 1045 boundary_distance_threshold: float = 0.5, 1046 foreground_threshold: float = 0.5, 1047 foreground_smoothing: float = 1.0, 1048 distance_smoothing: float = 1.6, 1049 min_size: int = 0, 1050 output_mode: Optional[str] = "binary_mask", 1051 tile_shape: Optional[Tuple[int, int]] = None, 1052 halo: Optional[Tuple[int, int]] = None, 1053 n_threads: Optional[int] = None, 1054 ) -> List[Dict[str, Any]]: 1055 """Generate instance segmentation for the currently initialized image. 1056 1057 Args: 1058 center_distance_threshold: Center distance predictions below this value will be 1059 used to find seeds (intersected with thresholded boundary distance predictions). 1060 boundary_distance_threshold: Boundary distance predictions below this value will be 1061 used to find seeds (intersected with thresholded center distance predictions). 1062 foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid 1063 checkerboard artifacts in the prediction. 1064 foreground_threshold: Foreground predictions above this value will be used as foreground mask. 1065 distance_smoothing: Sigma value for smoothing the distance predictions. 1066 min_size: Minimal object size in the segmentation result. 1067 output_mode: The form masks are returned in. Pass None to directly return the instance segmentation. 1068 tile_shape: Tile shape for parallelizing the instance segmentation post-processing. 1069 This parameter is independent from the tile shape for computing the embeddings. 1070 If not given then post-processing will not be parallelized. 1071 halo: Halo for parallel post-processing. See also `tile_shape`. 1072 n_threads: Number of threads for parallel post-processing. See also `tile_shape`. 1073 1074 Returns: 1075 The instance segmentation masks. 1076 """ 1077 if not self.is_initialized: 1078 raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.") 1079 1080 if foreground_smoothing > 0: 1081 foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing) 1082 else: 1083 foreground = self._foreground 1084 1085 if tile_shape is None: 1086 segmentation = watershed_from_center_and_boundary_distances( 1087 center_distances=self._center_distances, 1088 boundary_distances=self._boundary_distances, 1089 foreground_map=foreground, 1090 center_distance_threshold=center_distance_threshold, 1091 boundary_distance_threshold=boundary_distance_threshold, 1092 foreground_threshold=foreground_threshold, 1093 distance_smoothing=distance_smoothing, 1094 min_size=min_size, 1095 ) 1096 else: 1097 if halo is None: 1098 raise ValueError("You must pass a value for halo if tile_shape is given.") 1099 segmentation = _watershed_from_center_and_boundary_distances_parallel( 1100 center_distances=self._center_distances, 1101 boundary_distances=self._boundary_distances, 1102 foreground_map=foreground, 1103 center_distance_threshold=center_distance_threshold, 1104 boundary_distance_threshold=boundary_distance_threshold, 1105 foreground_threshold=foreground_threshold, 1106 distance_smoothing=distance_smoothing, 1107 min_size=min_size, 1108 tile_shape=tile_shape, 1109 halo=halo, 1110 n_threads=n_threads, 1111 verbose=False, 1112 ) 1113 1114 if output_mode is not None: 1115 segmentation = self._to_masks(segmentation, output_mode) 1116 return segmentation
Generate instance segmentation for the currently initialized image.
Arguments:
- center_distance_threshold: Center distance predictions below this value will be used to find seeds (intersected with thresholded boundary distance predictions).
- boundary_distance_threshold: Boundary distance predictions below this value will be used to find seeds (intersected with thresholded center distance predictions).
- foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid checkerboard artifacts in the prediction.
- foreground_threshold: Foreground predictions above this value will be used as foreground mask.
- distance_smoothing: Sigma value for smoothing the distance predictions.
- min_size: Minimal object size in the segmentation result.
- output_mode: The form masks are returned in. Pass None to directly return the instance segmentation.
- tile_shape: Tile shape for parallelizing the instance segmentation post-processing. This parameter is independent from the tile shape for computing the embeddings. If not given then post-processing will not be parallelized.
- halo: Halo for parallel post-processing. See also
tile_shape
. - n_threads: Number of threads for parallel post-processing. See also
tile_shape
.
Returns:
The instance segmentation masks.
1118 def get_state(self) -> Dict[str, Any]: 1119 """Get the initialized state of the instance segmenter. 1120 1121 Returns: 1122 Instance segmentation state. 1123 """ 1124 if not self.is_initialized: 1125 raise RuntimeError("The state has not been computed yet. Call initialize first.") 1126 1127 return { 1128 "foreground": self._foreground, 1129 "center_distances": self._center_distances, 1130 "boundary_distances": self._boundary_distances, 1131 }
Get the initialized state of the instance segmenter.
Returns:
Instance segmentation state.
1133 def set_state(self, state: Dict[str, Any]) -> None: 1134 """Set the state of the instance segmenter. 1135 1136 Args: 1137 state: The instance segmentation state 1138 """ 1139 self._foreground = state["foreground"] 1140 self._center_distances = state["center_distances"] 1141 self._boundary_distances = state["boundary_distances"] 1142 self._is_initialized = True
Set the state of the instance segmenter.
Arguments:
- state: The instance segmentation state
1153class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder): 1154 """Same as `InstanceSegmentationWithDecoder` but for tiled image embeddings. 1155 """ 1156 1157 # Apply the decoder in a batched fashion, and then perform the resizing independently per output. 1158 # This is necessary, because the individual tiles may have different tile shapes due to border tiles. 1159 def _predict_decoder(self, batched_embeddings, input_shapes, original_shapes): 1160 batched_embeddings = torch.cat(batched_embeddings) 1161 output = self._decoder._forward_impl(batched_embeddings) 1162 1163 batched_output = [] 1164 for x, input_shape, original_shape in zip(output, input_shapes, original_shapes): 1165 x = self._decoder.postprocess_masks(x.unsqueeze(0), input_shape, original_shape).squeeze(0) 1166 batched_output.append(x.cpu().numpy()) 1167 return batched_output 1168 1169 @torch.no_grad() 1170 def initialize( 1171 self, 1172 image: np.ndarray, 1173 image_embeddings: Optional[util.ImageEmbeddings] = None, 1174 i: Optional[int] = None, 1175 tile_shape: Optional[Tuple[int, int]] = None, 1176 halo: Optional[Tuple[int, int]] = None, 1177 verbose: bool = False, 1178 pbar_init: Optional[callable] = None, 1179 pbar_update: Optional[callable] = None, 1180 batch_size: int = 1, 1181 ) -> None: 1182 """Initialize image embeddings and decoder predictions for an image. 1183 1184 Args: 1185 image: The input image, volume or timeseries. 1186 image_embeddings: Optional precomputed image embeddings. 1187 See `util.precompute_image_embeddings` for details. 1188 i: Index for the image data. Required if `image` has three spatial dimensions 1189 or a time dimension and two spatial dimensions. 1190 tile_shape: Shape of the tiles for precomputing image embeddings. 1191 halo: Overlap of the tiles for tiled precomputation of image embeddings. 1192 verbose: Dummy input to be compatible with other function signatures. 1193 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 1194 Can be used together with pbar_update to handle napari progress bar in other thread. 1195 To enables using this function within a threadworker. 1196 pbar_update: Callback to update an external progress bar. 1197 batch_size: The batch size for image embedding computation and segmentation decoder prediction. 1198 """ 1199 original_size = image.shape[:2] 1200 image_embeddings, tile_shape, halo = _process_tiled_embeddings( 1201 self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose, batch_size=batch_size 1202 ) 1203 tiling = blocking([0, 0], original_size, tile_shape) 1204 1205 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 1206 pbar_init(tiling.numberOfBlocks, "Initialize tiled instance segmentation with decoder") 1207 1208 foreground = np.zeros(original_size, dtype="float32") 1209 center_distances = np.zeros(original_size, dtype="float32") 1210 boundary_distances = np.zeros(original_size, dtype="float32") 1211 1212 n_tiles = tiling.numberOfBlocks 1213 n_batches = int(np.ceil(n_tiles / batch_size)) 1214 1215 for batch_id in range(n_batches): 1216 tile_start = batch_id * batch_size 1217 tile_stop = min(tile_start + batch_size, n_tiles) 1218 1219 batched_embeddings, input_shapes, original_shapes = [], [], [] 1220 for tile_id in range(tile_start, tile_stop): 1221 # Get the image embeddings from the predictor for this tile. 1222 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i, tile_id=tile_id) 1223 1224 batched_embeddings.append(self._predictor.features) 1225 input_shapes.append(tuple(self._predictor.input_size)) 1226 original_shapes.append(tuple(self._predictor.original_size)) 1227 1228 batched_output = self._predict_decoder(batched_embeddings, input_shapes, original_shapes) 1229 1230 for output_id, tile_id in enumerate(range(tile_start, tile_stop)): 1231 output = batched_output[output_id] 1232 assert output.shape[0] == 3 1233 1234 # Set the predictions in the output for this tile. 1235 block = tiling.getBlockWithHalo(tile_id, halo=list(halo)) 1236 local_bb = tuple( 1237 slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end) 1238 ) 1239 inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end)) 1240 1241 foreground[inner_bb] = output[0][local_bb] 1242 center_distances[inner_bb] = output[1][local_bb] 1243 boundary_distances[inner_bb] = output[2][local_bb] 1244 pbar_update(1) 1245 1246 pbar_close() 1247 1248 # Set the state. 1249 self._foreground = foreground 1250 self._center_distances = center_distances 1251 self._boundary_distances = boundary_distances 1252 self._is_initialized = True
Same as InstanceSegmentationWithDecoder
but for tiled image embeddings.
1169 @torch.no_grad() 1170 def initialize( 1171 self, 1172 image: np.ndarray, 1173 image_embeddings: Optional[util.ImageEmbeddings] = None, 1174 i: Optional[int] = None, 1175 tile_shape: Optional[Tuple[int, int]] = None, 1176 halo: Optional[Tuple[int, int]] = None, 1177 verbose: bool = False, 1178 pbar_init: Optional[callable] = None, 1179 pbar_update: Optional[callable] = None, 1180 batch_size: int = 1, 1181 ) -> None: 1182 """Initialize image embeddings and decoder predictions for an image. 1183 1184 Args: 1185 image: The input image, volume or timeseries. 1186 image_embeddings: Optional precomputed image embeddings. 1187 See `util.precompute_image_embeddings` for details. 1188 i: Index for the image data. Required if `image` has three spatial dimensions 1189 or a time dimension and two spatial dimensions. 1190 tile_shape: Shape of the tiles for precomputing image embeddings. 1191 halo: Overlap of the tiles for tiled precomputation of image embeddings. 1192 verbose: Dummy input to be compatible with other function signatures. 1193 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 1194 Can be used together with pbar_update to handle napari progress bar in other thread. 1195 To enables using this function within a threadworker. 1196 pbar_update: Callback to update an external progress bar. 1197 batch_size: The batch size for image embedding computation and segmentation decoder prediction. 1198 """ 1199 original_size = image.shape[:2] 1200 image_embeddings, tile_shape, halo = _process_tiled_embeddings( 1201 self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose, batch_size=batch_size 1202 ) 1203 tiling = blocking([0, 0], original_size, tile_shape) 1204 1205 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 1206 pbar_init(tiling.numberOfBlocks, "Initialize tiled instance segmentation with decoder") 1207 1208 foreground = np.zeros(original_size, dtype="float32") 1209 center_distances = np.zeros(original_size, dtype="float32") 1210 boundary_distances = np.zeros(original_size, dtype="float32") 1211 1212 n_tiles = tiling.numberOfBlocks 1213 n_batches = int(np.ceil(n_tiles / batch_size)) 1214 1215 for batch_id in range(n_batches): 1216 tile_start = batch_id * batch_size 1217 tile_stop = min(tile_start + batch_size, n_tiles) 1218 1219 batched_embeddings, input_shapes, original_shapes = [], [], [] 1220 for tile_id in range(tile_start, tile_stop): 1221 # Get the image embeddings from the predictor for this tile. 1222 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i, tile_id=tile_id) 1223 1224 batched_embeddings.append(self._predictor.features) 1225 input_shapes.append(tuple(self._predictor.input_size)) 1226 original_shapes.append(tuple(self._predictor.original_size)) 1227 1228 batched_output = self._predict_decoder(batched_embeddings, input_shapes, original_shapes) 1229 1230 for output_id, tile_id in enumerate(range(tile_start, tile_stop)): 1231 output = batched_output[output_id] 1232 assert output.shape[0] == 3 1233 1234 # Set the predictions in the output for this tile. 1235 block = tiling.getBlockWithHalo(tile_id, halo=list(halo)) 1236 local_bb = tuple( 1237 slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end) 1238 ) 1239 inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end)) 1240 1241 foreground[inner_bb] = output[0][local_bb] 1242 center_distances[inner_bb] = output[1][local_bb] 1243 boundary_distances[inner_bb] = output[2][local_bb] 1244 pbar_update(1) 1245 1246 pbar_close() 1247 1248 # Set the state. 1249 self._foreground = foreground 1250 self._center_distances = center_distances 1251 self._boundary_distances = boundary_distances 1252 self._is_initialized = True
Initialize image embeddings and decoder predictions for an image.
Arguments:
- image: The input image, volume or timeseries.
- image_embeddings: Optional precomputed image embeddings.
See
util.precompute_image_embeddings
for details. - i: Index for the image data. Required if
image
has three spatial dimensions or a time dimension and two spatial dimensions. - tile_shape: Shape of the tiles for precomputing image embeddings.
- halo: Overlap of the tiles for tiled precomputation of image embeddings.
- verbose: Dummy input to be compatible with other function signatures.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
- batch_size: The batch size for image embedding computation and segmentation decoder prediction.
1255def get_amg( 1256 predictor: SamPredictor, is_tiled: bool, decoder: Optional[torch.nn.Module] = None, **kwargs, 1257) -> Union[AMGBase, InstanceSegmentationWithDecoder]: 1258 """Get the automatic mask generator class. 1259 1260 Args: 1261 predictor: The segment anything predictor. 1262 is_tiled: Whether tiled embeddings are used. 1263 decoder: Decoder to predict instacne segmmentation. 1264 kwargs: The keyword arguments for the amg class. 1265 1266 Returns: 1267 The automatic mask generator. 1268 """ 1269 if decoder is None: 1270 segmenter_class = TiledAutomaticMaskGenerator if is_tiled else AutomaticMaskGenerator 1271 segmenter = segmenter_class(predictor, **kwargs) 1272 else: 1273 segmenter_class = TiledInstanceSegmentationWithDecoder if is_tiled else InstanceSegmentationWithDecoder 1274 segmenter = segmenter_class(predictor, decoder, **kwargs) 1275 1276 return segmenter
Get the automatic mask generator class.
Arguments:
- predictor: The segment anything predictor.
- is_tiled: Whether tiled embeddings are used.
- decoder: Decoder to predict instacne segmmentation.
- kwargs: The keyword arguments for the amg class.
Returns:
The automatic mask generator.