micro_sam.instance_segmentation
Automated instance segmentation functionality. The classes implemented here extend the automatic instance segmentation from Segment Anything: https://computational-cell-analytics.github.io/micro-sam/micro_sam.html
1"""Automated instance segmentation functionality. 2The classes implemented here extend the automatic instance segmentation from Segment Anything: 3https://computational-cell-analytics.github.io/micro-sam/micro_sam.html 4""" 5 6import os 7import warnings 8from abc import ABC 9from copy import deepcopy 10from collections import OrderedDict 11from typing import Any, Dict, List, Optional, Tuple, Union 12 13import vigra 14import numpy as np 15from skimage.measure import label, regionprops 16from skimage.segmentation import relabel_sequential 17 18import torch 19from torchvision.ops.boxes import batched_nms, box_area 20 21from torch_em.model import UNETR 22from torch_em.util.segmentation import watershed_from_center_and_boundary_distances 23 24import elf.parallel as parallel 25from elf.parallel.filters import apply_filter 26 27from nifty.tools import blocking 28 29import segment_anything.utils.amg as amg_utils 30from segment_anything.predictor import SamPredictor 31 32from . import util 33from ._vendored import batched_mask_to_box, mask_to_rle_pytorch 34 35# 36# Utility Functionality 37# 38 39 40class _FakeInput: 41 def __init__(self, shape): 42 self.shape = shape 43 44 def __getitem__(self, index): 45 block_shape = tuple(ind.stop - ind.start for ind in index) 46 return np.zeros(block_shape, dtype="float32") 47 48 49def mask_data_to_segmentation( 50 masks: List[Dict[str, Any]], 51 with_background: bool, 52 min_object_size: int = 0, 53 max_object_size: Optional[int] = None, 54 label_masks: bool = True, 55) -> np.ndarray: 56 """Convert the output of the automatic mask generation to an instance segmentation. 57 58 Args: 59 masks: The outputs generated by AutomaticMaskGenerator or EmbeddingMaskGenerator. 60 Only supports output_mode=binary_mask. 61 with_background: Whether the segmentation has background. If yes this function assures that the largest 62 object in the output will be mapped to zero (the background value). 63 min_object_size: The minimal size of an object in pixels. 64 max_object_size: The maximal size of an object in pixels. 65 label_masks: Whether to apply connected components to the result before removing small objects. 66 67 Returns: 68 The instance segmentation. 69 """ 70 71 masks = sorted(masks, key=(lambda x: x["area"]), reverse=True) 72 # we could also get the shape from the crop box 73 shape = next(iter(masks))["segmentation"].shape 74 segmentation = np.zeros(shape, dtype="uint32") 75 76 def require_numpy(mask): 77 return mask.cpu().numpy() if torch.is_tensor(mask) else mask 78 79 seg_id = 1 80 for mask in masks: 81 if mask["area"] < min_object_size: 82 continue 83 if max_object_size is not None and mask["area"] > max_object_size: 84 continue 85 86 this_seg_id = mask.get("seg_id", seg_id) 87 segmentation[require_numpy(mask["segmentation"])] = this_seg_id 88 seg_id = this_seg_id + 1 89 90 if label_masks: 91 segmentation = label(segmentation).astype(segmentation.dtype) 92 93 seg_ids, sizes = np.unique(segmentation, return_counts=True) 94 95 # In some cases objects may be smaller than peviously calculated, 96 # since they are covered by other objects. We ensure these also get 97 # filtered out here. 98 filter_ids = seg_ids[sizes < min_object_size] 99 100 # If we run segmentation with background we also map the largest segment 101 # (the most likely background object) to zero. This is often zero already, 102 # but it does not hurt to reset that to zero either. 103 if with_background: 104 bg_id = seg_ids[np.argmax(sizes)] 105 filter_ids = np.concatenate([filter_ids, [bg_id]]) 106 107 segmentation[np.isin(segmentation, filter_ids)] = 0 108 segmentation = relabel_sequential(segmentation)[0] 109 110 return segmentation 111 112 113# 114# Classes for automatic instance segmentation 115# 116 117 118class AMGBase(ABC): 119 """Base class for the automatic mask generators. 120 """ 121 def __init__(self): 122 # the state that has to be computed by the 'initialize' method of the child classes 123 self._is_initialized = False 124 self._crop_list = None 125 self._crop_boxes = None 126 self._original_size = None 127 128 @property 129 def is_initialized(self): 130 """Whether the mask generator has already been initialized. 131 """ 132 return self._is_initialized 133 134 @property 135 def crop_list(self): 136 """The list of mask data after initialization. 137 """ 138 return self._crop_list 139 140 @property 141 def crop_boxes(self): 142 """The list of crop boxes. 143 """ 144 return self._crop_boxes 145 146 @property 147 def original_size(self): 148 """The original image size. 149 """ 150 return self._original_size 151 152 def _postprocess_batch( 153 self, 154 data, 155 crop_box, 156 original_size, 157 pred_iou_thresh, 158 stability_score_thresh, 159 box_nms_thresh, 160 ): 161 orig_h, orig_w = original_size 162 163 # filter by predicted IoU 164 if pred_iou_thresh > 0.0: 165 keep_mask = data["iou_preds"] > pred_iou_thresh 166 data.filter(keep_mask) 167 168 # filter by stability score 169 if stability_score_thresh > 0.0: 170 keep_mask = data["stability_score"] >= stability_score_thresh 171 data.filter(keep_mask) 172 173 # filter boxes that touch crop boundaries 174 keep_mask = ~amg_utils.is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) 175 if not torch.all(keep_mask): 176 data.filter(keep_mask) 177 178 # remove duplicates within this crop. 179 keep_by_nms = batched_nms( 180 data["boxes"].float(), 181 data["iou_preds"], 182 torch.zeros_like(data["boxes"][:, 0]), # categories 183 iou_threshold=box_nms_thresh, 184 ) 185 data.filter(keep_by_nms) 186 187 # return to the original image frame 188 data["boxes"] = amg_utils.uncrop_boxes_xyxy(data["boxes"], crop_box) 189 data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) 190 # the data from embedding based segmentation doesn't have the points 191 # so we skip if the corresponding key can't be found 192 try: 193 data["points"] = amg_utils.uncrop_points(data["points"], crop_box) 194 except KeyError: 195 pass 196 197 return data 198 199 def _postprocess_small_regions(self, mask_data, min_area, nms_thresh): 200 201 if len(mask_data["rles"]) == 0: 202 return mask_data 203 204 # filter small disconnected regions and holes 205 new_masks = [] 206 scores = [] 207 for rle in mask_data["rles"]: 208 mask = amg_utils.rle_to_mask(rle) 209 210 mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="holes") 211 unchanged = not changed 212 mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="islands") 213 unchanged = unchanged and not changed 214 215 new_masks.append(torch.as_tensor(mask, dtype=torch.int).unsqueeze(0)) 216 # give score=0 to changed masks and score=1 to unchanged masks 217 # so NMS will prefer ones that didn't need postprocessing 218 scores.append(float(unchanged)) 219 220 # recalculate boxes and remove any new duplicates 221 masks = torch.cat(new_masks, dim=0) 222 boxes = batched_mask_to_box(masks.to(torch.bool)) # Casting this to boolean as we work with one-hot labels. 223 keep_by_nms = batched_nms( 224 boxes.float(), 225 torch.as_tensor(scores, dtype=torch.float), 226 torch.zeros_like(boxes[:, 0]), # categories 227 iou_threshold=nms_thresh, 228 ) 229 230 # only recalculate RLEs for masks that have changed 231 for i_mask in keep_by_nms: 232 if scores[i_mask] == 0.0: 233 mask_torch = masks[i_mask].unsqueeze(0) 234 # mask_data["rles"][i_mask] = amg_utils.mask_to_rle_pytorch(mask_torch)[0] 235 mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] 236 mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly 237 mask_data.filter(keep_by_nms) 238 239 return mask_data 240 241 def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode): 242 # filter small disconnected regions and holes in masks 243 if min_mask_region_area > 0: 244 mask_data = self._postprocess_small_regions( 245 mask_data, 246 min_mask_region_area, 247 max(box_nms_thresh, crop_nms_thresh), 248 ) 249 250 # encode masks 251 if output_mode == "coco_rle": 252 mask_data["segmentations"] = [amg_utils.coco_encode_rle(rle) for rle in mask_data["rles"]] 253 elif output_mode == "binary_mask": 254 mask_data["segmentations"] = [amg_utils.rle_to_mask(rle) for rle in mask_data["rles"]] 255 else: 256 mask_data["segmentations"] = mask_data["rles"] 257 258 # write mask records 259 curr_anns = [] 260 for idx in range(len(mask_data["segmentations"])): 261 ann = { 262 "segmentation": mask_data["segmentations"][idx], 263 "area": amg_utils.area_from_rle(mask_data["rles"][idx]), 264 "bbox": amg_utils.box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), 265 "predicted_iou": mask_data["iou_preds"][idx].item(), 266 "stability_score": mask_data["stability_score"][idx].item(), 267 "crop_box": amg_utils.box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), 268 } 269 # the data from embedding based segmentation doesn't have the points 270 # so we skip if the corresponding key can't be found 271 try: 272 ann["point_coords"] = [mask_data["points"][idx].tolist()] 273 except KeyError: 274 pass 275 curr_anns.append(ann) 276 277 return curr_anns 278 279 def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None): 280 orig_h, orig_w = original_size 281 282 # serialize predictions and store in MaskData 283 data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1)) 284 if points is not None: 285 data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0), dtype=torch.float) 286 287 del masks 288 289 # calculate the stability scores 290 data["stability_score"] = amg_utils.calculate_stability_score( 291 data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset 292 ) 293 294 # threshold masks and calculate boxes 295 data["masks"] = data["masks"] > self._predictor.model.mask_threshold 296 data["masks"] = data["masks"].type(torch.bool) 297 data["boxes"] = batched_mask_to_box(data["masks"]) 298 299 # compress to RLE 300 data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w) 301 # data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"]) 302 data["rles"] = mask_to_rle_pytorch(data["masks"]) 303 del data["masks"] 304 305 return data 306 307 def get_state(self) -> Dict[str, Any]: 308 """Get the initialized state of the mask generator. 309 310 Returns: 311 State of the mask generator. 312 """ 313 if not self.is_initialized: 314 raise RuntimeError("The state has not been computed yet. Call initialize first.") 315 316 return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size} 317 318 def set_state(self, state: Dict[str, Any]) -> None: 319 """Set the state of the mask generator. 320 321 Args: 322 state: The state of the mask generator, e.g. from serialized state. 323 """ 324 self._crop_list = state["crop_list"] 325 self._crop_boxes = state["crop_boxes"] 326 self._original_size = state["original_size"] 327 self._is_initialized = True 328 329 def clear_state(self): 330 """Clear the state of the mask generator. 331 """ 332 self._crop_list = None 333 self._crop_boxes = None 334 self._original_size = None 335 self._is_initialized = False 336 337 338class AutomaticMaskGenerator(AMGBase): 339 """Generates an instance segmentation without prompts, using a point grid. 340 341 This class implements the same logic as 342 https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py 343 It decouples the computationally expensive steps of generating masks from the cheap post-processing operation 344 to filter these masks to enable grid search and interactively changing the post-processing. 345 346 Use this class as follows: 347 ```python 348 amg = AutomaticMaskGenerator(predictor) 349 amg.initialize(image) # Initialize the masks, this takes care of all expensive computations. 350 masks = amg.generate(pred_iou_thresh=0.8) # Generate the masks. This is fast and enables testing parameters 351 ``` 352 353 Args: 354 predictor: The segment anything predictor. 355 points_per_side: The number of points to be sampled along one side of the image. 356 If None, `point_grids` must provide explicit point sampling. 357 points_per_batch: The number of points run simultaneously by the model. 358 Higher numbers may be faster but use more GPU memory. 359 crop_n_layers: If >0, the mask prediction will be run again on crops of the image. 360 crop_overlap_ratio: Sets the degree to which crops overlap. 361 crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops. 362 point_grids: A lisst over explicit grids of points used for sampling masks. 363 Normalized to [0, 1] with respect to the image coordinate system. 364 stability_score_offset: The amount to shift the cutoff when calculating the stability score. 365 """ 366 def __init__( 367 self, 368 predictor: SamPredictor, 369 points_per_side: Optional[int] = 32, 370 points_per_batch: Optional[int] = None, 371 crop_n_layers: int = 0, 372 crop_overlap_ratio: float = 512 / 1500, 373 crop_n_points_downscale_factor: int = 1, 374 point_grids: Optional[List[np.ndarray]] = None, 375 stability_score_offset: float = 1.0, 376 ): 377 super().__init__() 378 379 if points_per_side is not None: 380 self.point_grids = amg_utils.build_all_layer_point_grids( 381 points_per_side, crop_n_layers, crop_n_points_downscale_factor, 382 ) 383 elif point_grids is not None: 384 self.point_grids = point_grids 385 else: 386 raise ValueError("Can't have both points_per_side and point_grid be None or not None.") 387 388 self._predictor = predictor 389 self._points_per_side = points_per_side 390 391 # we set the points per batch to 16 for mps for performance reasons 392 # and otherwise keep them at the default of 64 393 if points_per_batch is None: 394 points_per_batch = 16 if str(predictor.device) == "mps" else 64 395 self._points_per_batch = points_per_batch 396 397 self._crop_n_layers = crop_n_layers 398 self._crop_overlap_ratio = crop_overlap_ratio 399 self._crop_n_points_downscale_factor = crop_n_points_downscale_factor 400 self._stability_score_offset = stability_score_offset 401 402 def _process_batch(self, points, im_size, crop_box, original_size): 403 # run model on this batch 404 transformed_points = self._predictor.transform.apply_coords(points, im_size) 405 in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float) 406 in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) 407 masks, iou_preds, _ = self._predictor.predict_torch( 408 point_coords=in_points[:, None, :], 409 point_labels=in_labels[:, None], 410 multimask_output=True, 411 return_logits=True, 412 ) 413 data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points) 414 del masks 415 return data 416 417 def _process_crop(self, image, crop_box, crop_layer_idx, precomputed_embeddings, pbar_init=None, pbar_update=None): 418 # Crop the image and calculate embeddings. 419 x0, y0, x1, y1 = crop_box 420 cropped_im = image[y0:y1, x0:x1, :] 421 cropped_im_size = cropped_im.shape[:2] 422 423 if not precomputed_embeddings: 424 self._predictor.set_image(cropped_im) 425 426 # Get the points for this crop. 427 points_scale = np.array(cropped_im_size)[None, ::-1] 428 points_for_image = self.point_grids[crop_layer_idx] * points_scale 429 430 # Generate masks for this crop in batches. 431 data = amg_utils.MaskData() 432 n_batches = len(points_for_image) // self._points_per_batch +\ 433 int(len(points_for_image) % self._points_per_batch != 0) 434 if pbar_init is not None: 435 pbar_init(n_batches, "Predict masks for point grid prompts") 436 437 for (points,) in amg_utils.batch_iterator(self._points_per_batch, points_for_image): 438 batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size) 439 data.cat(batch_data) 440 del batch_data 441 if pbar_update is not None: 442 pbar_update(1) 443 444 if not precomputed_embeddings: 445 self._predictor.reset_image() 446 447 return data 448 449 @torch.no_grad() 450 def initialize( 451 self, 452 image: np.ndarray, 453 image_embeddings: Optional[util.ImageEmbeddings] = None, 454 i: Optional[int] = None, 455 verbose: bool = False, 456 pbar_init: Optional[callable] = None, 457 pbar_update: Optional[callable] = None, 458 ) -> None: 459 """Initialize image embeddings and masks for an image. 460 461 Args: 462 image: The input image, volume or timeseries. 463 image_embeddings: Optional precomputed image embeddings. 464 See `util.precompute_image_embeddings` for details. 465 i: Index for the image data. Required if `image` has three spatial dimensions 466 or a time dimension and two spatial dimensions. 467 verbose: Whether to print computation progress. 468 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 469 Can be used together with pbar_update to handle napari progress bar in other thread. 470 To enables using this function within a threadworker. 471 pbar_update: Callback to update an external progress bar. 472 """ 473 original_size = image.shape[:2] 474 self._original_size = original_size 475 476 crop_boxes, layer_idxs = amg_utils.generate_crop_boxes( 477 original_size, self._crop_n_layers, self._crop_overlap_ratio 478 ) 479 480 # We can set fixed image embeddings if we only have a single crop box (the default setting). 481 # Otherwise we have to recompute the embeddings for each crop and can't precompute. 482 if len(crop_boxes) == 1: 483 if image_embeddings is None: 484 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 485 util.set_precomputed(self._predictor, image_embeddings, i=i) 486 precomputed_embeddings = True 487 else: 488 precomputed_embeddings = False 489 490 # we need to cast to the image representation that is compatible with SAM 491 image = util._to_image(image) 492 493 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 494 495 crop_list = [] 496 for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 497 crop_data = self._process_crop( 498 image, crop_box, layer_idx, 499 precomputed_embeddings=precomputed_embeddings, 500 pbar_init=pbar_init, pbar_update=pbar_update, 501 ) 502 crop_list.append(crop_data) 503 pbar_close() 504 505 self._is_initialized = True 506 self._crop_list = crop_list 507 self._crop_boxes = crop_boxes 508 509 @torch.no_grad() 510 def generate( 511 self, 512 pred_iou_thresh: float = 0.88, 513 stability_score_thresh: float = 0.95, 514 box_nms_thresh: float = 0.7, 515 crop_nms_thresh: float = 0.7, 516 min_mask_region_area: int = 0, 517 output_mode: str = "binary_mask", 518 ) -> List[Dict[str, Any]]: 519 """Generate instance segmentation for the currently initialized image. 520 521 Args: 522 pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model. 523 stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask 524 under changes to the cutoff used to binarize the model prediction. 525 box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks. 526 crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops. 527 min_mask_region_area: Minimal size for the predicted masks. 528 output_mode: The form masks are returned in. 529 530 Returns: 531 The instance segmentation masks. 532 """ 533 if not self.is_initialized: 534 raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.") 535 536 data = amg_utils.MaskData() 537 for data_, crop_box in zip(self.crop_list, self.crop_boxes): 538 crop_data = self._postprocess_batch( 539 data=deepcopy(data_), 540 crop_box=crop_box, original_size=self.original_size, 541 pred_iou_thresh=pred_iou_thresh, 542 stability_score_thresh=stability_score_thresh, 543 box_nms_thresh=box_nms_thresh 544 ) 545 data.cat(crop_data) 546 547 if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0: 548 # Prefer masks from smaller crops 549 scores = 1 / box_area(data["crop_boxes"]) 550 scores = scores.to(data["boxes"].device) 551 keep_by_nms = batched_nms( 552 data["boxes"].float(), 553 scores, 554 torch.zeros_like(data["boxes"][:, 0]), # categories 555 iou_threshold=crop_nms_thresh, 556 ) 557 data.filter(keep_by_nms) 558 559 data.to_numpy() 560 masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode) 561 return masks 562 563 564# Helper function for tiled embedding computation and checking consistent state. 565def _process_tiled_embeddings(predictor, image, image_embeddings, tile_shape, halo, verbose): 566 if image_embeddings is None: 567 if tile_shape is None or halo is None: 568 raise ValueError("To compute tiled embeddings the parameters tile_shape and halo have to be passed.") 569 image_embeddings = util.precompute_image_embeddings( 570 predictor, image, tile_shape=tile_shape, halo=halo, verbose=verbose 571 ) 572 573 # Use tile shape and halo from the precomputed embeddings if not given. 574 # Otherwise check that they are consistent. 575 feats = image_embeddings["features"] 576 tile_shape_, halo_ = tuple(feats.attrs["tile_shape"]), tuple(feats.attrs["halo"]) 577 if tile_shape is None: 578 tile_shape = tile_shape_ 579 elif tile_shape != tile_shape_: 580 raise ValueError( 581 f"Inconsistent tile_shape parameter {tile_shape} with precomputed embeedings: {tile_shape_}." 582 ) 583 if halo is None: 584 halo = halo_ 585 elif halo != halo_: 586 raise ValueError(f"Inconsistent halo parameter {halo} with precomputed embeedings: {halo_}.") 587 588 return image_embeddings, tile_shape, halo 589 590 591class TiledAutomaticMaskGenerator(AutomaticMaskGenerator): 592 """Generates an instance segmentation without prompts, using a point grid. 593 594 Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings. 595 596 Args: 597 predictor: The segment anything predictor. 598 points_per_side: The number of points to be sampled along one side of the image. 599 If None, `point_grids` must provide explicit point sampling. 600 points_per_batch: The number of points run simultaneously by the model. 601 Higher numbers may be faster but use more GPU memory. 602 point_grids: A lisst over explicit grids of points used for sampling masks. 603 Normalized to [0, 1] with respect to the image coordinate system. 604 stability_score_offset: The amount to shift the cutoff when calculating the stability score. 605 """ 606 607 # We only expose the arguments that make sense for the tiled mask generator. 608 # Anything related to crops doesn't make sense, because we re-use that functionality 609 # for tiling, so these parameters wouldn't have any effect. 610 def __init__( 611 self, 612 predictor: SamPredictor, 613 points_per_side: Optional[int] = 32, 614 points_per_batch: int = 64, 615 point_grids: Optional[List[np.ndarray]] = None, 616 stability_score_offset: float = 1.0, 617 ) -> None: 618 super().__init__( 619 predictor=predictor, 620 points_per_side=points_per_side, 621 points_per_batch=points_per_batch, 622 point_grids=point_grids, 623 stability_score_offset=stability_score_offset, 624 ) 625 626 @torch.no_grad() 627 def initialize( 628 self, 629 image: np.ndarray, 630 image_embeddings: Optional[util.ImageEmbeddings] = None, 631 i: Optional[int] = None, 632 tile_shape: Optional[Tuple[int, int]] = None, 633 halo: Optional[Tuple[int, int]] = None, 634 verbose: bool = False, 635 pbar_init: Optional[callable] = None, 636 pbar_update: Optional[callable] = None, 637 ) -> None: 638 """Initialize image embeddings and masks for an image. 639 640 Args: 641 image: The input image, volume or timeseries. 642 image_embeddings: Optional precomputed image embeddings. 643 See `util.precompute_image_embeddings` for details. 644 i: Index for the image data. Required if `image` has three spatial dimensions 645 or a time dimension and two spatial dimensions. 646 tile_shape: The tile shape for embedding prediction. 647 halo: The overlap of between tiles. 648 verbose: Whether to print computation progress. 649 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 650 Can be used together with pbar_update to handle napari progress bar in other thread. 651 To enables using this function within a threadworker. 652 pbar_update: Callback to update an external progress bar. 653 """ 654 original_size = image.shape[:2] 655 self._original_size = original_size 656 657 image_embeddings, tile_shape, halo = _process_tiled_embeddings( 658 self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose, 659 ) 660 661 tiling = blocking([0, 0], original_size, tile_shape) 662 n_tiles = tiling.numberOfBlocks 663 664 # The crop box is always the full local tile. 665 tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)] 666 crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles] 667 668 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 669 pbar_init(n_tiles, "Compute masks for tile") 670 671 # We need to cast to the image representation that is compatible with SAM. 672 image = util._to_image(image) 673 674 mask_data = [] 675 for tile_id in range(n_tiles): 676 # set the pre-computed embeddings for this tile 677 features = image_embeddings["features"][tile_id] 678 tile_embeddings = { 679 "features": features, 680 "input_size": features.attrs["input_size"], 681 "original_size": features.attrs["original_size"], 682 } 683 util.set_precomputed(self._predictor, tile_embeddings, i) 684 685 # compute the mask data for this tile and append it 686 this_mask_data = self._process_crop( 687 image, crop_box=crop_boxes[tile_id], crop_layer_idx=0, precomputed_embeddings=True 688 ) 689 mask_data.append(this_mask_data) 690 pbar_update(1) 691 pbar_close() 692 693 # set the initialized data 694 self._is_initialized = True 695 self._crop_list = mask_data 696 self._crop_boxes = crop_boxes 697 698 699# 700# Instance segmentation functionality based on fine-tuned decoder 701# 702 703 704class DecoderAdapter(torch.nn.Module): 705 """Adapter to contain the UNETR decoder in a single module. 706 707 To apply the decoder on top of pre-computed embeddings for the segmentation functionality. 708 See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py 709 """ 710 def __init__(self, unetr): 711 super().__init__() 712 713 self.base = unetr.base 714 self.out_conv = unetr.out_conv 715 self.deconv_out = unetr.deconv_out 716 self.decoder_head = unetr.decoder_head 717 self.final_activation = unetr.final_activation 718 self.postprocess_masks = unetr.postprocess_masks 719 720 self.decoder = unetr.decoder 721 self.deconv1 = unetr.deconv1 722 self.deconv2 = unetr.deconv2 723 self.deconv3 = unetr.deconv3 724 self.deconv4 = unetr.deconv4 725 726 def forward(self, input_, input_shape, original_shape): 727 z12 = input_ 728 729 z9 = self.deconv1(z12) 730 z6 = self.deconv2(z9) 731 z3 = self.deconv3(z6) 732 z0 = self.deconv4(z3) 733 734 updated_from_encoder = [z9, z6, z3] 735 736 x = self.base(z12) 737 x = self.decoder(x, encoder_inputs=updated_from_encoder) 738 x = self.deconv_out(x) 739 740 x = torch.cat([x, z0], dim=1) 741 x = self.decoder_head(x) 742 743 x = self.out_conv(x) 744 if self.final_activation is not None: 745 x = self.final_activation(x) 746 747 x = self.postprocess_masks(x, input_shape, original_shape) 748 return x 749 750 751def get_unetr( 752 image_encoder: torch.nn.Module, 753 decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None, 754 device: Optional[Union[str, torch.device]] = None, 755 out_channels: int = 3, 756 flexible_load_checkpoint: bool = False, 757) -> torch.nn.Module: 758 """Get UNETR model for automatic instance segmentation. 759 760 Args: 761 image_encoder: The image encoder of the SAM model. 762 This is used as encoder by the UNETR too. 763 decoder_state: Optional decoder state to initialize the weights of the UNETR decoder. 764 device: The device. 765 out_channels: The number of output channels. 766 flexible_load_checkpoint: Whether to allow reinitialization of parameters 767 which could not be found in the provided decoder state. 768 769 Returns: 770 The UNETR model. 771 """ 772 device = util.get_device(device) 773 774 unetr = UNETR( 775 backbone="sam", 776 encoder=image_encoder, 777 out_channels=out_channels, 778 use_sam_stats=True, 779 final_activation="Sigmoid", 780 use_skip_connection=False, 781 resize_input=True, 782 ) 783 if decoder_state is not None: 784 unetr_state_dict = unetr.state_dict() 785 for k, v in unetr_state_dict.items(): 786 if not k.startswith("encoder"): 787 if flexible_load_checkpoint: # Whether allow reinitalization of params, if not found. 788 if k in decoder_state: # First check whether the key is available in the provided decoder state. 789 unetr_state_dict[k] = decoder_state[k] 790 else: # Otherwise, allow it to initialize it. 791 warnings.warn(f"Could not find '{k}' in the pretrained state dict. Hence, we reinitialize it.") 792 unetr_state_dict[k] = v 793 794 else: # Whether be strict on finding the parameter in the decoder state. 795 if k not in decoder_state: 796 raise RuntimeError(f"The parameters for '{k}' could not be found.") 797 unetr_state_dict[k] = decoder_state[k] 798 799 unetr.load_state_dict(unetr_state_dict) 800 801 unetr.to(device) 802 return unetr 803 804 805def get_decoder( 806 image_encoder: torch.nn.Module, 807 decoder_state: OrderedDict[str, torch.Tensor], 808 device: Optional[Union[str, torch.device]] = None, 809) -> DecoderAdapter: 810 """Get decoder to predict outputs for automatic instance segmentation 811 812 Args: 813 image_encoder: The image encoder of the SAM model. 814 decoder_state: State to initialize the weights of the UNETR decoder. 815 device: The device. 816 817 Returns: 818 The decoder for instance segmentation. 819 """ 820 unetr = get_unetr(image_encoder, decoder_state, device) 821 return DecoderAdapter(unetr) 822 823 824def get_predictor_and_decoder( 825 model_type: str, 826 checkpoint_path: Union[str, os.PathLike], 827 device: Optional[Union[str, torch.device]] = None, 828 peft_kwargs: Optional[Dict] = None, 829) -> Tuple[SamPredictor, DecoderAdapter]: 830 """Load the SAM model (predictor) and instance segmentation decoder. 831 832 This requires a checkpoint that contains the state for both predictor 833 and decoder. 834 835 Args: 836 model_type: The type of the image encoder used in the SAM model. 837 checkpoint_path: Path to the checkpoint from which to load the data. 838 device: The device. 839 peft_kwargs: Keyword arguments for the PEFT wrapper class. 840 841 Returns: 842 The SAM predictor. 843 The decoder for instance segmentation. 844 """ 845 device = util.get_device(device) 846 predictor, state = util.get_sam_model( 847 model_type=model_type, 848 checkpoint_path=checkpoint_path, 849 device=device, 850 return_state=True, 851 peft_kwargs=peft_kwargs, 852 ) 853 if "decoder_state" not in state: 854 raise ValueError( 855 f"The checkpoint at '{checkpoint_path}' or the chosen model '{model_type}' does not contain a decoder state" 856 ) 857 decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device) 858 return predictor, decoder 859 860 861def _watershed_from_center_and_boundary_distances_parallel( 862 center_distances, 863 boundary_distances, 864 foreground_map, 865 center_distance_threshold, 866 boundary_distance_threshold, 867 foreground_threshold, 868 distance_smoothing, 869 min_size, 870 tile_shape, 871 halo, 872 n_threads, 873 verbose=False, 874): 875 center_distances = apply_filter( 876 center_distances, "gaussianSmoothing", sigma=distance_smoothing, 877 block_shape=tile_shape, n_threads=n_threads 878 ) 879 boundary_distances = apply_filter( 880 boundary_distances, "gaussianSmoothing", sigma=distance_smoothing, 881 block_shape=tile_shape, n_threads=n_threads 882 ) 883 884 fg_mask = foreground_map > foreground_threshold 885 886 marker_map = np.logical_and( 887 center_distances < center_distance_threshold, boundary_distances < boundary_distance_threshold 888 ) 889 marker_map[~fg_mask] = 0 890 891 markers = np.zeros(marker_map.shape, dtype="uint64") 892 markers = parallel.label( 893 marker_map, out=markers, block_shape=tile_shape, n_threads=n_threads, verbose=verbose, 894 ) 895 896 seg = np.zeros_like(markers, dtype="uint64") 897 seg = parallel.seeded_watershed( 898 boundary_distances, seeds=markers, out=seg, block_shape=tile_shape, 899 halo=halo, n_threads=n_threads, verbose=verbose, mask=fg_mask, 900 ) 901 902 out = np.zeros_like(seg, dtype="uint64") 903 out = parallel.size_filter( 904 seg, out=out, min_size=min_size, block_shape=tile_shape, n_threads=n_threads, verbose=verbose 905 ) 906 907 return out 908 909 910class InstanceSegmentationWithDecoder: 911 """Generates an instance segmentation without prompts, using a decoder. 912 913 Implements the same interface as `AutomaticMaskGenerator`. 914 915 Use this class as follows: 916 ```python 917 segmenter = InstanceSegmentationWithDecoder(predictor, decoder) 918 segmenter.initialize(image) # Predict the image embeddings and decoder outputs. 919 masks = segmenter.generate(center_distance_threshold=0.75) # Generate the instance segmentation. 920 ``` 921 922 Args: 923 predictor: The segment anything predictor. 924 decoder: The decoder to predict intermediate representations 925 for instance segmentation. 926 """ 927 def __init__( 928 self, 929 predictor: SamPredictor, 930 decoder: torch.nn.Module, 931 ) -> None: 932 self._predictor = predictor 933 self._decoder = decoder 934 935 # The decoder outputs. 936 self._foreground = None 937 self._center_distances = None 938 self._boundary_distances = None 939 940 self._is_initialized = False 941 942 @property 943 def is_initialized(self): 944 """Whether the mask generator has already been initialized. 945 """ 946 return self._is_initialized 947 948 @torch.no_grad() 949 def initialize( 950 self, 951 image: np.ndarray, 952 image_embeddings: Optional[util.ImageEmbeddings] = None, 953 i: Optional[int] = None, 954 verbose: bool = False, 955 pbar_init: Optional[callable] = None, 956 pbar_update: Optional[callable] = None, 957 ) -> None: 958 """Initialize image embeddings and decoder predictions for an image. 959 960 Args: 961 image: The input image, volume or timeseries. 962 image_embeddings: Optional precomputed image embeddings. 963 See `util.precompute_image_embeddings` for details. 964 i: Index for the image data. Required if `image` has three spatial dimensions 965 or a time dimension and two spatial dimensions. 966 verbose: Whether to be verbose. 967 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 968 Can be used together with pbar_update to handle napari progress bar in other thread. 969 To enables using this function within a threadworker. 970 pbar_update: Callback to update an external progress bar. 971 """ 972 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 973 pbar_init(1, "Initialize instance segmentation with decoder") 974 975 if image_embeddings is None: 976 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 977 978 # Get the image embeddings from the predictor. 979 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i) 980 embeddings = self._predictor.features 981 input_shape = tuple(self._predictor.input_size) 982 original_shape = tuple(self._predictor.original_size) 983 984 # Run prediction with the UNETR decoder. 985 output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0) 986 assert output.shape[0] == 3, f"{output.shape}" 987 pbar_update(1) 988 pbar_close() 989 990 # Set the state. 991 self._foreground = output[0] 992 self._center_distances = output[1] 993 self._boundary_distances = output[2] 994 self._is_initialized = True 995 996 def _to_masks(self, segmentation, output_mode): 997 if output_mode != "binary_mask": 998 raise NotImplementedError 999 1000 props = regionprops(segmentation) 1001 ndim = segmentation.ndim 1002 assert ndim in (2, 3) 1003 1004 shape = segmentation.shape 1005 if ndim == 2: 1006 crop_box = [0, shape[1], 0, shape[0]] 1007 else: 1008 crop_box = [0, shape[2], 0, shape[1], 0, shape[0]] 1009 1010 # go from skimage bbox in format [y0, x0, y1, x1] to SAM format [x0, w, y0, h] 1011 def to_bbox_2d(bbox): 1012 y0, x0 = bbox[0], bbox[1] 1013 w = bbox[3] - x0 1014 h = bbox[2] - y0 1015 return [x0, w, y0, h] 1016 1017 def to_bbox_3d(bbox): 1018 z0, y0, x0 = bbox[0], bbox[1], bbox[2] 1019 w = bbox[5] - x0 1020 h = bbox[4] - y0 1021 d = bbox[3] - y0 1022 return [x0, w, y0, h, z0, d] 1023 1024 to_bbox = to_bbox_2d if ndim == 2 else to_bbox_3d 1025 masks = [ 1026 { 1027 "segmentation": segmentation == prop.label, 1028 "area": prop.area, 1029 "bbox": to_bbox(prop.bbox), 1030 "crop_box": crop_box, 1031 "seg_id": prop.label, 1032 } for prop in props 1033 ] 1034 return masks 1035 1036 def generate( 1037 self, 1038 center_distance_threshold: float = 0.5, 1039 boundary_distance_threshold: float = 0.5, 1040 foreground_threshold: float = 0.5, 1041 foreground_smoothing: float = 1.0, 1042 distance_smoothing: float = 1.6, 1043 min_size: int = 0, 1044 output_mode: Optional[str] = "binary_mask", 1045 tile_shape: Optional[Tuple[int, int]] = None, 1046 halo: Optional[Tuple[int, int]] = None, 1047 n_threads: Optional[int] = None, 1048 ) -> List[Dict[str, Any]]: 1049 """Generate instance segmentation for the currently initialized image. 1050 1051 Args: 1052 center_distance_threshold: Center distance predictions below this value will be 1053 used to find seeds (intersected with thresholded boundary distance predictions). 1054 boundary_distance_threshold: Boundary distance predictions below this value will be 1055 used to find seeds (intersected with thresholded center distance predictions). 1056 foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid 1057 checkerboard artifacts in the prediction. 1058 foreground_threshold: Foreground predictions above this value will be used as foreground mask. 1059 distance_smoothing: Sigma value for smoothing the distance predictions. 1060 min_size: Minimal object size in the segmentation result. 1061 output_mode: The form masks are returned in. Pass None to directly return the instance segmentation. 1062 tile_shape: Tile shape for parallelizing the instance segmentation post-processing. 1063 This parameter is independent from the tile shape for computing the embeddings. 1064 If not given then post-processing will not be parallelized. 1065 halo: Halo for parallel post-processing. See also `tile_shape`. 1066 n_threads: Number of threads for parallel post-processing. See also `tile_shape`. 1067 1068 Returns: 1069 The instance segmentation masks. 1070 """ 1071 if not self.is_initialized: 1072 raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.") 1073 1074 if foreground_smoothing > 0: 1075 foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing) 1076 else: 1077 foreground = self._foreground 1078 1079 if tile_shape is None: 1080 segmentation = watershed_from_center_and_boundary_distances( 1081 center_distances=self._center_distances, 1082 boundary_distances=self._boundary_distances, 1083 foreground_map=foreground, 1084 center_distance_threshold=center_distance_threshold, 1085 boundary_distance_threshold=boundary_distance_threshold, 1086 foreground_threshold=foreground_threshold, 1087 distance_smoothing=distance_smoothing, 1088 min_size=min_size, 1089 ) 1090 else: 1091 if halo is None: 1092 raise ValueError("You must pass a value for halo if tile_shape is given.") 1093 segmentation = _watershed_from_center_and_boundary_distances_parallel( 1094 center_distances=self._center_distances, 1095 boundary_distances=self._boundary_distances, 1096 foreground_map=foreground, 1097 center_distance_threshold=center_distance_threshold, 1098 boundary_distance_threshold=boundary_distance_threshold, 1099 foreground_threshold=foreground_threshold, 1100 distance_smoothing=distance_smoothing, 1101 min_size=min_size, 1102 tile_shape=tile_shape, 1103 halo=halo, 1104 n_threads=n_threads, 1105 verbose=False, 1106 ) 1107 1108 if output_mode is not None: 1109 segmentation = self._to_masks(segmentation, output_mode) 1110 return segmentation 1111 1112 def get_state(self) -> Dict[str, Any]: 1113 """Get the initialized state of the instance segmenter. 1114 1115 Returns: 1116 Instance segmentation state. 1117 """ 1118 if not self.is_initialized: 1119 raise RuntimeError("The state has not been computed yet. Call initialize first.") 1120 1121 return { 1122 "foreground": self._foreground, 1123 "center_distances": self._center_distances, 1124 "boundary_distances": self._boundary_distances, 1125 } 1126 1127 def set_state(self, state: Dict[str, Any]) -> None: 1128 """Set the state of the instance segmenter. 1129 1130 Args: 1131 state: The instance segmentation state 1132 """ 1133 self._foreground = state["foreground"] 1134 self._center_distances = state["center_distances"] 1135 self._boundary_distances = state["boundary_distances"] 1136 self._is_initialized = True 1137 1138 def clear_state(self): 1139 """Clear the state of the instance segmenter. 1140 """ 1141 self._foreground = None 1142 self._center_distances = None 1143 self._boundary_distances = None 1144 self._is_initialized = False 1145 1146 1147class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder): 1148 """Same as `InstanceSegmentationWithDecoder` but for tiled image embeddings. 1149 """ 1150 1151 @torch.no_grad() 1152 def initialize( 1153 self, 1154 image: np.ndarray, 1155 image_embeddings: Optional[util.ImageEmbeddings] = None, 1156 i: Optional[int] = None, 1157 tile_shape: Optional[Tuple[int, int]] = None, 1158 halo: Optional[Tuple[int, int]] = None, 1159 verbose: bool = False, 1160 pbar_init: Optional[callable] = None, 1161 pbar_update: Optional[callable] = None, 1162 ) -> None: 1163 """Initialize image embeddings and decoder predictions for an image. 1164 1165 Args: 1166 image: The input image, volume or timeseries. 1167 image_embeddings: Optional precomputed image embeddings. 1168 See `util.precompute_image_embeddings` for details. 1169 i: Index for the image data. Required if `image` has three spatial dimensions 1170 or a time dimension and two spatial dimensions. 1171 tile_shape: Shape of the tiles for precomputing image embeddings. 1172 halo: Overlap of the tiles for tiled precomputation of image embeddings. 1173 verbose: Dummy input to be compatible with other function signatures. 1174 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 1175 Can be used together with pbar_update to handle napari progress bar in other thread. 1176 To enables using this function within a threadworker. 1177 pbar_update: Callback to update an external progress bar. 1178 """ 1179 original_size = image.shape[:2] 1180 image_embeddings, tile_shape, halo = _process_tiled_embeddings( 1181 self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose, 1182 ) 1183 tiling = blocking([0, 0], original_size, tile_shape) 1184 1185 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 1186 pbar_init(tiling.numberOfBlocks, "Initialize tiled instance segmentation with decoder") 1187 1188 foreground = np.zeros(original_size, dtype="float32") 1189 center_distances = np.zeros(original_size, dtype="float32") 1190 boundary_distances = np.zeros(original_size, dtype="float32") 1191 1192 for tile_id in range(tiling.numberOfBlocks): 1193 1194 # Get the image embeddings from the predictor for this tile. 1195 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i, tile_id=tile_id) 1196 embeddings = self._predictor.features 1197 input_shape = tuple(self._predictor.input_size) 1198 original_shape = tuple(self._predictor.original_size) 1199 1200 # Predict with the UNETR decoder for this tile. 1201 output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0) 1202 assert output.shape[0] == 3, f"{output.shape}" 1203 1204 # Set the predictions in the output for this tile. 1205 block = tiling.getBlockWithHalo(tile_id, halo=list(halo)) 1206 local_bb = tuple( 1207 slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end) 1208 ) 1209 inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end)) 1210 1211 foreground[inner_bb] = output[0][local_bb] 1212 center_distances[inner_bb] = output[1][local_bb] 1213 boundary_distances[inner_bb] = output[2][local_bb] 1214 pbar_update(1) 1215 1216 pbar_close() 1217 1218 # Set the state. 1219 self._foreground = foreground 1220 self._center_distances = center_distances 1221 self._boundary_distances = boundary_distances 1222 self._is_initialized = True 1223 1224 1225def get_amg( 1226 predictor: SamPredictor, is_tiled: bool, decoder: Optional[torch.nn.Module] = None, **kwargs, 1227) -> Union[AMGBase, InstanceSegmentationWithDecoder]: 1228 """Get the automatic mask generator class. 1229 1230 Args: 1231 predictor: The segment anything predictor. 1232 is_tiled: Whether tiled embeddings are used. 1233 decoder: Decoder to predict instacne segmmentation. 1234 kwargs: The keyword arguments for the amg class. 1235 1236 Returns: 1237 The automatic mask generator. 1238 """ 1239 if decoder is None: 1240 segmenter_class = TiledAutomaticMaskGenerator if is_tiled else AutomaticMaskGenerator 1241 segmenter = segmenter_class(predictor, **kwargs) 1242 else: 1243 segmenter_class = TiledInstanceSegmentationWithDecoder if is_tiled else InstanceSegmentationWithDecoder 1244 segmenter = segmenter_class(predictor, decoder, **kwargs) 1245 1246 return segmenter
50def mask_data_to_segmentation( 51 masks: List[Dict[str, Any]], 52 with_background: bool, 53 min_object_size: int = 0, 54 max_object_size: Optional[int] = None, 55 label_masks: bool = True, 56) -> np.ndarray: 57 """Convert the output of the automatic mask generation to an instance segmentation. 58 59 Args: 60 masks: The outputs generated by AutomaticMaskGenerator or EmbeddingMaskGenerator. 61 Only supports output_mode=binary_mask. 62 with_background: Whether the segmentation has background. If yes this function assures that the largest 63 object in the output will be mapped to zero (the background value). 64 min_object_size: The minimal size of an object in pixels. 65 max_object_size: The maximal size of an object in pixels. 66 label_masks: Whether to apply connected components to the result before removing small objects. 67 68 Returns: 69 The instance segmentation. 70 """ 71 72 masks = sorted(masks, key=(lambda x: x["area"]), reverse=True) 73 # we could also get the shape from the crop box 74 shape = next(iter(masks))["segmentation"].shape 75 segmentation = np.zeros(shape, dtype="uint32") 76 77 def require_numpy(mask): 78 return mask.cpu().numpy() if torch.is_tensor(mask) else mask 79 80 seg_id = 1 81 for mask in masks: 82 if mask["area"] < min_object_size: 83 continue 84 if max_object_size is not None and mask["area"] > max_object_size: 85 continue 86 87 this_seg_id = mask.get("seg_id", seg_id) 88 segmentation[require_numpy(mask["segmentation"])] = this_seg_id 89 seg_id = this_seg_id + 1 90 91 if label_masks: 92 segmentation = label(segmentation).astype(segmentation.dtype) 93 94 seg_ids, sizes = np.unique(segmentation, return_counts=True) 95 96 # In some cases objects may be smaller than peviously calculated, 97 # since they are covered by other objects. We ensure these also get 98 # filtered out here. 99 filter_ids = seg_ids[sizes < min_object_size] 100 101 # If we run segmentation with background we also map the largest segment 102 # (the most likely background object) to zero. This is often zero already, 103 # but it does not hurt to reset that to zero either. 104 if with_background: 105 bg_id = seg_ids[np.argmax(sizes)] 106 filter_ids = np.concatenate([filter_ids, [bg_id]]) 107 108 segmentation[np.isin(segmentation, filter_ids)] = 0 109 segmentation = relabel_sequential(segmentation)[0] 110 111 return segmentation
Convert the output of the automatic mask generation to an instance segmentation.
Arguments:
- masks: The outputs generated by AutomaticMaskGenerator or EmbeddingMaskGenerator. Only supports output_mode=binary_mask.
- with_background: Whether the segmentation has background. If yes this function assures that the largest object in the output will be mapped to zero (the background value).
- min_object_size: The minimal size of an object in pixels.
- max_object_size: The maximal size of an object in pixels.
- label_masks: Whether to apply connected components to the result before removing small objects.
Returns:
The instance segmentation.
119class AMGBase(ABC): 120 """Base class for the automatic mask generators. 121 """ 122 def __init__(self): 123 # the state that has to be computed by the 'initialize' method of the child classes 124 self._is_initialized = False 125 self._crop_list = None 126 self._crop_boxes = None 127 self._original_size = None 128 129 @property 130 def is_initialized(self): 131 """Whether the mask generator has already been initialized. 132 """ 133 return self._is_initialized 134 135 @property 136 def crop_list(self): 137 """The list of mask data after initialization. 138 """ 139 return self._crop_list 140 141 @property 142 def crop_boxes(self): 143 """The list of crop boxes. 144 """ 145 return self._crop_boxes 146 147 @property 148 def original_size(self): 149 """The original image size. 150 """ 151 return self._original_size 152 153 def _postprocess_batch( 154 self, 155 data, 156 crop_box, 157 original_size, 158 pred_iou_thresh, 159 stability_score_thresh, 160 box_nms_thresh, 161 ): 162 orig_h, orig_w = original_size 163 164 # filter by predicted IoU 165 if pred_iou_thresh > 0.0: 166 keep_mask = data["iou_preds"] > pred_iou_thresh 167 data.filter(keep_mask) 168 169 # filter by stability score 170 if stability_score_thresh > 0.0: 171 keep_mask = data["stability_score"] >= stability_score_thresh 172 data.filter(keep_mask) 173 174 # filter boxes that touch crop boundaries 175 keep_mask = ~amg_utils.is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) 176 if not torch.all(keep_mask): 177 data.filter(keep_mask) 178 179 # remove duplicates within this crop. 180 keep_by_nms = batched_nms( 181 data["boxes"].float(), 182 data["iou_preds"], 183 torch.zeros_like(data["boxes"][:, 0]), # categories 184 iou_threshold=box_nms_thresh, 185 ) 186 data.filter(keep_by_nms) 187 188 # return to the original image frame 189 data["boxes"] = amg_utils.uncrop_boxes_xyxy(data["boxes"], crop_box) 190 data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) 191 # the data from embedding based segmentation doesn't have the points 192 # so we skip if the corresponding key can't be found 193 try: 194 data["points"] = amg_utils.uncrop_points(data["points"], crop_box) 195 except KeyError: 196 pass 197 198 return data 199 200 def _postprocess_small_regions(self, mask_data, min_area, nms_thresh): 201 202 if len(mask_data["rles"]) == 0: 203 return mask_data 204 205 # filter small disconnected regions and holes 206 new_masks = [] 207 scores = [] 208 for rle in mask_data["rles"]: 209 mask = amg_utils.rle_to_mask(rle) 210 211 mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="holes") 212 unchanged = not changed 213 mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="islands") 214 unchanged = unchanged and not changed 215 216 new_masks.append(torch.as_tensor(mask, dtype=torch.int).unsqueeze(0)) 217 # give score=0 to changed masks and score=1 to unchanged masks 218 # so NMS will prefer ones that didn't need postprocessing 219 scores.append(float(unchanged)) 220 221 # recalculate boxes and remove any new duplicates 222 masks = torch.cat(new_masks, dim=0) 223 boxes = batched_mask_to_box(masks.to(torch.bool)) # Casting this to boolean as we work with one-hot labels. 224 keep_by_nms = batched_nms( 225 boxes.float(), 226 torch.as_tensor(scores, dtype=torch.float), 227 torch.zeros_like(boxes[:, 0]), # categories 228 iou_threshold=nms_thresh, 229 ) 230 231 # only recalculate RLEs for masks that have changed 232 for i_mask in keep_by_nms: 233 if scores[i_mask] == 0.0: 234 mask_torch = masks[i_mask].unsqueeze(0) 235 # mask_data["rles"][i_mask] = amg_utils.mask_to_rle_pytorch(mask_torch)[0] 236 mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] 237 mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly 238 mask_data.filter(keep_by_nms) 239 240 return mask_data 241 242 def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode): 243 # filter small disconnected regions and holes in masks 244 if min_mask_region_area > 0: 245 mask_data = self._postprocess_small_regions( 246 mask_data, 247 min_mask_region_area, 248 max(box_nms_thresh, crop_nms_thresh), 249 ) 250 251 # encode masks 252 if output_mode == "coco_rle": 253 mask_data["segmentations"] = [amg_utils.coco_encode_rle(rle) for rle in mask_data["rles"]] 254 elif output_mode == "binary_mask": 255 mask_data["segmentations"] = [amg_utils.rle_to_mask(rle) for rle in mask_data["rles"]] 256 else: 257 mask_data["segmentations"] = mask_data["rles"] 258 259 # write mask records 260 curr_anns = [] 261 for idx in range(len(mask_data["segmentations"])): 262 ann = { 263 "segmentation": mask_data["segmentations"][idx], 264 "area": amg_utils.area_from_rle(mask_data["rles"][idx]), 265 "bbox": amg_utils.box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), 266 "predicted_iou": mask_data["iou_preds"][idx].item(), 267 "stability_score": mask_data["stability_score"][idx].item(), 268 "crop_box": amg_utils.box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), 269 } 270 # the data from embedding based segmentation doesn't have the points 271 # so we skip if the corresponding key can't be found 272 try: 273 ann["point_coords"] = [mask_data["points"][idx].tolist()] 274 except KeyError: 275 pass 276 curr_anns.append(ann) 277 278 return curr_anns 279 280 def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None): 281 orig_h, orig_w = original_size 282 283 # serialize predictions and store in MaskData 284 data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1)) 285 if points is not None: 286 data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0), dtype=torch.float) 287 288 del masks 289 290 # calculate the stability scores 291 data["stability_score"] = amg_utils.calculate_stability_score( 292 data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset 293 ) 294 295 # threshold masks and calculate boxes 296 data["masks"] = data["masks"] > self._predictor.model.mask_threshold 297 data["masks"] = data["masks"].type(torch.bool) 298 data["boxes"] = batched_mask_to_box(data["masks"]) 299 300 # compress to RLE 301 data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w) 302 # data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"]) 303 data["rles"] = mask_to_rle_pytorch(data["masks"]) 304 del data["masks"] 305 306 return data 307 308 def get_state(self) -> Dict[str, Any]: 309 """Get the initialized state of the mask generator. 310 311 Returns: 312 State of the mask generator. 313 """ 314 if not self.is_initialized: 315 raise RuntimeError("The state has not been computed yet. Call initialize first.") 316 317 return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size} 318 319 def set_state(self, state: Dict[str, Any]) -> None: 320 """Set the state of the mask generator. 321 322 Args: 323 state: The state of the mask generator, e.g. from serialized state. 324 """ 325 self._crop_list = state["crop_list"] 326 self._crop_boxes = state["crop_boxes"] 327 self._original_size = state["original_size"] 328 self._is_initialized = True 329 330 def clear_state(self): 331 """Clear the state of the mask generator. 332 """ 333 self._crop_list = None 334 self._crop_boxes = None 335 self._original_size = None 336 self._is_initialized = False
Base class for the automatic mask generators.
129 @property 130 def is_initialized(self): 131 """Whether the mask generator has already been initialized. 132 """ 133 return self._is_initialized
Whether the mask generator has already been initialized.
135 @property 136 def crop_list(self): 137 """The list of mask data after initialization. 138 """ 139 return self._crop_list
The list of mask data after initialization.
141 @property 142 def crop_boxes(self): 143 """The list of crop boxes. 144 """ 145 return self._crop_boxes
The list of crop boxes.
147 @property 148 def original_size(self): 149 """The original image size. 150 """ 151 return self._original_size
The original image size.
308 def get_state(self) -> Dict[str, Any]: 309 """Get the initialized state of the mask generator. 310 311 Returns: 312 State of the mask generator. 313 """ 314 if not self.is_initialized: 315 raise RuntimeError("The state has not been computed yet. Call initialize first.") 316 317 return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size}
Get the initialized state of the mask generator.
Returns:
State of the mask generator.
319 def set_state(self, state: Dict[str, Any]) -> None: 320 """Set the state of the mask generator. 321 322 Args: 323 state: The state of the mask generator, e.g. from serialized state. 324 """ 325 self._crop_list = state["crop_list"] 326 self._crop_boxes = state["crop_boxes"] 327 self._original_size = state["original_size"] 328 self._is_initialized = True
Set the state of the mask generator.
Arguments:
- state: The state of the mask generator, e.g. from serialized state.
339class AutomaticMaskGenerator(AMGBase): 340 """Generates an instance segmentation without prompts, using a point grid. 341 342 This class implements the same logic as 343 https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py 344 It decouples the computationally expensive steps of generating masks from the cheap post-processing operation 345 to filter these masks to enable grid search and interactively changing the post-processing. 346 347 Use this class as follows: 348 ```python 349 amg = AutomaticMaskGenerator(predictor) 350 amg.initialize(image) # Initialize the masks, this takes care of all expensive computations. 351 masks = amg.generate(pred_iou_thresh=0.8) # Generate the masks. This is fast and enables testing parameters 352 ``` 353 354 Args: 355 predictor: The segment anything predictor. 356 points_per_side: The number of points to be sampled along one side of the image. 357 If None, `point_grids` must provide explicit point sampling. 358 points_per_batch: The number of points run simultaneously by the model. 359 Higher numbers may be faster but use more GPU memory. 360 crop_n_layers: If >0, the mask prediction will be run again on crops of the image. 361 crop_overlap_ratio: Sets the degree to which crops overlap. 362 crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops. 363 point_grids: A lisst over explicit grids of points used for sampling masks. 364 Normalized to [0, 1] with respect to the image coordinate system. 365 stability_score_offset: The amount to shift the cutoff when calculating the stability score. 366 """ 367 def __init__( 368 self, 369 predictor: SamPredictor, 370 points_per_side: Optional[int] = 32, 371 points_per_batch: Optional[int] = None, 372 crop_n_layers: int = 0, 373 crop_overlap_ratio: float = 512 / 1500, 374 crop_n_points_downscale_factor: int = 1, 375 point_grids: Optional[List[np.ndarray]] = None, 376 stability_score_offset: float = 1.0, 377 ): 378 super().__init__() 379 380 if points_per_side is not None: 381 self.point_grids = amg_utils.build_all_layer_point_grids( 382 points_per_side, crop_n_layers, crop_n_points_downscale_factor, 383 ) 384 elif point_grids is not None: 385 self.point_grids = point_grids 386 else: 387 raise ValueError("Can't have both points_per_side and point_grid be None or not None.") 388 389 self._predictor = predictor 390 self._points_per_side = points_per_side 391 392 # we set the points per batch to 16 for mps for performance reasons 393 # and otherwise keep them at the default of 64 394 if points_per_batch is None: 395 points_per_batch = 16 if str(predictor.device) == "mps" else 64 396 self._points_per_batch = points_per_batch 397 398 self._crop_n_layers = crop_n_layers 399 self._crop_overlap_ratio = crop_overlap_ratio 400 self._crop_n_points_downscale_factor = crop_n_points_downscale_factor 401 self._stability_score_offset = stability_score_offset 402 403 def _process_batch(self, points, im_size, crop_box, original_size): 404 # run model on this batch 405 transformed_points = self._predictor.transform.apply_coords(points, im_size) 406 in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float) 407 in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) 408 masks, iou_preds, _ = self._predictor.predict_torch( 409 point_coords=in_points[:, None, :], 410 point_labels=in_labels[:, None], 411 multimask_output=True, 412 return_logits=True, 413 ) 414 data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points) 415 del masks 416 return data 417 418 def _process_crop(self, image, crop_box, crop_layer_idx, precomputed_embeddings, pbar_init=None, pbar_update=None): 419 # Crop the image and calculate embeddings. 420 x0, y0, x1, y1 = crop_box 421 cropped_im = image[y0:y1, x0:x1, :] 422 cropped_im_size = cropped_im.shape[:2] 423 424 if not precomputed_embeddings: 425 self._predictor.set_image(cropped_im) 426 427 # Get the points for this crop. 428 points_scale = np.array(cropped_im_size)[None, ::-1] 429 points_for_image = self.point_grids[crop_layer_idx] * points_scale 430 431 # Generate masks for this crop in batches. 432 data = amg_utils.MaskData() 433 n_batches = len(points_for_image) // self._points_per_batch +\ 434 int(len(points_for_image) % self._points_per_batch != 0) 435 if pbar_init is not None: 436 pbar_init(n_batches, "Predict masks for point grid prompts") 437 438 for (points,) in amg_utils.batch_iterator(self._points_per_batch, points_for_image): 439 batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size) 440 data.cat(batch_data) 441 del batch_data 442 if pbar_update is not None: 443 pbar_update(1) 444 445 if not precomputed_embeddings: 446 self._predictor.reset_image() 447 448 return data 449 450 @torch.no_grad() 451 def initialize( 452 self, 453 image: np.ndarray, 454 image_embeddings: Optional[util.ImageEmbeddings] = None, 455 i: Optional[int] = None, 456 verbose: bool = False, 457 pbar_init: Optional[callable] = None, 458 pbar_update: Optional[callable] = None, 459 ) -> None: 460 """Initialize image embeddings and masks for an image. 461 462 Args: 463 image: The input image, volume or timeseries. 464 image_embeddings: Optional precomputed image embeddings. 465 See `util.precompute_image_embeddings` for details. 466 i: Index for the image data. Required if `image` has three spatial dimensions 467 or a time dimension and two spatial dimensions. 468 verbose: Whether to print computation progress. 469 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 470 Can be used together with pbar_update to handle napari progress bar in other thread. 471 To enables using this function within a threadworker. 472 pbar_update: Callback to update an external progress bar. 473 """ 474 original_size = image.shape[:2] 475 self._original_size = original_size 476 477 crop_boxes, layer_idxs = amg_utils.generate_crop_boxes( 478 original_size, self._crop_n_layers, self._crop_overlap_ratio 479 ) 480 481 # We can set fixed image embeddings if we only have a single crop box (the default setting). 482 # Otherwise we have to recompute the embeddings for each crop and can't precompute. 483 if len(crop_boxes) == 1: 484 if image_embeddings is None: 485 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 486 util.set_precomputed(self._predictor, image_embeddings, i=i) 487 precomputed_embeddings = True 488 else: 489 precomputed_embeddings = False 490 491 # we need to cast to the image representation that is compatible with SAM 492 image = util._to_image(image) 493 494 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 495 496 crop_list = [] 497 for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 498 crop_data = self._process_crop( 499 image, crop_box, layer_idx, 500 precomputed_embeddings=precomputed_embeddings, 501 pbar_init=pbar_init, pbar_update=pbar_update, 502 ) 503 crop_list.append(crop_data) 504 pbar_close() 505 506 self._is_initialized = True 507 self._crop_list = crop_list 508 self._crop_boxes = crop_boxes 509 510 @torch.no_grad() 511 def generate( 512 self, 513 pred_iou_thresh: float = 0.88, 514 stability_score_thresh: float = 0.95, 515 box_nms_thresh: float = 0.7, 516 crop_nms_thresh: float = 0.7, 517 min_mask_region_area: int = 0, 518 output_mode: str = "binary_mask", 519 ) -> List[Dict[str, Any]]: 520 """Generate instance segmentation for the currently initialized image. 521 522 Args: 523 pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model. 524 stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask 525 under changes to the cutoff used to binarize the model prediction. 526 box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks. 527 crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops. 528 min_mask_region_area: Minimal size for the predicted masks. 529 output_mode: The form masks are returned in. 530 531 Returns: 532 The instance segmentation masks. 533 """ 534 if not self.is_initialized: 535 raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.") 536 537 data = amg_utils.MaskData() 538 for data_, crop_box in zip(self.crop_list, self.crop_boxes): 539 crop_data = self._postprocess_batch( 540 data=deepcopy(data_), 541 crop_box=crop_box, original_size=self.original_size, 542 pred_iou_thresh=pred_iou_thresh, 543 stability_score_thresh=stability_score_thresh, 544 box_nms_thresh=box_nms_thresh 545 ) 546 data.cat(crop_data) 547 548 if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0: 549 # Prefer masks from smaller crops 550 scores = 1 / box_area(data["crop_boxes"]) 551 scores = scores.to(data["boxes"].device) 552 keep_by_nms = batched_nms( 553 data["boxes"].float(), 554 scores, 555 torch.zeros_like(data["boxes"][:, 0]), # categories 556 iou_threshold=crop_nms_thresh, 557 ) 558 data.filter(keep_by_nms) 559 560 data.to_numpy() 561 masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode) 562 return masks
Generates an instance segmentation without prompts, using a point grid.
This class implements the same logic as https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py It decouples the computationally expensive steps of generating masks from the cheap post-processing operation to filter these masks to enable grid search and interactively changing the post-processing.
Use this class as follows:
amg = AutomaticMaskGenerator(predictor)
amg.initialize(image) # Initialize the masks, this takes care of all expensive computations.
masks = amg.generate(pred_iou_thresh=0.8) # Generate the masks. This is fast and enables testing parameters
Arguments:
- predictor: The segment anything predictor.
- points_per_side: The number of points to be sampled along one side of the image.
If None,
point_grids
must provide explicit point sampling. - points_per_batch: The number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory.
- crop_n_layers: If >0, the mask prediction will be run again on crops of the image.
- crop_overlap_ratio: Sets the degree to which crops overlap.
- crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops.
- point_grids: A lisst over explicit grids of points used for sampling masks. Normalized to [0, 1] with respect to the image coordinate system.
- stability_score_offset: The amount to shift the cutoff when calculating the stability score.
367 def __init__( 368 self, 369 predictor: SamPredictor, 370 points_per_side: Optional[int] = 32, 371 points_per_batch: Optional[int] = None, 372 crop_n_layers: int = 0, 373 crop_overlap_ratio: float = 512 / 1500, 374 crop_n_points_downscale_factor: int = 1, 375 point_grids: Optional[List[np.ndarray]] = None, 376 stability_score_offset: float = 1.0, 377 ): 378 super().__init__() 379 380 if points_per_side is not None: 381 self.point_grids = amg_utils.build_all_layer_point_grids( 382 points_per_side, crop_n_layers, crop_n_points_downscale_factor, 383 ) 384 elif point_grids is not None: 385 self.point_grids = point_grids 386 else: 387 raise ValueError("Can't have both points_per_side and point_grid be None or not None.") 388 389 self._predictor = predictor 390 self._points_per_side = points_per_side 391 392 # we set the points per batch to 16 for mps for performance reasons 393 # and otherwise keep them at the default of 64 394 if points_per_batch is None: 395 points_per_batch = 16 if str(predictor.device) == "mps" else 64 396 self._points_per_batch = points_per_batch 397 398 self._crop_n_layers = crop_n_layers 399 self._crop_overlap_ratio = crop_overlap_ratio 400 self._crop_n_points_downscale_factor = crop_n_points_downscale_factor 401 self._stability_score_offset = stability_score_offset
450 @torch.no_grad() 451 def initialize( 452 self, 453 image: np.ndarray, 454 image_embeddings: Optional[util.ImageEmbeddings] = None, 455 i: Optional[int] = None, 456 verbose: bool = False, 457 pbar_init: Optional[callable] = None, 458 pbar_update: Optional[callable] = None, 459 ) -> None: 460 """Initialize image embeddings and masks for an image. 461 462 Args: 463 image: The input image, volume or timeseries. 464 image_embeddings: Optional precomputed image embeddings. 465 See `util.precompute_image_embeddings` for details. 466 i: Index for the image data. Required if `image` has three spatial dimensions 467 or a time dimension and two spatial dimensions. 468 verbose: Whether to print computation progress. 469 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 470 Can be used together with pbar_update to handle napari progress bar in other thread. 471 To enables using this function within a threadworker. 472 pbar_update: Callback to update an external progress bar. 473 """ 474 original_size = image.shape[:2] 475 self._original_size = original_size 476 477 crop_boxes, layer_idxs = amg_utils.generate_crop_boxes( 478 original_size, self._crop_n_layers, self._crop_overlap_ratio 479 ) 480 481 # We can set fixed image embeddings if we only have a single crop box (the default setting). 482 # Otherwise we have to recompute the embeddings for each crop and can't precompute. 483 if len(crop_boxes) == 1: 484 if image_embeddings is None: 485 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 486 util.set_precomputed(self._predictor, image_embeddings, i=i) 487 precomputed_embeddings = True 488 else: 489 precomputed_embeddings = False 490 491 # we need to cast to the image representation that is compatible with SAM 492 image = util._to_image(image) 493 494 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 495 496 crop_list = [] 497 for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 498 crop_data = self._process_crop( 499 image, crop_box, layer_idx, 500 precomputed_embeddings=precomputed_embeddings, 501 pbar_init=pbar_init, pbar_update=pbar_update, 502 ) 503 crop_list.append(crop_data) 504 pbar_close() 505 506 self._is_initialized = True 507 self._crop_list = crop_list 508 self._crop_boxes = crop_boxes
Initialize image embeddings and masks for an image.
Arguments:
- image: The input image, volume or timeseries.
- image_embeddings: Optional precomputed image embeddings.
See
util.precompute_image_embeddings
for details. - i: Index for the image data. Required if
image
has three spatial dimensions or a time dimension and two spatial dimensions. - verbose: Whether to print computation progress.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
510 @torch.no_grad() 511 def generate( 512 self, 513 pred_iou_thresh: float = 0.88, 514 stability_score_thresh: float = 0.95, 515 box_nms_thresh: float = 0.7, 516 crop_nms_thresh: float = 0.7, 517 min_mask_region_area: int = 0, 518 output_mode: str = "binary_mask", 519 ) -> List[Dict[str, Any]]: 520 """Generate instance segmentation for the currently initialized image. 521 522 Args: 523 pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model. 524 stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask 525 under changes to the cutoff used to binarize the model prediction. 526 box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks. 527 crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops. 528 min_mask_region_area: Minimal size for the predicted masks. 529 output_mode: The form masks are returned in. 530 531 Returns: 532 The instance segmentation masks. 533 """ 534 if not self.is_initialized: 535 raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.") 536 537 data = amg_utils.MaskData() 538 for data_, crop_box in zip(self.crop_list, self.crop_boxes): 539 crop_data = self._postprocess_batch( 540 data=deepcopy(data_), 541 crop_box=crop_box, original_size=self.original_size, 542 pred_iou_thresh=pred_iou_thresh, 543 stability_score_thresh=stability_score_thresh, 544 box_nms_thresh=box_nms_thresh 545 ) 546 data.cat(crop_data) 547 548 if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0: 549 # Prefer masks from smaller crops 550 scores = 1 / box_area(data["crop_boxes"]) 551 scores = scores.to(data["boxes"].device) 552 keep_by_nms = batched_nms( 553 data["boxes"].float(), 554 scores, 555 torch.zeros_like(data["boxes"][:, 0]), # categories 556 iou_threshold=crop_nms_thresh, 557 ) 558 data.filter(keep_by_nms) 559 560 data.to_numpy() 561 masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode) 562 return masks
Generate instance segmentation for the currently initialized image.
Arguments:
- pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
- stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask under changes to the cutoff used to binarize the model prediction.
- box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
- crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops.
- min_mask_region_area: Minimal size for the predicted masks.
- output_mode: The form masks are returned in.
Returns:
The instance segmentation masks.
Inherited Members
592class TiledAutomaticMaskGenerator(AutomaticMaskGenerator): 593 """Generates an instance segmentation without prompts, using a point grid. 594 595 Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings. 596 597 Args: 598 predictor: The segment anything predictor. 599 points_per_side: The number of points to be sampled along one side of the image. 600 If None, `point_grids` must provide explicit point sampling. 601 points_per_batch: The number of points run simultaneously by the model. 602 Higher numbers may be faster but use more GPU memory. 603 point_grids: A lisst over explicit grids of points used for sampling masks. 604 Normalized to [0, 1] with respect to the image coordinate system. 605 stability_score_offset: The amount to shift the cutoff when calculating the stability score. 606 """ 607 608 # We only expose the arguments that make sense for the tiled mask generator. 609 # Anything related to crops doesn't make sense, because we re-use that functionality 610 # for tiling, so these parameters wouldn't have any effect. 611 def __init__( 612 self, 613 predictor: SamPredictor, 614 points_per_side: Optional[int] = 32, 615 points_per_batch: int = 64, 616 point_grids: Optional[List[np.ndarray]] = None, 617 stability_score_offset: float = 1.0, 618 ) -> None: 619 super().__init__( 620 predictor=predictor, 621 points_per_side=points_per_side, 622 points_per_batch=points_per_batch, 623 point_grids=point_grids, 624 stability_score_offset=stability_score_offset, 625 ) 626 627 @torch.no_grad() 628 def initialize( 629 self, 630 image: np.ndarray, 631 image_embeddings: Optional[util.ImageEmbeddings] = None, 632 i: Optional[int] = None, 633 tile_shape: Optional[Tuple[int, int]] = None, 634 halo: Optional[Tuple[int, int]] = None, 635 verbose: bool = False, 636 pbar_init: Optional[callable] = None, 637 pbar_update: Optional[callable] = None, 638 ) -> None: 639 """Initialize image embeddings and masks for an image. 640 641 Args: 642 image: The input image, volume or timeseries. 643 image_embeddings: Optional precomputed image embeddings. 644 See `util.precompute_image_embeddings` for details. 645 i: Index for the image data. Required if `image` has three spatial dimensions 646 or a time dimension and two spatial dimensions. 647 tile_shape: The tile shape for embedding prediction. 648 halo: The overlap of between tiles. 649 verbose: Whether to print computation progress. 650 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 651 Can be used together with pbar_update to handle napari progress bar in other thread. 652 To enables using this function within a threadworker. 653 pbar_update: Callback to update an external progress bar. 654 """ 655 original_size = image.shape[:2] 656 self._original_size = original_size 657 658 image_embeddings, tile_shape, halo = _process_tiled_embeddings( 659 self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose, 660 ) 661 662 tiling = blocking([0, 0], original_size, tile_shape) 663 n_tiles = tiling.numberOfBlocks 664 665 # The crop box is always the full local tile. 666 tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)] 667 crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles] 668 669 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 670 pbar_init(n_tiles, "Compute masks for tile") 671 672 # We need to cast to the image representation that is compatible with SAM. 673 image = util._to_image(image) 674 675 mask_data = [] 676 for tile_id in range(n_tiles): 677 # set the pre-computed embeddings for this tile 678 features = image_embeddings["features"][tile_id] 679 tile_embeddings = { 680 "features": features, 681 "input_size": features.attrs["input_size"], 682 "original_size": features.attrs["original_size"], 683 } 684 util.set_precomputed(self._predictor, tile_embeddings, i) 685 686 # compute the mask data for this tile and append it 687 this_mask_data = self._process_crop( 688 image, crop_box=crop_boxes[tile_id], crop_layer_idx=0, precomputed_embeddings=True 689 ) 690 mask_data.append(this_mask_data) 691 pbar_update(1) 692 pbar_close() 693 694 # set the initialized data 695 self._is_initialized = True 696 self._crop_list = mask_data 697 self._crop_boxes = crop_boxes
Generates an instance segmentation without prompts, using a point grid.
Implements the same functionality as AutomaticMaskGenerator
but for tiled embeddings.
Arguments:
- predictor: The segment anything predictor.
- points_per_side: The number of points to be sampled along one side of the image.
If None,
point_grids
must provide explicit point sampling. - points_per_batch: The number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory.
- point_grids: A lisst over explicit grids of points used for sampling masks. Normalized to [0, 1] with respect to the image coordinate system.
- stability_score_offset: The amount to shift the cutoff when calculating the stability score.
611 def __init__( 612 self, 613 predictor: SamPredictor, 614 points_per_side: Optional[int] = 32, 615 points_per_batch: int = 64, 616 point_grids: Optional[List[np.ndarray]] = None, 617 stability_score_offset: float = 1.0, 618 ) -> None: 619 super().__init__( 620 predictor=predictor, 621 points_per_side=points_per_side, 622 points_per_batch=points_per_batch, 623 point_grids=point_grids, 624 stability_score_offset=stability_score_offset, 625 )
627 @torch.no_grad() 628 def initialize( 629 self, 630 image: np.ndarray, 631 image_embeddings: Optional[util.ImageEmbeddings] = None, 632 i: Optional[int] = None, 633 tile_shape: Optional[Tuple[int, int]] = None, 634 halo: Optional[Tuple[int, int]] = None, 635 verbose: bool = False, 636 pbar_init: Optional[callable] = None, 637 pbar_update: Optional[callable] = None, 638 ) -> None: 639 """Initialize image embeddings and masks for an image. 640 641 Args: 642 image: The input image, volume or timeseries. 643 image_embeddings: Optional precomputed image embeddings. 644 See `util.precompute_image_embeddings` for details. 645 i: Index for the image data. Required if `image` has three spatial dimensions 646 or a time dimension and two spatial dimensions. 647 tile_shape: The tile shape for embedding prediction. 648 halo: The overlap of between tiles. 649 verbose: Whether to print computation progress. 650 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 651 Can be used together with pbar_update to handle napari progress bar in other thread. 652 To enables using this function within a threadworker. 653 pbar_update: Callback to update an external progress bar. 654 """ 655 original_size = image.shape[:2] 656 self._original_size = original_size 657 658 image_embeddings, tile_shape, halo = _process_tiled_embeddings( 659 self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose, 660 ) 661 662 tiling = blocking([0, 0], original_size, tile_shape) 663 n_tiles = tiling.numberOfBlocks 664 665 # The crop box is always the full local tile. 666 tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)] 667 crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles] 668 669 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 670 pbar_init(n_tiles, "Compute masks for tile") 671 672 # We need to cast to the image representation that is compatible with SAM. 673 image = util._to_image(image) 674 675 mask_data = [] 676 for tile_id in range(n_tiles): 677 # set the pre-computed embeddings for this tile 678 features = image_embeddings["features"][tile_id] 679 tile_embeddings = { 680 "features": features, 681 "input_size": features.attrs["input_size"], 682 "original_size": features.attrs["original_size"], 683 } 684 util.set_precomputed(self._predictor, tile_embeddings, i) 685 686 # compute the mask data for this tile and append it 687 this_mask_data = self._process_crop( 688 image, crop_box=crop_boxes[tile_id], crop_layer_idx=0, precomputed_embeddings=True 689 ) 690 mask_data.append(this_mask_data) 691 pbar_update(1) 692 pbar_close() 693 694 # set the initialized data 695 self._is_initialized = True 696 self._crop_list = mask_data 697 self._crop_boxes = crop_boxes
Initialize image embeddings and masks for an image.
Arguments:
- image: The input image, volume or timeseries.
- image_embeddings: Optional precomputed image embeddings.
See
util.precompute_image_embeddings
for details. - i: Index for the image data. Required if
image
has three spatial dimensions or a time dimension and two spatial dimensions. - tile_shape: The tile shape for embedding prediction.
- halo: The overlap of between tiles.
- verbose: Whether to print computation progress.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
705class DecoderAdapter(torch.nn.Module): 706 """Adapter to contain the UNETR decoder in a single module. 707 708 To apply the decoder on top of pre-computed embeddings for the segmentation functionality. 709 See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py 710 """ 711 def __init__(self, unetr): 712 super().__init__() 713 714 self.base = unetr.base 715 self.out_conv = unetr.out_conv 716 self.deconv_out = unetr.deconv_out 717 self.decoder_head = unetr.decoder_head 718 self.final_activation = unetr.final_activation 719 self.postprocess_masks = unetr.postprocess_masks 720 721 self.decoder = unetr.decoder 722 self.deconv1 = unetr.deconv1 723 self.deconv2 = unetr.deconv2 724 self.deconv3 = unetr.deconv3 725 self.deconv4 = unetr.deconv4 726 727 def forward(self, input_, input_shape, original_shape): 728 z12 = input_ 729 730 z9 = self.deconv1(z12) 731 z6 = self.deconv2(z9) 732 z3 = self.deconv3(z6) 733 z0 = self.deconv4(z3) 734 735 updated_from_encoder = [z9, z6, z3] 736 737 x = self.base(z12) 738 x = self.decoder(x, encoder_inputs=updated_from_encoder) 739 x = self.deconv_out(x) 740 741 x = torch.cat([x, z0], dim=1) 742 x = self.decoder_head(x) 743 744 x = self.out_conv(x) 745 if self.final_activation is not None: 746 x = self.final_activation(x) 747 748 x = self.postprocess_masks(x, input_shape, original_shape) 749 return x
Adapter to contain the UNETR decoder in a single module.
To apply the decoder on top of pre-computed embeddings for the segmentation functionality. See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py
711 def __init__(self, unetr): 712 super().__init__() 713 714 self.base = unetr.base 715 self.out_conv = unetr.out_conv 716 self.deconv_out = unetr.deconv_out 717 self.decoder_head = unetr.decoder_head 718 self.final_activation = unetr.final_activation 719 self.postprocess_masks = unetr.postprocess_masks 720 721 self.decoder = unetr.decoder 722 self.deconv1 = unetr.deconv1 723 self.deconv2 = unetr.deconv2 724 self.deconv3 = unetr.deconv3 725 self.deconv4 = unetr.deconv4
Initialize internal Module state, shared by both nn.Module and ScriptModule.
727 def forward(self, input_, input_shape, original_shape): 728 z12 = input_ 729 730 z9 = self.deconv1(z12) 731 z6 = self.deconv2(z9) 732 z3 = self.deconv3(z6) 733 z0 = self.deconv4(z3) 734 735 updated_from_encoder = [z9, z6, z3] 736 737 x = self.base(z12) 738 x = self.decoder(x, encoder_inputs=updated_from_encoder) 739 x = self.deconv_out(x) 740 741 x = torch.cat([x, z0], dim=1) 742 x = self.decoder_head(x) 743 744 x = self.out_conv(x) 745 if self.final_activation is not None: 746 x = self.final_activation(x) 747 748 x = self.postprocess_masks(x, input_shape, original_shape) 749 return x
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
752def get_unetr( 753 image_encoder: torch.nn.Module, 754 decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None, 755 device: Optional[Union[str, torch.device]] = None, 756 out_channels: int = 3, 757 flexible_load_checkpoint: bool = False, 758) -> torch.nn.Module: 759 """Get UNETR model for automatic instance segmentation. 760 761 Args: 762 image_encoder: The image encoder of the SAM model. 763 This is used as encoder by the UNETR too. 764 decoder_state: Optional decoder state to initialize the weights of the UNETR decoder. 765 device: The device. 766 out_channels: The number of output channels. 767 flexible_load_checkpoint: Whether to allow reinitialization of parameters 768 which could not be found in the provided decoder state. 769 770 Returns: 771 The UNETR model. 772 """ 773 device = util.get_device(device) 774 775 unetr = UNETR( 776 backbone="sam", 777 encoder=image_encoder, 778 out_channels=out_channels, 779 use_sam_stats=True, 780 final_activation="Sigmoid", 781 use_skip_connection=False, 782 resize_input=True, 783 ) 784 if decoder_state is not None: 785 unetr_state_dict = unetr.state_dict() 786 for k, v in unetr_state_dict.items(): 787 if not k.startswith("encoder"): 788 if flexible_load_checkpoint: # Whether allow reinitalization of params, if not found. 789 if k in decoder_state: # First check whether the key is available in the provided decoder state. 790 unetr_state_dict[k] = decoder_state[k] 791 else: # Otherwise, allow it to initialize it. 792 warnings.warn(f"Could not find '{k}' in the pretrained state dict. Hence, we reinitialize it.") 793 unetr_state_dict[k] = v 794 795 else: # Whether be strict on finding the parameter in the decoder state. 796 if k not in decoder_state: 797 raise RuntimeError(f"The parameters for '{k}' could not be found.") 798 unetr_state_dict[k] = decoder_state[k] 799 800 unetr.load_state_dict(unetr_state_dict) 801 802 unetr.to(device) 803 return unetr
Get UNETR model for automatic instance segmentation.
Arguments:
- image_encoder: The image encoder of the SAM model. This is used as encoder by the UNETR too.
- decoder_state: Optional decoder state to initialize the weights of the UNETR decoder.
- device: The device.
- out_channels: The number of output channels.
- flexible_load_checkpoint: Whether to allow reinitialization of parameters which could not be found in the provided decoder state.
Returns:
The UNETR model.
806def get_decoder( 807 image_encoder: torch.nn.Module, 808 decoder_state: OrderedDict[str, torch.Tensor], 809 device: Optional[Union[str, torch.device]] = None, 810) -> DecoderAdapter: 811 """Get decoder to predict outputs for automatic instance segmentation 812 813 Args: 814 image_encoder: The image encoder of the SAM model. 815 decoder_state: State to initialize the weights of the UNETR decoder. 816 device: The device. 817 818 Returns: 819 The decoder for instance segmentation. 820 """ 821 unetr = get_unetr(image_encoder, decoder_state, device) 822 return DecoderAdapter(unetr)
Get decoder to predict outputs for automatic instance segmentation
Arguments:
- image_encoder: The image encoder of the SAM model.
- decoder_state: State to initialize the weights of the UNETR decoder.
- device: The device.
Returns:
The decoder for instance segmentation.
825def get_predictor_and_decoder( 826 model_type: str, 827 checkpoint_path: Union[str, os.PathLike], 828 device: Optional[Union[str, torch.device]] = None, 829 peft_kwargs: Optional[Dict] = None, 830) -> Tuple[SamPredictor, DecoderAdapter]: 831 """Load the SAM model (predictor) and instance segmentation decoder. 832 833 This requires a checkpoint that contains the state for both predictor 834 and decoder. 835 836 Args: 837 model_type: The type of the image encoder used in the SAM model. 838 checkpoint_path: Path to the checkpoint from which to load the data. 839 device: The device. 840 peft_kwargs: Keyword arguments for the PEFT wrapper class. 841 842 Returns: 843 The SAM predictor. 844 The decoder for instance segmentation. 845 """ 846 device = util.get_device(device) 847 predictor, state = util.get_sam_model( 848 model_type=model_type, 849 checkpoint_path=checkpoint_path, 850 device=device, 851 return_state=True, 852 peft_kwargs=peft_kwargs, 853 ) 854 if "decoder_state" not in state: 855 raise ValueError( 856 f"The checkpoint at '{checkpoint_path}' or the chosen model '{model_type}' does not contain a decoder state" 857 ) 858 decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device) 859 return predictor, decoder
Load the SAM model (predictor) and instance segmentation decoder.
This requires a checkpoint that contains the state for both predictor and decoder.
Arguments:
- model_type: The type of the image encoder used in the SAM model.
- checkpoint_path: Path to the checkpoint from which to load the data.
- device: The device.
- peft_kwargs: Keyword arguments for the PEFT wrapper class.
Returns:
The SAM predictor. The decoder for instance segmentation.
911class InstanceSegmentationWithDecoder: 912 """Generates an instance segmentation without prompts, using a decoder. 913 914 Implements the same interface as `AutomaticMaskGenerator`. 915 916 Use this class as follows: 917 ```python 918 segmenter = InstanceSegmentationWithDecoder(predictor, decoder) 919 segmenter.initialize(image) # Predict the image embeddings and decoder outputs. 920 masks = segmenter.generate(center_distance_threshold=0.75) # Generate the instance segmentation. 921 ``` 922 923 Args: 924 predictor: The segment anything predictor. 925 decoder: The decoder to predict intermediate representations 926 for instance segmentation. 927 """ 928 def __init__( 929 self, 930 predictor: SamPredictor, 931 decoder: torch.nn.Module, 932 ) -> None: 933 self._predictor = predictor 934 self._decoder = decoder 935 936 # The decoder outputs. 937 self._foreground = None 938 self._center_distances = None 939 self._boundary_distances = None 940 941 self._is_initialized = False 942 943 @property 944 def is_initialized(self): 945 """Whether the mask generator has already been initialized. 946 """ 947 return self._is_initialized 948 949 @torch.no_grad() 950 def initialize( 951 self, 952 image: np.ndarray, 953 image_embeddings: Optional[util.ImageEmbeddings] = None, 954 i: Optional[int] = None, 955 verbose: bool = False, 956 pbar_init: Optional[callable] = None, 957 pbar_update: Optional[callable] = None, 958 ) -> None: 959 """Initialize image embeddings and decoder predictions for an image. 960 961 Args: 962 image: The input image, volume or timeseries. 963 image_embeddings: Optional precomputed image embeddings. 964 See `util.precompute_image_embeddings` for details. 965 i: Index for the image data. Required if `image` has three spatial dimensions 966 or a time dimension and two spatial dimensions. 967 verbose: Whether to be verbose. 968 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 969 Can be used together with pbar_update to handle napari progress bar in other thread. 970 To enables using this function within a threadworker. 971 pbar_update: Callback to update an external progress bar. 972 """ 973 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 974 pbar_init(1, "Initialize instance segmentation with decoder") 975 976 if image_embeddings is None: 977 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 978 979 # Get the image embeddings from the predictor. 980 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i) 981 embeddings = self._predictor.features 982 input_shape = tuple(self._predictor.input_size) 983 original_shape = tuple(self._predictor.original_size) 984 985 # Run prediction with the UNETR decoder. 986 output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0) 987 assert output.shape[0] == 3, f"{output.shape}" 988 pbar_update(1) 989 pbar_close() 990 991 # Set the state. 992 self._foreground = output[0] 993 self._center_distances = output[1] 994 self._boundary_distances = output[2] 995 self._is_initialized = True 996 997 def _to_masks(self, segmentation, output_mode): 998 if output_mode != "binary_mask": 999 raise NotImplementedError 1000 1001 props = regionprops(segmentation) 1002 ndim = segmentation.ndim 1003 assert ndim in (2, 3) 1004 1005 shape = segmentation.shape 1006 if ndim == 2: 1007 crop_box = [0, shape[1], 0, shape[0]] 1008 else: 1009 crop_box = [0, shape[2], 0, shape[1], 0, shape[0]] 1010 1011 # go from skimage bbox in format [y0, x0, y1, x1] to SAM format [x0, w, y0, h] 1012 def to_bbox_2d(bbox): 1013 y0, x0 = bbox[0], bbox[1] 1014 w = bbox[3] - x0 1015 h = bbox[2] - y0 1016 return [x0, w, y0, h] 1017 1018 def to_bbox_3d(bbox): 1019 z0, y0, x0 = bbox[0], bbox[1], bbox[2] 1020 w = bbox[5] - x0 1021 h = bbox[4] - y0 1022 d = bbox[3] - y0 1023 return [x0, w, y0, h, z0, d] 1024 1025 to_bbox = to_bbox_2d if ndim == 2 else to_bbox_3d 1026 masks = [ 1027 { 1028 "segmentation": segmentation == prop.label, 1029 "area": prop.area, 1030 "bbox": to_bbox(prop.bbox), 1031 "crop_box": crop_box, 1032 "seg_id": prop.label, 1033 } for prop in props 1034 ] 1035 return masks 1036 1037 def generate( 1038 self, 1039 center_distance_threshold: float = 0.5, 1040 boundary_distance_threshold: float = 0.5, 1041 foreground_threshold: float = 0.5, 1042 foreground_smoothing: float = 1.0, 1043 distance_smoothing: float = 1.6, 1044 min_size: int = 0, 1045 output_mode: Optional[str] = "binary_mask", 1046 tile_shape: Optional[Tuple[int, int]] = None, 1047 halo: Optional[Tuple[int, int]] = None, 1048 n_threads: Optional[int] = None, 1049 ) -> List[Dict[str, Any]]: 1050 """Generate instance segmentation for the currently initialized image. 1051 1052 Args: 1053 center_distance_threshold: Center distance predictions below this value will be 1054 used to find seeds (intersected with thresholded boundary distance predictions). 1055 boundary_distance_threshold: Boundary distance predictions below this value will be 1056 used to find seeds (intersected with thresholded center distance predictions). 1057 foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid 1058 checkerboard artifacts in the prediction. 1059 foreground_threshold: Foreground predictions above this value will be used as foreground mask. 1060 distance_smoothing: Sigma value for smoothing the distance predictions. 1061 min_size: Minimal object size in the segmentation result. 1062 output_mode: The form masks are returned in. Pass None to directly return the instance segmentation. 1063 tile_shape: Tile shape for parallelizing the instance segmentation post-processing. 1064 This parameter is independent from the tile shape for computing the embeddings. 1065 If not given then post-processing will not be parallelized. 1066 halo: Halo for parallel post-processing. See also `tile_shape`. 1067 n_threads: Number of threads for parallel post-processing. See also `tile_shape`. 1068 1069 Returns: 1070 The instance segmentation masks. 1071 """ 1072 if not self.is_initialized: 1073 raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.") 1074 1075 if foreground_smoothing > 0: 1076 foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing) 1077 else: 1078 foreground = self._foreground 1079 1080 if tile_shape is None: 1081 segmentation = watershed_from_center_and_boundary_distances( 1082 center_distances=self._center_distances, 1083 boundary_distances=self._boundary_distances, 1084 foreground_map=foreground, 1085 center_distance_threshold=center_distance_threshold, 1086 boundary_distance_threshold=boundary_distance_threshold, 1087 foreground_threshold=foreground_threshold, 1088 distance_smoothing=distance_smoothing, 1089 min_size=min_size, 1090 ) 1091 else: 1092 if halo is None: 1093 raise ValueError("You must pass a value for halo if tile_shape is given.") 1094 segmentation = _watershed_from_center_and_boundary_distances_parallel( 1095 center_distances=self._center_distances, 1096 boundary_distances=self._boundary_distances, 1097 foreground_map=foreground, 1098 center_distance_threshold=center_distance_threshold, 1099 boundary_distance_threshold=boundary_distance_threshold, 1100 foreground_threshold=foreground_threshold, 1101 distance_smoothing=distance_smoothing, 1102 min_size=min_size, 1103 tile_shape=tile_shape, 1104 halo=halo, 1105 n_threads=n_threads, 1106 verbose=False, 1107 ) 1108 1109 if output_mode is not None: 1110 segmentation = self._to_masks(segmentation, output_mode) 1111 return segmentation 1112 1113 def get_state(self) -> Dict[str, Any]: 1114 """Get the initialized state of the instance segmenter. 1115 1116 Returns: 1117 Instance segmentation state. 1118 """ 1119 if not self.is_initialized: 1120 raise RuntimeError("The state has not been computed yet. Call initialize first.") 1121 1122 return { 1123 "foreground": self._foreground, 1124 "center_distances": self._center_distances, 1125 "boundary_distances": self._boundary_distances, 1126 } 1127 1128 def set_state(self, state: Dict[str, Any]) -> None: 1129 """Set the state of the instance segmenter. 1130 1131 Args: 1132 state: The instance segmentation state 1133 """ 1134 self._foreground = state["foreground"] 1135 self._center_distances = state["center_distances"] 1136 self._boundary_distances = state["boundary_distances"] 1137 self._is_initialized = True 1138 1139 def clear_state(self): 1140 """Clear the state of the instance segmenter. 1141 """ 1142 self._foreground = None 1143 self._center_distances = None 1144 self._boundary_distances = None 1145 self._is_initialized = False
Generates an instance segmentation without prompts, using a decoder.
Implements the same interface as AutomaticMaskGenerator
.
Use this class as follows:
segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
segmenter.initialize(image) # Predict the image embeddings and decoder outputs.
masks = segmenter.generate(center_distance_threshold=0.75) # Generate the instance segmentation.
Arguments:
- predictor: The segment anything predictor.
- decoder: The decoder to predict intermediate representations for instance segmentation.
928 def __init__( 929 self, 930 predictor: SamPredictor, 931 decoder: torch.nn.Module, 932 ) -> None: 933 self._predictor = predictor 934 self._decoder = decoder 935 936 # The decoder outputs. 937 self._foreground = None 938 self._center_distances = None 939 self._boundary_distances = None 940 941 self._is_initialized = False
943 @property 944 def is_initialized(self): 945 """Whether the mask generator has already been initialized. 946 """ 947 return self._is_initialized
Whether the mask generator has already been initialized.
949 @torch.no_grad() 950 def initialize( 951 self, 952 image: np.ndarray, 953 image_embeddings: Optional[util.ImageEmbeddings] = None, 954 i: Optional[int] = None, 955 verbose: bool = False, 956 pbar_init: Optional[callable] = None, 957 pbar_update: Optional[callable] = None, 958 ) -> None: 959 """Initialize image embeddings and decoder predictions for an image. 960 961 Args: 962 image: The input image, volume or timeseries. 963 image_embeddings: Optional precomputed image embeddings. 964 See `util.precompute_image_embeddings` for details. 965 i: Index for the image data. Required if `image` has three spatial dimensions 966 or a time dimension and two spatial dimensions. 967 verbose: Whether to be verbose. 968 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 969 Can be used together with pbar_update to handle napari progress bar in other thread. 970 To enables using this function within a threadworker. 971 pbar_update: Callback to update an external progress bar. 972 """ 973 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 974 pbar_init(1, "Initialize instance segmentation with decoder") 975 976 if image_embeddings is None: 977 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 978 979 # Get the image embeddings from the predictor. 980 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i) 981 embeddings = self._predictor.features 982 input_shape = tuple(self._predictor.input_size) 983 original_shape = tuple(self._predictor.original_size) 984 985 # Run prediction with the UNETR decoder. 986 output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0) 987 assert output.shape[0] == 3, f"{output.shape}" 988 pbar_update(1) 989 pbar_close() 990 991 # Set the state. 992 self._foreground = output[0] 993 self._center_distances = output[1] 994 self._boundary_distances = output[2] 995 self._is_initialized = True
Initialize image embeddings and decoder predictions for an image.
Arguments:
- image: The input image, volume or timeseries.
- image_embeddings: Optional precomputed image embeddings.
See
util.precompute_image_embeddings
for details. - i: Index for the image data. Required if
image
has three spatial dimensions or a time dimension and two spatial dimensions. - verbose: Whether to be verbose.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
1037 def generate( 1038 self, 1039 center_distance_threshold: float = 0.5, 1040 boundary_distance_threshold: float = 0.5, 1041 foreground_threshold: float = 0.5, 1042 foreground_smoothing: float = 1.0, 1043 distance_smoothing: float = 1.6, 1044 min_size: int = 0, 1045 output_mode: Optional[str] = "binary_mask", 1046 tile_shape: Optional[Tuple[int, int]] = None, 1047 halo: Optional[Tuple[int, int]] = None, 1048 n_threads: Optional[int] = None, 1049 ) -> List[Dict[str, Any]]: 1050 """Generate instance segmentation for the currently initialized image. 1051 1052 Args: 1053 center_distance_threshold: Center distance predictions below this value will be 1054 used to find seeds (intersected with thresholded boundary distance predictions). 1055 boundary_distance_threshold: Boundary distance predictions below this value will be 1056 used to find seeds (intersected with thresholded center distance predictions). 1057 foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid 1058 checkerboard artifacts in the prediction. 1059 foreground_threshold: Foreground predictions above this value will be used as foreground mask. 1060 distance_smoothing: Sigma value for smoothing the distance predictions. 1061 min_size: Minimal object size in the segmentation result. 1062 output_mode: The form masks are returned in. Pass None to directly return the instance segmentation. 1063 tile_shape: Tile shape for parallelizing the instance segmentation post-processing. 1064 This parameter is independent from the tile shape for computing the embeddings. 1065 If not given then post-processing will not be parallelized. 1066 halo: Halo for parallel post-processing. See also `tile_shape`. 1067 n_threads: Number of threads for parallel post-processing. See also `tile_shape`. 1068 1069 Returns: 1070 The instance segmentation masks. 1071 """ 1072 if not self.is_initialized: 1073 raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.") 1074 1075 if foreground_smoothing > 0: 1076 foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing) 1077 else: 1078 foreground = self._foreground 1079 1080 if tile_shape is None: 1081 segmentation = watershed_from_center_and_boundary_distances( 1082 center_distances=self._center_distances, 1083 boundary_distances=self._boundary_distances, 1084 foreground_map=foreground, 1085 center_distance_threshold=center_distance_threshold, 1086 boundary_distance_threshold=boundary_distance_threshold, 1087 foreground_threshold=foreground_threshold, 1088 distance_smoothing=distance_smoothing, 1089 min_size=min_size, 1090 ) 1091 else: 1092 if halo is None: 1093 raise ValueError("You must pass a value for halo if tile_shape is given.") 1094 segmentation = _watershed_from_center_and_boundary_distances_parallel( 1095 center_distances=self._center_distances, 1096 boundary_distances=self._boundary_distances, 1097 foreground_map=foreground, 1098 center_distance_threshold=center_distance_threshold, 1099 boundary_distance_threshold=boundary_distance_threshold, 1100 foreground_threshold=foreground_threshold, 1101 distance_smoothing=distance_smoothing, 1102 min_size=min_size, 1103 tile_shape=tile_shape, 1104 halo=halo, 1105 n_threads=n_threads, 1106 verbose=False, 1107 ) 1108 1109 if output_mode is not None: 1110 segmentation = self._to_masks(segmentation, output_mode) 1111 return segmentation
Generate instance segmentation for the currently initialized image.
Arguments:
- center_distance_threshold: Center distance predictions below this value will be used to find seeds (intersected with thresholded boundary distance predictions).
- boundary_distance_threshold: Boundary distance predictions below this value will be used to find seeds (intersected with thresholded center distance predictions).
- foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid checkerboard artifacts in the prediction.
- foreground_threshold: Foreground predictions above this value will be used as foreground mask.
- distance_smoothing: Sigma value for smoothing the distance predictions.
- min_size: Minimal object size in the segmentation result.
- output_mode: The form masks are returned in. Pass None to directly return the instance segmentation.
- tile_shape: Tile shape for parallelizing the instance segmentation post-processing. This parameter is independent from the tile shape for computing the embeddings. If not given then post-processing will not be parallelized.
- halo: Halo for parallel post-processing. See also
tile_shape
. - n_threads: Number of threads for parallel post-processing. See also
tile_shape
.
Returns:
The instance segmentation masks.
1113 def get_state(self) -> Dict[str, Any]: 1114 """Get the initialized state of the instance segmenter. 1115 1116 Returns: 1117 Instance segmentation state. 1118 """ 1119 if not self.is_initialized: 1120 raise RuntimeError("The state has not been computed yet. Call initialize first.") 1121 1122 return { 1123 "foreground": self._foreground, 1124 "center_distances": self._center_distances, 1125 "boundary_distances": self._boundary_distances, 1126 }
Get the initialized state of the instance segmenter.
Returns:
Instance segmentation state.
1128 def set_state(self, state: Dict[str, Any]) -> None: 1129 """Set the state of the instance segmenter. 1130 1131 Args: 1132 state: The instance segmentation state 1133 """ 1134 self._foreground = state["foreground"] 1135 self._center_distances = state["center_distances"] 1136 self._boundary_distances = state["boundary_distances"] 1137 self._is_initialized = True
Set the state of the instance segmenter.
Arguments:
- state: The instance segmentation state
1148class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder): 1149 """Same as `InstanceSegmentationWithDecoder` but for tiled image embeddings. 1150 """ 1151 1152 @torch.no_grad() 1153 def initialize( 1154 self, 1155 image: np.ndarray, 1156 image_embeddings: Optional[util.ImageEmbeddings] = None, 1157 i: Optional[int] = None, 1158 tile_shape: Optional[Tuple[int, int]] = None, 1159 halo: Optional[Tuple[int, int]] = None, 1160 verbose: bool = False, 1161 pbar_init: Optional[callable] = None, 1162 pbar_update: Optional[callable] = None, 1163 ) -> None: 1164 """Initialize image embeddings and decoder predictions for an image. 1165 1166 Args: 1167 image: The input image, volume or timeseries. 1168 image_embeddings: Optional precomputed image embeddings. 1169 See `util.precompute_image_embeddings` for details. 1170 i: Index for the image data. Required if `image` has three spatial dimensions 1171 or a time dimension and two spatial dimensions. 1172 tile_shape: Shape of the tiles for precomputing image embeddings. 1173 halo: Overlap of the tiles for tiled precomputation of image embeddings. 1174 verbose: Dummy input to be compatible with other function signatures. 1175 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 1176 Can be used together with pbar_update to handle napari progress bar in other thread. 1177 To enables using this function within a threadworker. 1178 pbar_update: Callback to update an external progress bar. 1179 """ 1180 original_size = image.shape[:2] 1181 image_embeddings, tile_shape, halo = _process_tiled_embeddings( 1182 self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose, 1183 ) 1184 tiling = blocking([0, 0], original_size, tile_shape) 1185 1186 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 1187 pbar_init(tiling.numberOfBlocks, "Initialize tiled instance segmentation with decoder") 1188 1189 foreground = np.zeros(original_size, dtype="float32") 1190 center_distances = np.zeros(original_size, dtype="float32") 1191 boundary_distances = np.zeros(original_size, dtype="float32") 1192 1193 for tile_id in range(tiling.numberOfBlocks): 1194 1195 # Get the image embeddings from the predictor for this tile. 1196 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i, tile_id=tile_id) 1197 embeddings = self._predictor.features 1198 input_shape = tuple(self._predictor.input_size) 1199 original_shape = tuple(self._predictor.original_size) 1200 1201 # Predict with the UNETR decoder for this tile. 1202 output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0) 1203 assert output.shape[0] == 3, f"{output.shape}" 1204 1205 # Set the predictions in the output for this tile. 1206 block = tiling.getBlockWithHalo(tile_id, halo=list(halo)) 1207 local_bb = tuple( 1208 slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end) 1209 ) 1210 inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end)) 1211 1212 foreground[inner_bb] = output[0][local_bb] 1213 center_distances[inner_bb] = output[1][local_bb] 1214 boundary_distances[inner_bb] = output[2][local_bb] 1215 pbar_update(1) 1216 1217 pbar_close() 1218 1219 # Set the state. 1220 self._foreground = foreground 1221 self._center_distances = center_distances 1222 self._boundary_distances = boundary_distances 1223 self._is_initialized = True
Same as InstanceSegmentationWithDecoder
but for tiled image embeddings.
1152 @torch.no_grad() 1153 def initialize( 1154 self, 1155 image: np.ndarray, 1156 image_embeddings: Optional[util.ImageEmbeddings] = None, 1157 i: Optional[int] = None, 1158 tile_shape: Optional[Tuple[int, int]] = None, 1159 halo: Optional[Tuple[int, int]] = None, 1160 verbose: bool = False, 1161 pbar_init: Optional[callable] = None, 1162 pbar_update: Optional[callable] = None, 1163 ) -> None: 1164 """Initialize image embeddings and decoder predictions for an image. 1165 1166 Args: 1167 image: The input image, volume or timeseries. 1168 image_embeddings: Optional precomputed image embeddings. 1169 See `util.precompute_image_embeddings` for details. 1170 i: Index for the image data. Required if `image` has three spatial dimensions 1171 or a time dimension and two spatial dimensions. 1172 tile_shape: Shape of the tiles for precomputing image embeddings. 1173 halo: Overlap of the tiles for tiled precomputation of image embeddings. 1174 verbose: Dummy input to be compatible with other function signatures. 1175 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 1176 Can be used together with pbar_update to handle napari progress bar in other thread. 1177 To enables using this function within a threadworker. 1178 pbar_update: Callback to update an external progress bar. 1179 """ 1180 original_size = image.shape[:2] 1181 image_embeddings, tile_shape, halo = _process_tiled_embeddings( 1182 self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose, 1183 ) 1184 tiling = blocking([0, 0], original_size, tile_shape) 1185 1186 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 1187 pbar_init(tiling.numberOfBlocks, "Initialize tiled instance segmentation with decoder") 1188 1189 foreground = np.zeros(original_size, dtype="float32") 1190 center_distances = np.zeros(original_size, dtype="float32") 1191 boundary_distances = np.zeros(original_size, dtype="float32") 1192 1193 for tile_id in range(tiling.numberOfBlocks): 1194 1195 # Get the image embeddings from the predictor for this tile. 1196 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i, tile_id=tile_id) 1197 embeddings = self._predictor.features 1198 input_shape = tuple(self._predictor.input_size) 1199 original_shape = tuple(self._predictor.original_size) 1200 1201 # Predict with the UNETR decoder for this tile. 1202 output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0) 1203 assert output.shape[0] == 3, f"{output.shape}" 1204 1205 # Set the predictions in the output for this tile. 1206 block = tiling.getBlockWithHalo(tile_id, halo=list(halo)) 1207 local_bb = tuple( 1208 slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end) 1209 ) 1210 inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end)) 1211 1212 foreground[inner_bb] = output[0][local_bb] 1213 center_distances[inner_bb] = output[1][local_bb] 1214 boundary_distances[inner_bb] = output[2][local_bb] 1215 pbar_update(1) 1216 1217 pbar_close() 1218 1219 # Set the state. 1220 self._foreground = foreground 1221 self._center_distances = center_distances 1222 self._boundary_distances = boundary_distances 1223 self._is_initialized = True
Initialize image embeddings and decoder predictions for an image.
Arguments:
- image: The input image, volume or timeseries.
- image_embeddings: Optional precomputed image embeddings.
See
util.precompute_image_embeddings
for details. - i: Index for the image data. Required if
image
has three spatial dimensions or a time dimension and two spatial dimensions. - tile_shape: Shape of the tiles for precomputing image embeddings.
- halo: Overlap of the tiles for tiled precomputation of image embeddings.
- verbose: Dummy input to be compatible with other function signatures.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
1226def get_amg( 1227 predictor: SamPredictor, is_tiled: bool, decoder: Optional[torch.nn.Module] = None, **kwargs, 1228) -> Union[AMGBase, InstanceSegmentationWithDecoder]: 1229 """Get the automatic mask generator class. 1230 1231 Args: 1232 predictor: The segment anything predictor. 1233 is_tiled: Whether tiled embeddings are used. 1234 decoder: Decoder to predict instacne segmmentation. 1235 kwargs: The keyword arguments for the amg class. 1236 1237 Returns: 1238 The automatic mask generator. 1239 """ 1240 if decoder is None: 1241 segmenter_class = TiledAutomaticMaskGenerator if is_tiled else AutomaticMaskGenerator 1242 segmenter = segmenter_class(predictor, **kwargs) 1243 else: 1244 segmenter_class = TiledInstanceSegmentationWithDecoder if is_tiled else InstanceSegmentationWithDecoder 1245 segmenter = segmenter_class(predictor, decoder, **kwargs) 1246 1247 return segmenter
Get the automatic mask generator class.
Arguments:
- predictor: The segment anything predictor.
- is_tiled: Whether tiled embeddings are used.
- decoder: Decoder to predict instacne segmmentation.
- kwargs: The keyword arguments for the amg class.
Returns:
The automatic mask generator.