micro_sam.instance_segmentation
Automated instance segmentation functionality. The classes implemented here extend the automatic instance segmentation from Segment Anything: https://computational-cell-analytics.github.io/micro-sam/micro_sam.html
1""" 2Automated instance segmentation functionality. 3The classes implemented here extend the automatic instance segmentation from Segment Anything: 4https://computational-cell-analytics.github.io/micro-sam/micro_sam.html 5""" 6 7import os 8from abc import ABC 9from copy import deepcopy 10from collections import OrderedDict 11from typing import Any, Dict, List, Optional, Tuple, Union 12 13import vigra 14import numpy as np 15import torch 16 17from skimage.measure import label, regionprops 18from skimage.segmentation import relabel_sequential 19from torchvision.ops.boxes import batched_nms, box_area 20 21from torch_em.model import UNETR 22from torch_em.util.segmentation import watershed_from_center_and_boundary_distances 23 24from nifty.tools import blocking 25 26import segment_anything.utils.amg as amg_utils 27from segment_anything.predictor import SamPredictor 28 29from . import util 30from ._vendored import batched_mask_to_box, mask_to_rle_pytorch 31 32# 33# Utility Functionality 34# 35 36 37class _FakeInput: 38 def __init__(self, shape): 39 self.shape = shape 40 41 def __getitem__(self, index): 42 block_shape = tuple(ind.stop - ind.start for ind in index) 43 return np.zeros(block_shape, dtype="float32") 44 45 46def mask_data_to_segmentation( 47 masks: List[Dict[str, Any]], 48 with_background: bool, 49 min_object_size: int = 0, 50 max_object_size: Optional[int] = None, 51 label_masks: bool = True, 52) -> np.ndarray: 53 """Convert the output of the automatic mask generation to an instance segmentation. 54 55 Args: 56 masks: The outputs generated by AutomaticMaskGenerator or EmbeddingMaskGenerator. 57 Only supports output_mode=binary_mask. 58 with_background: Whether the segmentation has background. If yes this function assures that the largest 59 object in the output will be mapped to zero (the background value). 60 min_object_size: The minimal size of an object in pixels. 61 max_object_size: The maximal size of an object in pixels. 62 label_masks: Whether to apply connected components to the result before removing small objects. 63 64 Returns: 65 The instance segmentation. 66 """ 67 68 masks = sorted(masks, key=(lambda x: x["area"]), reverse=True) 69 # we could also get the shape from the crop box 70 shape = next(iter(masks))["segmentation"].shape 71 segmentation = np.zeros(shape, dtype="uint32") 72 73 def require_numpy(mask): 74 return mask.cpu().numpy() if torch.is_tensor(mask) else mask 75 76 seg_id = 1 77 for mask in masks: 78 if mask["area"] < min_object_size: 79 continue 80 if max_object_size is not None and mask["area"] > max_object_size: 81 continue 82 83 this_seg_id = mask.get("seg_id", seg_id) 84 segmentation[require_numpy(mask["segmentation"])] = this_seg_id 85 seg_id = this_seg_id + 1 86 87 if label_masks: 88 segmentation = label(segmentation).astype(segmentation.dtype) 89 90 seg_ids, sizes = np.unique(segmentation, return_counts=True) 91 92 # In some cases objects may be smaller than peviously calculated, 93 # since they are covered by other objects. We ensure these also get 94 # filtered out here. 95 filter_ids = seg_ids[sizes < min_object_size] 96 97 # If we run segmentation with background we also map the largest segment 98 # (the most likely background object) to zero. This is often zero already, 99 # but it does not hurt to reset that to zero either. 100 if with_background: 101 bg_id = seg_ids[np.argmax(sizes)] 102 filter_ids = np.concatenate([filter_ids, [bg_id]]) 103 104 segmentation[np.isin(segmentation, filter_ids)] = 0 105 segmentation = relabel_sequential(segmentation)[0] 106 107 return segmentation 108 109 110# 111# Classes for automatic instance segmentation 112# 113 114 115class AMGBase(ABC): 116 """Base class for the automatic mask generators. 117 """ 118 def __init__(self): 119 # the state that has to be computed by the 'initialize' method of the child classes 120 self._is_initialized = False 121 self._crop_list = None 122 self._crop_boxes = None 123 self._original_size = None 124 125 @property 126 def is_initialized(self): 127 """Whether the mask generator has already been initialized. 128 """ 129 return self._is_initialized 130 131 @property 132 def crop_list(self): 133 """The list of mask data after initialization. 134 """ 135 return self._crop_list 136 137 @property 138 def crop_boxes(self): 139 """The list of crop boxes. 140 """ 141 return self._crop_boxes 142 143 @property 144 def original_size(self): 145 """The original image size. 146 """ 147 return self._original_size 148 149 def _postprocess_batch( 150 self, 151 data, 152 crop_box, 153 original_size, 154 pred_iou_thresh, 155 stability_score_thresh, 156 box_nms_thresh, 157 ): 158 orig_h, orig_w = original_size 159 160 # filter by predicted IoU 161 if pred_iou_thresh > 0.0: 162 keep_mask = data["iou_preds"] > pred_iou_thresh 163 data.filter(keep_mask) 164 165 # filter by stability score 166 if stability_score_thresh > 0.0: 167 keep_mask = data["stability_score"] >= stability_score_thresh 168 data.filter(keep_mask) 169 170 # filter boxes that touch crop boundaries 171 keep_mask = ~amg_utils.is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) 172 if not torch.all(keep_mask): 173 data.filter(keep_mask) 174 175 # remove duplicates within this crop. 176 keep_by_nms = batched_nms( 177 data["boxes"].float(), 178 data["iou_preds"], 179 torch.zeros_like(data["boxes"][:, 0]), # categories 180 iou_threshold=box_nms_thresh, 181 ) 182 data.filter(keep_by_nms) 183 184 # return to the original image frame 185 data["boxes"] = amg_utils.uncrop_boxes_xyxy(data["boxes"], crop_box) 186 data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) 187 # the data from embedding based segmentation doesn't have the points 188 # so we skip if the corresponding key can't be found 189 try: 190 data["points"] = amg_utils.uncrop_points(data["points"], crop_box) 191 except KeyError: 192 pass 193 194 return data 195 196 def _postprocess_small_regions(self, mask_data, min_area, nms_thresh): 197 198 if len(mask_data["rles"]) == 0: 199 return mask_data 200 201 # filter small disconnected regions and holes 202 new_masks = [] 203 scores = [] 204 for rle in mask_data["rles"]: 205 mask = amg_utils.rle_to_mask(rle) 206 207 mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="holes") 208 unchanged = not changed 209 mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="islands") 210 unchanged = unchanged and not changed 211 212 new_masks.append(torch.as_tensor(mask, dtype=torch.int).unsqueeze(0)) 213 # give score=0 to changed masks and score=1 to unchanged masks 214 # so NMS will prefer ones that didn't need postprocessing 215 scores.append(float(unchanged)) 216 217 # recalculate boxes and remove any new duplicates 218 masks = torch.cat(new_masks, dim=0) 219 boxes = batched_mask_to_box(masks.to(torch.bool)) # Casting this to boolean as we work with one-hot labels. 220 keep_by_nms = batched_nms( 221 boxes.float(), 222 torch.as_tensor(scores, dtype=torch.float), 223 torch.zeros_like(boxes[:, 0]), # categories 224 iou_threshold=nms_thresh, 225 ) 226 227 # only recalculate RLEs for masks that have changed 228 for i_mask in keep_by_nms: 229 if scores[i_mask] == 0.0: 230 mask_torch = masks[i_mask].unsqueeze(0) 231 # mask_data["rles"][i_mask] = amg_utils.mask_to_rle_pytorch(mask_torch)[0] 232 mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] 233 mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly 234 mask_data.filter(keep_by_nms) 235 236 return mask_data 237 238 def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode): 239 # filter small disconnected regions and holes in masks 240 if min_mask_region_area > 0: 241 mask_data = self._postprocess_small_regions( 242 mask_data, 243 min_mask_region_area, 244 max(box_nms_thresh, crop_nms_thresh), 245 ) 246 247 # encode masks 248 if output_mode == "coco_rle": 249 mask_data["segmentations"] = [amg_utils.coco_encode_rle(rle) for rle in mask_data["rles"]] 250 elif output_mode == "binary_mask": 251 mask_data["segmentations"] = [amg_utils.rle_to_mask(rle) for rle in mask_data["rles"]] 252 else: 253 mask_data["segmentations"] = mask_data["rles"] 254 255 # write mask records 256 curr_anns = [] 257 for idx in range(len(mask_data["segmentations"])): 258 ann = { 259 "segmentation": mask_data["segmentations"][idx], 260 "area": amg_utils.area_from_rle(mask_data["rles"][idx]), 261 "bbox": amg_utils.box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), 262 "predicted_iou": mask_data["iou_preds"][idx].item(), 263 "stability_score": mask_data["stability_score"][idx].item(), 264 "crop_box": amg_utils.box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), 265 } 266 # the data from embedding based segmentation doesn't have the points 267 # so we skip if the corresponding key can't be found 268 try: 269 ann["point_coords"] = [mask_data["points"][idx].tolist()] 270 except KeyError: 271 pass 272 curr_anns.append(ann) 273 274 return curr_anns 275 276 def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None): 277 orig_h, orig_w = original_size 278 279 # serialize predictions and store in MaskData 280 data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1)) 281 if points is not None: 282 data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0), dtype=torch.float) 283 284 del masks 285 286 # calculate the stability scores 287 data["stability_score"] = amg_utils.calculate_stability_score( 288 data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset 289 ) 290 291 # threshold masks and calculate boxes 292 data["masks"] = data["masks"] > self._predictor.model.mask_threshold 293 data["masks"] = data["masks"].type(torch.bool) 294 data["boxes"] = batched_mask_to_box(data["masks"]) 295 296 # compress to RLE 297 data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w) 298 # data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"]) 299 data["rles"] = mask_to_rle_pytorch(data["masks"]) 300 del data["masks"] 301 302 return data 303 304 def get_state(self) -> Dict[str, Any]: 305 """Get the initialized state of the mask generator. 306 307 Returns: 308 State of the mask generator. 309 """ 310 if not self.is_initialized: 311 raise RuntimeError("The state has not been computed yet. Call initialize first.") 312 313 return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size} 314 315 def set_state(self, state: Dict[str, Any]) -> None: 316 """Set the state of the mask generator. 317 318 Args: 319 state: The state of the mask generator, e.g. from serialized state. 320 """ 321 self._crop_list = state["crop_list"] 322 self._crop_boxes = state["crop_boxes"] 323 self._original_size = state["original_size"] 324 self._is_initialized = True 325 326 def clear_state(self): 327 """Clear the state of the mask generator. 328 """ 329 self._crop_list = None 330 self._crop_boxes = None 331 self._original_size = None 332 self._is_initialized = False 333 334 335class AutomaticMaskGenerator(AMGBase): 336 """Generates an instance segmentation without prompts, using a point grid. 337 338 This class implements the same logic as 339 https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py 340 It decouples the computationally expensive steps of generating masks from the cheap post-processing operation 341 to filter these masks to enable grid search and interactively changing the post-processing. 342 343 Use this class as follows: 344 ```python 345 amg = AutomaticMaskGenerator(predictor) 346 amg.initialize(image) # Initialize the masks, this takes care of all expensive computations. 347 masks = amg.generate(pred_iou_thresh=0.8) # Generate the masks. This is fast and enables testing parameters 348 ``` 349 350 Args: 351 predictor: The segment anything predictor. 352 points_per_side: The number of points to be sampled along one side of the image. 353 If None, `point_grids` must provide explicit point sampling. 354 points_per_batch: The number of points run simultaneously by the model. 355 Higher numbers may be faster but use more GPU memory. 356 crop_n_layers: If >0, the mask prediction will be run again on crops of the image. 357 crop_overlap_ratio: Sets the degree to which crops overlap. 358 crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops. 359 point_grids: A lisst over explicit grids of points used for sampling masks. 360 Normalized to [0, 1] with respect to the image coordinate system. 361 stability_score_offset: The amount to shift the cutoff when calculating the stability score. 362 """ 363 def __init__( 364 self, 365 predictor: SamPredictor, 366 points_per_side: Optional[int] = 32, 367 points_per_batch: Optional[int] = None, 368 crop_n_layers: int = 0, 369 crop_overlap_ratio: float = 512 / 1500, 370 crop_n_points_downscale_factor: int = 1, 371 point_grids: Optional[List[np.ndarray]] = None, 372 stability_score_offset: float = 1.0, 373 ): 374 super().__init__() 375 376 if points_per_side is not None: 377 self.point_grids = amg_utils.build_all_layer_point_grids( 378 points_per_side, 379 crop_n_layers, 380 crop_n_points_downscale_factor, 381 ) 382 elif point_grids is not None: 383 self.point_grids = point_grids 384 else: 385 raise ValueError("Can't have both points_per_side and point_grid be None or not None.") 386 387 self._predictor = predictor 388 self._points_per_side = points_per_side 389 390 # we set the points per batch to 16 for mps for performance reasons 391 # and otherwise keep them at the default of 64 392 if points_per_batch is None: 393 points_per_batch = 16 if str(predictor.device) == "mps" else 64 394 self._points_per_batch = points_per_batch 395 396 self._crop_n_layers = crop_n_layers 397 self._crop_overlap_ratio = crop_overlap_ratio 398 self._crop_n_points_downscale_factor = crop_n_points_downscale_factor 399 self._stability_score_offset = stability_score_offset 400 401 def _process_batch(self, points, im_size, crop_box, original_size): 402 # run model on this batch 403 transformed_points = self._predictor.transform.apply_coords(points, im_size) 404 in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float) 405 in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) 406 masks, iou_preds, _ = self._predictor.predict_torch( 407 in_points[:, None, :], 408 in_labels[:, None], 409 multimask_output=True, 410 return_logits=True, 411 ) 412 data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points) 413 del masks 414 return data 415 416 def _process_crop(self, image, crop_box, crop_layer_idx, precomputed_embeddings, pbar_init=None, pbar_update=None): 417 # Crop the image and calculate embeddings. 418 x0, y0, x1, y1 = crop_box 419 cropped_im = image[y0:y1, x0:x1, :] 420 cropped_im_size = cropped_im.shape[:2] 421 422 if not precomputed_embeddings: 423 self._predictor.set_image(cropped_im) 424 425 # Get the points for this crop. 426 points_scale = np.array(cropped_im_size)[None, ::-1] 427 points_for_image = self.point_grids[crop_layer_idx] * points_scale 428 429 # Generate masks for this crop in batches. 430 data = amg_utils.MaskData() 431 n_batches = len(points_for_image) // self._points_per_batch +\ 432 int(len(points_for_image) % self._points_per_batch != 0) 433 if pbar_init is not None: 434 pbar_init(n_batches, "Predict masks for point grid prompts") 435 436 for (points,) in amg_utils.batch_iterator(self._points_per_batch, points_for_image): 437 batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size) 438 data.cat(batch_data) 439 del batch_data 440 if pbar_update is not None: 441 pbar_update(1) 442 443 if not precomputed_embeddings: 444 self._predictor.reset_image() 445 446 return data 447 448 @torch.no_grad() 449 def initialize( 450 self, 451 image: np.ndarray, 452 image_embeddings: Optional[util.ImageEmbeddings] = None, 453 i: Optional[int] = None, 454 verbose: bool = False, 455 pbar_init: Optional[callable] = None, 456 pbar_update: Optional[callable] = None, 457 ) -> None: 458 """Initialize image embeddings and masks for an image. 459 460 Args: 461 image: The input image, volume or timeseries. 462 image_embeddings: Optional precomputed image embeddings. 463 See `util.precompute_image_embeddings` for details. 464 i: Index for the image data. Required if `image` has three spatial dimensions 465 or a time dimension and two spatial dimensions. 466 verbose: Whether to print computation progress. 467 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 468 Can be used together with pbar_update to handle napari progress bar in other thread. 469 To enables using this function within a threadworker. 470 pbar_update: Callback to update an external progress bar. 471 """ 472 original_size = image.shape[:2] 473 self._original_size = original_size 474 475 crop_boxes, layer_idxs = amg_utils.generate_crop_boxes( 476 original_size, self._crop_n_layers, self._crop_overlap_ratio 477 ) 478 479 # We can set fixed image embeddings if we only have a single crop box (the default setting). 480 # Otherwise we have to recompute the embeddings for each crop and can't precompute. 481 if len(crop_boxes) == 1: 482 if image_embeddings is None: 483 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 484 util.set_precomputed(self._predictor, image_embeddings, i=i) 485 precomputed_embeddings = True 486 else: 487 precomputed_embeddings = False 488 489 # we need to cast to the image representation that is compatible with SAM 490 image = util._to_image(image) 491 492 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 493 494 crop_list = [] 495 for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 496 crop_data = self._process_crop( 497 image, crop_box, layer_idx, 498 precomputed_embeddings=precomputed_embeddings, 499 pbar_init=pbar_init, pbar_update=pbar_update, 500 ) 501 crop_list.append(crop_data) 502 pbar_close() 503 504 self._is_initialized = True 505 self._crop_list = crop_list 506 self._crop_boxes = crop_boxes 507 508 @torch.no_grad() 509 def generate( 510 self, 511 pred_iou_thresh: float = 0.88, 512 stability_score_thresh: float = 0.95, 513 box_nms_thresh: float = 0.7, 514 crop_nms_thresh: float = 0.7, 515 min_mask_region_area: int = 0, 516 output_mode: str = "binary_mask", 517 ) -> List[Dict[str, Any]]: 518 """Generate instance segmentation for the currently initialized image. 519 520 Args: 521 pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model. 522 stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask 523 under changes to the cutoff used to binarize the model prediction. 524 box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks. 525 crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops. 526 min_mask_region_area: Minimal size for the predicted masks. 527 output_mode: The form masks are returned in. 528 529 Returns: 530 The instance segmentation masks. 531 """ 532 if not self.is_initialized: 533 raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.") 534 535 data = amg_utils.MaskData() 536 for data_, crop_box in zip(self.crop_list, self.crop_boxes): 537 crop_data = self._postprocess_batch( 538 data=deepcopy(data_), 539 crop_box=crop_box, original_size=self.original_size, 540 pred_iou_thresh=pred_iou_thresh, 541 stability_score_thresh=stability_score_thresh, 542 box_nms_thresh=box_nms_thresh 543 ) 544 data.cat(crop_data) 545 546 if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0: 547 # Prefer masks from smaller crops 548 scores = 1 / box_area(data["crop_boxes"]) 549 scores = scores.to(data["boxes"].device) 550 keep_by_nms = batched_nms( 551 data["boxes"].float(), 552 scores, 553 torch.zeros_like(data["boxes"][:, 0]), # categories 554 iou_threshold=crop_nms_thresh, 555 ) 556 data.filter(keep_by_nms) 557 558 data.to_numpy() 559 masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode) 560 return masks 561 562 563# Helper function for tiled embedding computation and checking consistent state. 564def _process_tiled_embeddings(predictor, image, image_embeddings, tile_shape, halo): 565 if image_embeddings is None: 566 if tile_shape is None or halo is None: 567 raise ValueError("To compute tiled embeddings the parameters tile_shape and halo have to be passed.") 568 image_embeddings = util.precompute_image_embeddings(predictor, image, tile_shape=tile_shape, halo=halo) 569 570 # Use tile shape and halo from the precomputed embeddings if not given. 571 # Otherwise check that they are consistent. 572 feats = image_embeddings["features"] 573 tile_shape_, halo_ = feats.attrs["tile_shape"], feats.attrs["halo"] 574 if tile_shape is None: 575 tile_shape = tile_shape_ 576 elif tile_shape != tile_shape_: 577 raise ValueError( 578 f"Inconsistent tile_shape parameter {tile_shape} with precomputed embeedings: {tile_shape_}." 579 ) 580 if halo is None: 581 halo = halo_ 582 elif halo != halo_: 583 raise ValueError(f"Inconsistent halo parameter {halo} with precomputed embeedings: {halo_}.") 584 585 return image_embeddings, tile_shape, halo 586 587 588class TiledAutomaticMaskGenerator(AutomaticMaskGenerator): 589 """Generates an instance segmentation without prompts, using a point grid. 590 591 Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings. 592 593 Args: 594 predictor: The segment anything predictor. 595 points_per_side: The number of points to be sampled along one side of the image. 596 If None, `point_grids` must provide explicit point sampling. 597 points_per_batch: The number of points run simultaneously by the model. 598 Higher numbers may be faster but use more GPU memory. 599 point_grids: A lisst over explicit grids of points used for sampling masks. 600 Normalized to [0, 1] with respect to the image coordinate system. 601 stability_score_offset: The amount to shift the cutoff when calculating the stability score. 602 """ 603 604 # We only expose the arguments that make sense for the tiled mask generator. 605 # Anything related to crops doesn't make sense, because we re-use that functionality 606 # for tiling, so these parameters wouldn't have any effect. 607 def __init__( 608 self, 609 predictor: SamPredictor, 610 points_per_side: Optional[int] = 32, 611 points_per_batch: int = 64, 612 point_grids: Optional[List[np.ndarray]] = None, 613 stability_score_offset: float = 1.0, 614 ) -> None: 615 super().__init__( 616 predictor=predictor, 617 points_per_side=points_per_side, 618 points_per_batch=points_per_batch, 619 point_grids=point_grids, 620 stability_score_offset=stability_score_offset, 621 ) 622 623 @torch.no_grad() 624 def initialize( 625 self, 626 image: np.ndarray, 627 image_embeddings: Optional[util.ImageEmbeddings] = None, 628 i: Optional[int] = None, 629 tile_shape: Optional[Tuple[int, int]] = None, 630 halo: Optional[Tuple[int, int]] = None, 631 verbose: bool = False, 632 pbar_init: Optional[callable] = None, 633 pbar_update: Optional[callable] = None, 634 ) -> None: 635 """Initialize image embeddings and masks for an image. 636 637 Args: 638 image: The input image, volume or timeseries. 639 image_embeddings: Optional precomputed image embeddings. 640 See `util.precompute_image_embeddings` for details. 641 i: Index for the image data. Required if `image` has three spatial dimensions 642 or a time dimension and two spatial dimensions. 643 tile_shape: The tile shape for embedding prediction. 644 halo: The overlap of between tiles. 645 verbose: Whether to print computation progress. 646 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 647 Can be used together with pbar_update to handle napari progress bar in other thread. 648 To enables using this function within a threadworker. 649 pbar_update: Callback to update an external progress bar. 650 """ 651 original_size = image.shape[:2] 652 self._original_size = original_size 653 654 image_embeddings, tile_shape, halo = _process_tiled_embeddings( 655 self._predictor, image, image_embeddings, tile_shape, halo 656 ) 657 658 tiling = blocking([0, 0], original_size, tile_shape) 659 n_tiles = tiling.numberOfBlocks 660 661 # The crop box is always the full local tile. 662 tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)] 663 crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles] 664 665 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 666 pbar_init(n_tiles, "Compute masks for tile") 667 668 # We need to cast to the image representation that is compatible with SAM. 669 image = util._to_image(image) 670 671 mask_data = [] 672 for tile_id in range(n_tiles): 673 # set the pre-computed embeddings for this tile 674 features = image_embeddings["features"][tile_id] 675 tile_embeddings = { 676 "features": features, 677 "input_size": features.attrs["input_size"], 678 "original_size": features.attrs["original_size"], 679 } 680 util.set_precomputed(self._predictor, tile_embeddings, i) 681 682 # compute the mask data for this tile and append it 683 this_mask_data = self._process_crop( 684 image, crop_box=crop_boxes[tile_id], crop_layer_idx=0, precomputed_embeddings=True 685 ) 686 mask_data.append(this_mask_data) 687 pbar_update(1) 688 pbar_close() 689 690 # set the initialized data 691 self._is_initialized = True 692 self._crop_list = mask_data 693 self._crop_boxes = crop_boxes 694 695 696# 697# Instance segmentation functionality based on fine-tuned decoder 698# 699 700 701class DecoderAdapter(torch.nn.Module): 702 """Adapter to contain the UNETR decoder in a single module. 703 704 To apply the decoder on top of pre-computed embeddings for 705 the segmentation functionality. 706 See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py 707 """ 708 def __init__(self, unetr): 709 super().__init__() 710 711 self.base = unetr.base 712 self.out_conv = unetr.out_conv 713 self.deconv_out = unetr.deconv_out 714 self.decoder_head = unetr.decoder_head 715 self.final_activation = unetr.final_activation 716 self.postprocess_masks = unetr.postprocess_masks 717 718 self.decoder = unetr.decoder 719 self.deconv1 = unetr.deconv1 720 self.deconv2 = unetr.deconv2 721 self.deconv3 = unetr.deconv3 722 self.deconv4 = unetr.deconv4 723 724 def forward(self, input_, input_shape, original_shape): 725 z12 = input_ 726 727 z9 = self.deconv1(z12) 728 z6 = self.deconv2(z9) 729 z3 = self.deconv3(z6) 730 z0 = self.deconv4(z3) 731 732 updated_from_encoder = [z9, z6, z3] 733 734 x = self.base(z12) 735 x = self.decoder(x, encoder_inputs=updated_from_encoder) 736 x = self.deconv_out(x) 737 738 x = torch.cat([x, z0], dim=1) 739 x = self.decoder_head(x) 740 741 x = self.out_conv(x) 742 if self.final_activation is not None: 743 x = self.final_activation(x) 744 745 x = self.postprocess_masks(x, input_shape, original_shape) 746 return x 747 748 749def get_unetr( 750 image_encoder: torch.nn.Module, 751 decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None, 752 device: Optional[Union[str, torch.device]] = None, 753) -> torch.nn.Module: 754 """Get UNETR model for automatic instance segmentation. 755 756 Args: 757 image_encoder: The image encoder of the SAM model. 758 This is used as encoder by the UNETR too. 759 decoder_state: Optional decoder state to initialize the weights 760 of the UNETR decoder. 761 device: The device. 762 Returns: 763 The UNETR model. 764 """ 765 device = util.get_device(device) 766 767 unetr = UNETR( 768 backbone="sam", 769 encoder=image_encoder, 770 out_channels=3, 771 use_sam_stats=True, 772 final_activation="Sigmoid", 773 use_skip_connection=False, 774 resize_input=True, 775 ) 776 if decoder_state is not None: 777 unetr_state_dict = unetr.state_dict() 778 for k, v in unetr_state_dict.items(): 779 if not k.startswith("encoder"): 780 unetr_state_dict[k] = decoder_state[k] 781 unetr.load_state_dict(unetr_state_dict) 782 783 unetr.to(device) 784 return unetr 785 786 787def get_decoder( 788 image_encoder: torch.nn.Module, 789 decoder_state: OrderedDict[str, torch.Tensor], 790 device: Optional[Union[str, torch.device]] = None, 791) -> DecoderAdapter: 792 """Get decoder to predict outputs for automatic instance segmentation 793 794 Args: 795 image_encoder: The image encoder of the SAM model. 796 decoder_state: State to initialize the weights of the UNETR decoder. 797 device: The device. 798 Returns: 799 The decoder for instance segmentation. 800 """ 801 unetr = get_unetr(image_encoder, decoder_state, device) 802 return DecoderAdapter(unetr) 803 804 805def get_predictor_and_decoder( 806 model_type: str, 807 checkpoint_path: Union[str, os.PathLike], 808 device: Optional[Union[str, torch.device]] = None, 809 peft_kwargs: Optional[Dict] = None, 810) -> Tuple[SamPredictor, DecoderAdapter]: 811 """Load the SAM model (predictor) and instance segmentation decoder. 812 813 This requires a checkpoint that contains the state for both predictor 814 and decoder. 815 816 Args: 817 model_type: The type of the image encoder used in the SAM model. 818 checkpoint_path: Path to the checkpoint from which to load the data. 819 device: The device. 820 lora_rank: The rank for low rank adaptation of the attention layers. 821 822 Returns: 823 The SAM predictor. 824 The decoder for instance segmentation. 825 """ 826 device = util.get_device(device) 827 predictor, state = util.get_sam_model( 828 model_type=model_type, 829 checkpoint_path=checkpoint_path, 830 device=device, 831 return_state=True, 832 peft_kwargs=peft_kwargs, 833 ) 834 if "decoder_state" not in state: 835 raise ValueError( 836 f"The checkpoint at '{checkpoint_path}' or the chosen model '{model_type}' does not contain a decoder state" 837 ) 838 decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device) 839 return predictor, decoder 840 841 842class InstanceSegmentationWithDecoder: 843 """Generates an instance segmentation without prompts, using a decoder. 844 845 Implements the same interface as `AutomaticMaskGenerator`. 846 847 Use this class as follows: 848 ```python 849 segmenter = InstanceSegmentationWithDecoder(predictor, decoder) 850 segmenter.initialize(image) # Predict the image embeddings and decoder outputs. 851 masks = segmenter.generate(center_distance_threshold=0.75) # Generate the instance segmentation. 852 ``` 853 854 Args: 855 predictor: The segment anything predictor. 856 decoder: The decoder to predict intermediate representations 857 for instance segmentation. 858 """ 859 def __init__( 860 self, 861 predictor: SamPredictor, 862 decoder: torch.nn.Module, 863 ) -> None: 864 self._predictor = predictor 865 self._decoder = decoder 866 867 # The decoder outputs. 868 self._foreground = None 869 self._center_distances = None 870 self._boundary_distances = None 871 872 self._is_initialized = False 873 874 @property 875 def is_initialized(self): 876 """Whether the mask generator has already been initialized. 877 """ 878 return self._is_initialized 879 880 @torch.no_grad() 881 def initialize( 882 self, 883 image: np.ndarray, 884 image_embeddings: Optional[util.ImageEmbeddings] = None, 885 i: Optional[int] = None, 886 verbose: bool = False, 887 pbar_init: Optional[callable] = None, 888 pbar_update: Optional[callable] = None, 889 ) -> None: 890 """Initialize image embeddings and decoder predictions for an image. 891 892 Args: 893 image: The input image, volume or timeseries. 894 image_embeddings: Optional precomputed image embeddings. 895 See `util.precompute_image_embeddings` for details. 896 i: Index for the image data. Required if `image` has three spatial dimensions 897 or a time dimension and two spatial dimensions. 898 verbose: Whether to be verbose. 899 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 900 Can be used together with pbar_update to handle napari progress bar in other thread. 901 To enables using this function within a threadworker. 902 pbar_update: Callback to update an external progress bar. 903 """ 904 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 905 pbar_init(1, "Initialize instance segmentation with decoder") 906 907 if image_embeddings is None: 908 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 909 910 # Get the image embeddings from the predictor. 911 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i) 912 embeddings = self._predictor.features 913 input_shape = tuple(self._predictor.input_size) 914 original_shape = tuple(self._predictor.original_size) 915 916 # Run prediction with the UNETR decoder. 917 output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0) 918 assert output.shape[0] == 3, f"{output.shape}" 919 pbar_update(1) 920 pbar_close() 921 922 # Set the state. 923 self._foreground = output[0] 924 self._center_distances = output[1] 925 self._boundary_distances = output[2] 926 self._is_initialized = True 927 928 def _to_masks(self, segmentation, output_mode): 929 if output_mode != "binary_mask": 930 raise NotImplementedError 931 932 props = regionprops(segmentation) 933 ndim = segmentation.ndim 934 assert ndim in (2, 3) 935 936 shape = segmentation.shape 937 if ndim == 2: 938 crop_box = [0, shape[1], 0, shape[0]] 939 else: 940 crop_box = [0, shape[2], 0, shape[1], 0, shape[0]] 941 942 # go from skimage bbox in format [y0, x0, y1, x1] to SAM format [x0, w, y0, h] 943 def to_bbox_2d(bbox): 944 y0, x0 = bbox[0], bbox[1] 945 w = bbox[3] - x0 946 h = bbox[2] - y0 947 return [x0, w, y0, h] 948 949 def to_bbox_3d(bbox): 950 z0, y0, x0 = bbox[0], bbox[1], bbox[2] 951 w = bbox[5] - x0 952 h = bbox[4] - y0 953 d = bbox[3] - y0 954 return [x0, w, y0, h, z0, d] 955 956 to_bbox = to_bbox_2d if ndim == 2 else to_bbox_3d 957 masks = [ 958 { 959 "segmentation": segmentation == prop.label, 960 "area": prop.area, 961 "bbox": to_bbox(prop.bbox), 962 "crop_box": crop_box, 963 "seg_id": prop.label, 964 } for prop in props 965 ] 966 return masks 967 968 def generate( 969 self, 970 center_distance_threshold: float = 0.5, 971 boundary_distance_threshold: float = 0.5, 972 foreground_threshold: float = 0.5, 973 foreground_smoothing: float = 1.0, 974 distance_smoothing: float = 1.6, 975 min_size: int = 0, 976 output_mode: Optional[str] = "binary_mask", 977 ) -> List[Dict[str, Any]]: 978 """Generate instance segmentation for the currently initialized image. 979 980 Args: 981 center_distance_threshold: Center distance predictions below this value will be 982 used to find seeds (intersected with thresholded boundary distance predictions). 983 boundary_distance_threshold: Boundary distance predictions below this value will be 984 used to find seeds (intersected with thresholded center distance predictions). 985 foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid 986 checkerboard artifacts in the prediction. 987 foreground_threshold: Foreground predictions above this value will be used as foreground mask. 988 distance_smoothing: Sigma value for smoothing the distance predictions. 989 min_size: Minimal object size in the segmentation result. 990 output_mode: The form masks are returned in. Pass None to directly return the instance segmentation. 991 992 Returns: 993 The instance segmentation masks. 994 """ 995 if not self.is_initialized: 996 raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.") 997 998 if foreground_smoothing > 0: 999 foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing) 1000 else: 1001 foreground = self._foreground 1002 # Further optimization: parallel implementation using elf.parallel functionality. 1003 # (Make sure to expose n_threads to avoid over-subscription in case of outer parallelization) 1004 segmentation = watershed_from_center_and_boundary_distances( 1005 self._center_distances, self._boundary_distances, foreground, 1006 center_distance_threshold=center_distance_threshold, 1007 boundary_distance_threshold=boundary_distance_threshold, 1008 foreground_threshold=foreground_threshold, 1009 distance_smoothing=distance_smoothing, 1010 min_size=min_size, 1011 ) 1012 if output_mode is not None: 1013 segmentation = self._to_masks(segmentation, output_mode) 1014 return segmentation 1015 1016 def get_state(self) -> Dict[str, Any]: 1017 """Get the initialized state of the instance segmenter. 1018 1019 Returns: 1020 Instance segmentation state. 1021 """ 1022 if not self.is_initialized: 1023 raise RuntimeError("The state has not been computed yet. Call initialize first.") 1024 1025 return { 1026 "foreground": self._foreground, 1027 "center_distances": self._center_distances, 1028 "boundary_distances": self._boundary_distances, 1029 } 1030 1031 def set_state(self, state: Dict[str, Any]) -> None: 1032 """Set the state of the instance segmenter. 1033 1034 Args: 1035 state: The instance segmentation state 1036 """ 1037 self._foreground = state["foreground"] 1038 self._center_distances = state["center_distances"] 1039 self._boundary_distances = state["boundary_distances"] 1040 self._is_initialized = True 1041 1042 def clear_state(self): 1043 """Clear the state of the instance segmenter. 1044 """ 1045 self._foreground = None 1046 self._center_distances = None 1047 self._boundary_distances = None 1048 self._is_initialized = False 1049 1050 1051class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder): 1052 """Same as `InstanceSegmentationWithDecoder` but for tiled image embeddings. 1053 """ 1054 1055 @torch.no_grad() 1056 def initialize( 1057 self, 1058 image: np.ndarray, 1059 image_embeddings: Optional[util.ImageEmbeddings] = None, 1060 i: Optional[int] = None, 1061 tile_shape: Optional[Tuple[int, int]] = None, 1062 halo: Optional[Tuple[int, int]] = None, 1063 verbose: bool = False, 1064 pbar_init: Optional[callable] = None, 1065 pbar_update: Optional[callable] = None, 1066 ) -> None: 1067 """Initialize image embeddings and decoder predictions for an image. 1068 1069 Args: 1070 image: The input image, volume or timeseries. 1071 image_embeddings: Optional precomputed image embeddings. 1072 See `util.precompute_image_embeddings` for details. 1073 i: Index for the image data. Required if `image` has three spatial dimensions 1074 or a time dimension and two spatial dimensions. 1075 verbose: Dummy input to be compatible with other function signatures. 1076 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 1077 Can be used together with pbar_update to handle napari progress bar in other thread. 1078 To enables using this function within a threadworker. 1079 pbar_update: Callback to update an external progress bar. 1080 """ 1081 original_size = image.shape[:2] 1082 image_embeddings, tile_shape, halo = _process_tiled_embeddings( 1083 self._predictor, image, image_embeddings, tile_shape, halo 1084 ) 1085 tiling = blocking([0, 0], original_size, tile_shape) 1086 1087 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 1088 pbar_init(tiling.numberOfBlocks, "Initialize tiled instance segmentation with decoder") 1089 1090 foreground = np.zeros(original_size, dtype="float32") 1091 center_distances = np.zeros(original_size, dtype="float32") 1092 boundary_distances = np.zeros(original_size, dtype="float32") 1093 1094 for tile_id in range(tiling.numberOfBlocks): 1095 1096 # Get the image embeddings from the predictor for this tile. 1097 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i, tile_id=tile_id) 1098 embeddings = self._predictor.features 1099 input_shape = tuple(self._predictor.input_size) 1100 original_shape = tuple(self._predictor.original_size) 1101 1102 # Predict with the UNETR decoder for this tile. 1103 output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0) 1104 assert output.shape[0] == 3, f"{output.shape}" 1105 1106 # Set the predictions in the output for this tile. 1107 block = tiling.getBlockWithHalo(tile_id, halo=list(halo)) 1108 local_bb = tuple( 1109 slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end) 1110 ) 1111 inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end)) 1112 1113 foreground[inner_bb] = output[0][local_bb] 1114 center_distances[inner_bb] = output[1][local_bb] 1115 boundary_distances[inner_bb] = output[2][local_bb] 1116 pbar_close() 1117 1118 # Set the state. 1119 self._foreground = foreground 1120 self._center_distances = center_distances 1121 self._boundary_distances = boundary_distances 1122 self._is_initialized = True 1123 1124 1125def get_amg( 1126 predictor: SamPredictor, is_tiled: bool, decoder: Optional[torch.nn.Module] = None, **kwargs, 1127) -> Union[AMGBase, InstanceSegmentationWithDecoder]: 1128 """Get the automatic mask generator class. 1129 1130 Args: 1131 predictor: The segment anything predictor. 1132 is_tiled: Whether tiled embeddings are used. 1133 decoder: Decoder to predict instacne segmmentation. 1134 kwargs: The keyword arguments for the amg class. 1135 1136 Returns: 1137 The automatic mask generator. 1138 """ 1139 if decoder is None: 1140 segmenter_class = TiledAutomaticMaskGenerator if is_tiled else AutomaticMaskGenerator 1141 segmenter = segmenter_class(predictor, **kwargs) 1142 else: 1143 segmenter_class = TiledInstanceSegmentationWithDecoder if is_tiled else InstanceSegmentationWithDecoder 1144 segmenter = segmenter_class(predictor, decoder, **kwargs) 1145 1146 return segmenter
47def mask_data_to_segmentation( 48 masks: List[Dict[str, Any]], 49 with_background: bool, 50 min_object_size: int = 0, 51 max_object_size: Optional[int] = None, 52 label_masks: bool = True, 53) -> np.ndarray: 54 """Convert the output of the automatic mask generation to an instance segmentation. 55 56 Args: 57 masks: The outputs generated by AutomaticMaskGenerator or EmbeddingMaskGenerator. 58 Only supports output_mode=binary_mask. 59 with_background: Whether the segmentation has background. If yes this function assures that the largest 60 object in the output will be mapped to zero (the background value). 61 min_object_size: The minimal size of an object in pixels. 62 max_object_size: The maximal size of an object in pixels. 63 label_masks: Whether to apply connected components to the result before removing small objects. 64 65 Returns: 66 The instance segmentation. 67 """ 68 69 masks = sorted(masks, key=(lambda x: x["area"]), reverse=True) 70 # we could also get the shape from the crop box 71 shape = next(iter(masks))["segmentation"].shape 72 segmentation = np.zeros(shape, dtype="uint32") 73 74 def require_numpy(mask): 75 return mask.cpu().numpy() if torch.is_tensor(mask) else mask 76 77 seg_id = 1 78 for mask in masks: 79 if mask["area"] < min_object_size: 80 continue 81 if max_object_size is not None and mask["area"] > max_object_size: 82 continue 83 84 this_seg_id = mask.get("seg_id", seg_id) 85 segmentation[require_numpy(mask["segmentation"])] = this_seg_id 86 seg_id = this_seg_id + 1 87 88 if label_masks: 89 segmentation = label(segmentation).astype(segmentation.dtype) 90 91 seg_ids, sizes = np.unique(segmentation, return_counts=True) 92 93 # In some cases objects may be smaller than peviously calculated, 94 # since they are covered by other objects. We ensure these also get 95 # filtered out here. 96 filter_ids = seg_ids[sizes < min_object_size] 97 98 # If we run segmentation with background we also map the largest segment 99 # (the most likely background object) to zero. This is often zero already, 100 # but it does not hurt to reset that to zero either. 101 if with_background: 102 bg_id = seg_ids[np.argmax(sizes)] 103 filter_ids = np.concatenate([filter_ids, [bg_id]]) 104 105 segmentation[np.isin(segmentation, filter_ids)] = 0 106 segmentation = relabel_sequential(segmentation)[0] 107 108 return segmentation
Convert the output of the automatic mask generation to an instance segmentation.
Arguments:
- masks: The outputs generated by AutomaticMaskGenerator or EmbeddingMaskGenerator. Only supports output_mode=binary_mask.
- with_background: Whether the segmentation has background. If yes this function assures that the largest object in the output will be mapped to zero (the background value).
- min_object_size: The minimal size of an object in pixels.
- max_object_size: The maximal size of an object in pixels.
- label_masks: Whether to apply connected components to the result before removing small objects.
Returns:
The instance segmentation.
116class AMGBase(ABC): 117 """Base class for the automatic mask generators. 118 """ 119 def __init__(self): 120 # the state that has to be computed by the 'initialize' method of the child classes 121 self._is_initialized = False 122 self._crop_list = None 123 self._crop_boxes = None 124 self._original_size = None 125 126 @property 127 def is_initialized(self): 128 """Whether the mask generator has already been initialized. 129 """ 130 return self._is_initialized 131 132 @property 133 def crop_list(self): 134 """The list of mask data after initialization. 135 """ 136 return self._crop_list 137 138 @property 139 def crop_boxes(self): 140 """The list of crop boxes. 141 """ 142 return self._crop_boxes 143 144 @property 145 def original_size(self): 146 """The original image size. 147 """ 148 return self._original_size 149 150 def _postprocess_batch( 151 self, 152 data, 153 crop_box, 154 original_size, 155 pred_iou_thresh, 156 stability_score_thresh, 157 box_nms_thresh, 158 ): 159 orig_h, orig_w = original_size 160 161 # filter by predicted IoU 162 if pred_iou_thresh > 0.0: 163 keep_mask = data["iou_preds"] > pred_iou_thresh 164 data.filter(keep_mask) 165 166 # filter by stability score 167 if stability_score_thresh > 0.0: 168 keep_mask = data["stability_score"] >= stability_score_thresh 169 data.filter(keep_mask) 170 171 # filter boxes that touch crop boundaries 172 keep_mask = ~amg_utils.is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) 173 if not torch.all(keep_mask): 174 data.filter(keep_mask) 175 176 # remove duplicates within this crop. 177 keep_by_nms = batched_nms( 178 data["boxes"].float(), 179 data["iou_preds"], 180 torch.zeros_like(data["boxes"][:, 0]), # categories 181 iou_threshold=box_nms_thresh, 182 ) 183 data.filter(keep_by_nms) 184 185 # return to the original image frame 186 data["boxes"] = amg_utils.uncrop_boxes_xyxy(data["boxes"], crop_box) 187 data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) 188 # the data from embedding based segmentation doesn't have the points 189 # so we skip if the corresponding key can't be found 190 try: 191 data["points"] = amg_utils.uncrop_points(data["points"], crop_box) 192 except KeyError: 193 pass 194 195 return data 196 197 def _postprocess_small_regions(self, mask_data, min_area, nms_thresh): 198 199 if len(mask_data["rles"]) == 0: 200 return mask_data 201 202 # filter small disconnected regions and holes 203 new_masks = [] 204 scores = [] 205 for rle in mask_data["rles"]: 206 mask = amg_utils.rle_to_mask(rle) 207 208 mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="holes") 209 unchanged = not changed 210 mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="islands") 211 unchanged = unchanged and not changed 212 213 new_masks.append(torch.as_tensor(mask, dtype=torch.int).unsqueeze(0)) 214 # give score=0 to changed masks and score=1 to unchanged masks 215 # so NMS will prefer ones that didn't need postprocessing 216 scores.append(float(unchanged)) 217 218 # recalculate boxes and remove any new duplicates 219 masks = torch.cat(new_masks, dim=0) 220 boxes = batched_mask_to_box(masks.to(torch.bool)) # Casting this to boolean as we work with one-hot labels. 221 keep_by_nms = batched_nms( 222 boxes.float(), 223 torch.as_tensor(scores, dtype=torch.float), 224 torch.zeros_like(boxes[:, 0]), # categories 225 iou_threshold=nms_thresh, 226 ) 227 228 # only recalculate RLEs for masks that have changed 229 for i_mask in keep_by_nms: 230 if scores[i_mask] == 0.0: 231 mask_torch = masks[i_mask].unsqueeze(0) 232 # mask_data["rles"][i_mask] = amg_utils.mask_to_rle_pytorch(mask_torch)[0] 233 mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] 234 mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly 235 mask_data.filter(keep_by_nms) 236 237 return mask_data 238 239 def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode): 240 # filter small disconnected regions and holes in masks 241 if min_mask_region_area > 0: 242 mask_data = self._postprocess_small_regions( 243 mask_data, 244 min_mask_region_area, 245 max(box_nms_thresh, crop_nms_thresh), 246 ) 247 248 # encode masks 249 if output_mode == "coco_rle": 250 mask_data["segmentations"] = [amg_utils.coco_encode_rle(rle) for rle in mask_data["rles"]] 251 elif output_mode == "binary_mask": 252 mask_data["segmentations"] = [amg_utils.rle_to_mask(rle) for rle in mask_data["rles"]] 253 else: 254 mask_data["segmentations"] = mask_data["rles"] 255 256 # write mask records 257 curr_anns = [] 258 for idx in range(len(mask_data["segmentations"])): 259 ann = { 260 "segmentation": mask_data["segmentations"][idx], 261 "area": amg_utils.area_from_rle(mask_data["rles"][idx]), 262 "bbox": amg_utils.box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), 263 "predicted_iou": mask_data["iou_preds"][idx].item(), 264 "stability_score": mask_data["stability_score"][idx].item(), 265 "crop_box": amg_utils.box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), 266 } 267 # the data from embedding based segmentation doesn't have the points 268 # so we skip if the corresponding key can't be found 269 try: 270 ann["point_coords"] = [mask_data["points"][idx].tolist()] 271 except KeyError: 272 pass 273 curr_anns.append(ann) 274 275 return curr_anns 276 277 def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None): 278 orig_h, orig_w = original_size 279 280 # serialize predictions and store in MaskData 281 data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1)) 282 if points is not None: 283 data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0), dtype=torch.float) 284 285 del masks 286 287 # calculate the stability scores 288 data["stability_score"] = amg_utils.calculate_stability_score( 289 data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset 290 ) 291 292 # threshold masks and calculate boxes 293 data["masks"] = data["masks"] > self._predictor.model.mask_threshold 294 data["masks"] = data["masks"].type(torch.bool) 295 data["boxes"] = batched_mask_to_box(data["masks"]) 296 297 # compress to RLE 298 data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w) 299 # data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"]) 300 data["rles"] = mask_to_rle_pytorch(data["masks"]) 301 del data["masks"] 302 303 return data 304 305 def get_state(self) -> Dict[str, Any]: 306 """Get the initialized state of the mask generator. 307 308 Returns: 309 State of the mask generator. 310 """ 311 if not self.is_initialized: 312 raise RuntimeError("The state has not been computed yet. Call initialize first.") 313 314 return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size} 315 316 def set_state(self, state: Dict[str, Any]) -> None: 317 """Set the state of the mask generator. 318 319 Args: 320 state: The state of the mask generator, e.g. from serialized state. 321 """ 322 self._crop_list = state["crop_list"] 323 self._crop_boxes = state["crop_boxes"] 324 self._original_size = state["original_size"] 325 self._is_initialized = True 326 327 def clear_state(self): 328 """Clear the state of the mask generator. 329 """ 330 self._crop_list = None 331 self._crop_boxes = None 332 self._original_size = None 333 self._is_initialized = False
Base class for the automatic mask generators.
126 @property 127 def is_initialized(self): 128 """Whether the mask generator has already been initialized. 129 """ 130 return self._is_initialized
Whether the mask generator has already been initialized.
132 @property 133 def crop_list(self): 134 """The list of mask data after initialization. 135 """ 136 return self._crop_list
The list of mask data after initialization.
138 @property 139 def crop_boxes(self): 140 """The list of crop boxes. 141 """ 142 return self._crop_boxes
The list of crop boxes.
144 @property 145 def original_size(self): 146 """The original image size. 147 """ 148 return self._original_size
The original image size.
305 def get_state(self) -> Dict[str, Any]: 306 """Get the initialized state of the mask generator. 307 308 Returns: 309 State of the mask generator. 310 """ 311 if not self.is_initialized: 312 raise RuntimeError("The state has not been computed yet. Call initialize first.") 313 314 return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size}
Get the initialized state of the mask generator.
Returns:
State of the mask generator.
316 def set_state(self, state: Dict[str, Any]) -> None: 317 """Set the state of the mask generator. 318 319 Args: 320 state: The state of the mask generator, e.g. from serialized state. 321 """ 322 self._crop_list = state["crop_list"] 323 self._crop_boxes = state["crop_boxes"] 324 self._original_size = state["original_size"] 325 self._is_initialized = True
Set the state of the mask generator.
Arguments:
- state: The state of the mask generator, e.g. from serialized state.
336class AutomaticMaskGenerator(AMGBase): 337 """Generates an instance segmentation without prompts, using a point grid. 338 339 This class implements the same logic as 340 https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py 341 It decouples the computationally expensive steps of generating masks from the cheap post-processing operation 342 to filter these masks to enable grid search and interactively changing the post-processing. 343 344 Use this class as follows: 345 ```python 346 amg = AutomaticMaskGenerator(predictor) 347 amg.initialize(image) # Initialize the masks, this takes care of all expensive computations. 348 masks = amg.generate(pred_iou_thresh=0.8) # Generate the masks. This is fast and enables testing parameters 349 ``` 350 351 Args: 352 predictor: The segment anything predictor. 353 points_per_side: The number of points to be sampled along one side of the image. 354 If None, `point_grids` must provide explicit point sampling. 355 points_per_batch: The number of points run simultaneously by the model. 356 Higher numbers may be faster but use more GPU memory. 357 crop_n_layers: If >0, the mask prediction will be run again on crops of the image. 358 crop_overlap_ratio: Sets the degree to which crops overlap. 359 crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops. 360 point_grids: A lisst over explicit grids of points used for sampling masks. 361 Normalized to [0, 1] with respect to the image coordinate system. 362 stability_score_offset: The amount to shift the cutoff when calculating the stability score. 363 """ 364 def __init__( 365 self, 366 predictor: SamPredictor, 367 points_per_side: Optional[int] = 32, 368 points_per_batch: Optional[int] = None, 369 crop_n_layers: int = 0, 370 crop_overlap_ratio: float = 512 / 1500, 371 crop_n_points_downscale_factor: int = 1, 372 point_grids: Optional[List[np.ndarray]] = None, 373 stability_score_offset: float = 1.0, 374 ): 375 super().__init__() 376 377 if points_per_side is not None: 378 self.point_grids = amg_utils.build_all_layer_point_grids( 379 points_per_side, 380 crop_n_layers, 381 crop_n_points_downscale_factor, 382 ) 383 elif point_grids is not None: 384 self.point_grids = point_grids 385 else: 386 raise ValueError("Can't have both points_per_side and point_grid be None or not None.") 387 388 self._predictor = predictor 389 self._points_per_side = points_per_side 390 391 # we set the points per batch to 16 for mps for performance reasons 392 # and otherwise keep them at the default of 64 393 if points_per_batch is None: 394 points_per_batch = 16 if str(predictor.device) == "mps" else 64 395 self._points_per_batch = points_per_batch 396 397 self._crop_n_layers = crop_n_layers 398 self._crop_overlap_ratio = crop_overlap_ratio 399 self._crop_n_points_downscale_factor = crop_n_points_downscale_factor 400 self._stability_score_offset = stability_score_offset 401 402 def _process_batch(self, points, im_size, crop_box, original_size): 403 # run model on this batch 404 transformed_points = self._predictor.transform.apply_coords(points, im_size) 405 in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float) 406 in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) 407 masks, iou_preds, _ = self._predictor.predict_torch( 408 in_points[:, None, :], 409 in_labels[:, None], 410 multimask_output=True, 411 return_logits=True, 412 ) 413 data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points) 414 del masks 415 return data 416 417 def _process_crop(self, image, crop_box, crop_layer_idx, precomputed_embeddings, pbar_init=None, pbar_update=None): 418 # Crop the image and calculate embeddings. 419 x0, y0, x1, y1 = crop_box 420 cropped_im = image[y0:y1, x0:x1, :] 421 cropped_im_size = cropped_im.shape[:2] 422 423 if not precomputed_embeddings: 424 self._predictor.set_image(cropped_im) 425 426 # Get the points for this crop. 427 points_scale = np.array(cropped_im_size)[None, ::-1] 428 points_for_image = self.point_grids[crop_layer_idx] * points_scale 429 430 # Generate masks for this crop in batches. 431 data = amg_utils.MaskData() 432 n_batches = len(points_for_image) // self._points_per_batch +\ 433 int(len(points_for_image) % self._points_per_batch != 0) 434 if pbar_init is not None: 435 pbar_init(n_batches, "Predict masks for point grid prompts") 436 437 for (points,) in amg_utils.batch_iterator(self._points_per_batch, points_for_image): 438 batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size) 439 data.cat(batch_data) 440 del batch_data 441 if pbar_update is not None: 442 pbar_update(1) 443 444 if not precomputed_embeddings: 445 self._predictor.reset_image() 446 447 return data 448 449 @torch.no_grad() 450 def initialize( 451 self, 452 image: np.ndarray, 453 image_embeddings: Optional[util.ImageEmbeddings] = None, 454 i: Optional[int] = None, 455 verbose: bool = False, 456 pbar_init: Optional[callable] = None, 457 pbar_update: Optional[callable] = None, 458 ) -> None: 459 """Initialize image embeddings and masks for an image. 460 461 Args: 462 image: The input image, volume or timeseries. 463 image_embeddings: Optional precomputed image embeddings. 464 See `util.precompute_image_embeddings` for details. 465 i: Index for the image data. Required if `image` has three spatial dimensions 466 or a time dimension and two spatial dimensions. 467 verbose: Whether to print computation progress. 468 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 469 Can be used together with pbar_update to handle napari progress bar in other thread. 470 To enables using this function within a threadworker. 471 pbar_update: Callback to update an external progress bar. 472 """ 473 original_size = image.shape[:2] 474 self._original_size = original_size 475 476 crop_boxes, layer_idxs = amg_utils.generate_crop_boxes( 477 original_size, self._crop_n_layers, self._crop_overlap_ratio 478 ) 479 480 # We can set fixed image embeddings if we only have a single crop box (the default setting). 481 # Otherwise we have to recompute the embeddings for each crop and can't precompute. 482 if len(crop_boxes) == 1: 483 if image_embeddings is None: 484 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 485 util.set_precomputed(self._predictor, image_embeddings, i=i) 486 precomputed_embeddings = True 487 else: 488 precomputed_embeddings = False 489 490 # we need to cast to the image representation that is compatible with SAM 491 image = util._to_image(image) 492 493 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 494 495 crop_list = [] 496 for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 497 crop_data = self._process_crop( 498 image, crop_box, layer_idx, 499 precomputed_embeddings=precomputed_embeddings, 500 pbar_init=pbar_init, pbar_update=pbar_update, 501 ) 502 crop_list.append(crop_data) 503 pbar_close() 504 505 self._is_initialized = True 506 self._crop_list = crop_list 507 self._crop_boxes = crop_boxes 508 509 @torch.no_grad() 510 def generate( 511 self, 512 pred_iou_thresh: float = 0.88, 513 stability_score_thresh: float = 0.95, 514 box_nms_thresh: float = 0.7, 515 crop_nms_thresh: float = 0.7, 516 min_mask_region_area: int = 0, 517 output_mode: str = "binary_mask", 518 ) -> List[Dict[str, Any]]: 519 """Generate instance segmentation for the currently initialized image. 520 521 Args: 522 pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model. 523 stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask 524 under changes to the cutoff used to binarize the model prediction. 525 box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks. 526 crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops. 527 min_mask_region_area: Minimal size for the predicted masks. 528 output_mode: The form masks are returned in. 529 530 Returns: 531 The instance segmentation masks. 532 """ 533 if not self.is_initialized: 534 raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.") 535 536 data = amg_utils.MaskData() 537 for data_, crop_box in zip(self.crop_list, self.crop_boxes): 538 crop_data = self._postprocess_batch( 539 data=deepcopy(data_), 540 crop_box=crop_box, original_size=self.original_size, 541 pred_iou_thresh=pred_iou_thresh, 542 stability_score_thresh=stability_score_thresh, 543 box_nms_thresh=box_nms_thresh 544 ) 545 data.cat(crop_data) 546 547 if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0: 548 # Prefer masks from smaller crops 549 scores = 1 / box_area(data["crop_boxes"]) 550 scores = scores.to(data["boxes"].device) 551 keep_by_nms = batched_nms( 552 data["boxes"].float(), 553 scores, 554 torch.zeros_like(data["boxes"][:, 0]), # categories 555 iou_threshold=crop_nms_thresh, 556 ) 557 data.filter(keep_by_nms) 558 559 data.to_numpy() 560 masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode) 561 return masks
Generates an instance segmentation without prompts, using a point grid.
This class implements the same logic as https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py It decouples the computationally expensive steps of generating masks from the cheap post-processing operation to filter these masks to enable grid search and interactively changing the post-processing.
Use this class as follows:
amg = AutomaticMaskGenerator(predictor)
amg.initialize(image) # Initialize the masks, this takes care of all expensive computations.
masks = amg.generate(pred_iou_thresh=0.8) # Generate the masks. This is fast and enables testing parameters
Arguments:
- predictor: The segment anything predictor.
- points_per_side: The number of points to be sampled along one side of the image.
If None,
point_grids
must provide explicit point sampling. - points_per_batch: The number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory.
- crop_n_layers: If >0, the mask prediction will be run again on crops of the image.
- crop_overlap_ratio: Sets the degree to which crops overlap.
- crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops.
- point_grids: A lisst over explicit grids of points used for sampling masks. Normalized to [0, 1] with respect to the image coordinate system.
- stability_score_offset: The amount to shift the cutoff when calculating the stability score.
364 def __init__( 365 self, 366 predictor: SamPredictor, 367 points_per_side: Optional[int] = 32, 368 points_per_batch: Optional[int] = None, 369 crop_n_layers: int = 0, 370 crop_overlap_ratio: float = 512 / 1500, 371 crop_n_points_downscale_factor: int = 1, 372 point_grids: Optional[List[np.ndarray]] = None, 373 stability_score_offset: float = 1.0, 374 ): 375 super().__init__() 376 377 if points_per_side is not None: 378 self.point_grids = amg_utils.build_all_layer_point_grids( 379 points_per_side, 380 crop_n_layers, 381 crop_n_points_downscale_factor, 382 ) 383 elif point_grids is not None: 384 self.point_grids = point_grids 385 else: 386 raise ValueError("Can't have both points_per_side and point_grid be None or not None.") 387 388 self._predictor = predictor 389 self._points_per_side = points_per_side 390 391 # we set the points per batch to 16 for mps for performance reasons 392 # and otherwise keep them at the default of 64 393 if points_per_batch is None: 394 points_per_batch = 16 if str(predictor.device) == "mps" else 64 395 self._points_per_batch = points_per_batch 396 397 self._crop_n_layers = crop_n_layers 398 self._crop_overlap_ratio = crop_overlap_ratio 399 self._crop_n_points_downscale_factor = crop_n_points_downscale_factor 400 self._stability_score_offset = stability_score_offset
449 @torch.no_grad() 450 def initialize( 451 self, 452 image: np.ndarray, 453 image_embeddings: Optional[util.ImageEmbeddings] = None, 454 i: Optional[int] = None, 455 verbose: bool = False, 456 pbar_init: Optional[callable] = None, 457 pbar_update: Optional[callable] = None, 458 ) -> None: 459 """Initialize image embeddings and masks for an image. 460 461 Args: 462 image: The input image, volume or timeseries. 463 image_embeddings: Optional precomputed image embeddings. 464 See `util.precompute_image_embeddings` for details. 465 i: Index for the image data. Required if `image` has three spatial dimensions 466 or a time dimension and two spatial dimensions. 467 verbose: Whether to print computation progress. 468 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 469 Can be used together with pbar_update to handle napari progress bar in other thread. 470 To enables using this function within a threadworker. 471 pbar_update: Callback to update an external progress bar. 472 """ 473 original_size = image.shape[:2] 474 self._original_size = original_size 475 476 crop_boxes, layer_idxs = amg_utils.generate_crop_boxes( 477 original_size, self._crop_n_layers, self._crop_overlap_ratio 478 ) 479 480 # We can set fixed image embeddings if we only have a single crop box (the default setting). 481 # Otherwise we have to recompute the embeddings for each crop and can't precompute. 482 if len(crop_boxes) == 1: 483 if image_embeddings is None: 484 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 485 util.set_precomputed(self._predictor, image_embeddings, i=i) 486 precomputed_embeddings = True 487 else: 488 precomputed_embeddings = False 489 490 # we need to cast to the image representation that is compatible with SAM 491 image = util._to_image(image) 492 493 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 494 495 crop_list = [] 496 for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 497 crop_data = self._process_crop( 498 image, crop_box, layer_idx, 499 precomputed_embeddings=precomputed_embeddings, 500 pbar_init=pbar_init, pbar_update=pbar_update, 501 ) 502 crop_list.append(crop_data) 503 pbar_close() 504 505 self._is_initialized = True 506 self._crop_list = crop_list 507 self._crop_boxes = crop_boxes
Initialize image embeddings and masks for an image.
Arguments:
- image: The input image, volume or timeseries.
- image_embeddings: Optional precomputed image embeddings.
See
util.precompute_image_embeddings
for details. - i: Index for the image data. Required if
image
has three spatial dimensions or a time dimension and two spatial dimensions. - verbose: Whether to print computation progress.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
509 @torch.no_grad() 510 def generate( 511 self, 512 pred_iou_thresh: float = 0.88, 513 stability_score_thresh: float = 0.95, 514 box_nms_thresh: float = 0.7, 515 crop_nms_thresh: float = 0.7, 516 min_mask_region_area: int = 0, 517 output_mode: str = "binary_mask", 518 ) -> List[Dict[str, Any]]: 519 """Generate instance segmentation for the currently initialized image. 520 521 Args: 522 pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model. 523 stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask 524 under changes to the cutoff used to binarize the model prediction. 525 box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks. 526 crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops. 527 min_mask_region_area: Minimal size for the predicted masks. 528 output_mode: The form masks are returned in. 529 530 Returns: 531 The instance segmentation masks. 532 """ 533 if not self.is_initialized: 534 raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.") 535 536 data = amg_utils.MaskData() 537 for data_, crop_box in zip(self.crop_list, self.crop_boxes): 538 crop_data = self._postprocess_batch( 539 data=deepcopy(data_), 540 crop_box=crop_box, original_size=self.original_size, 541 pred_iou_thresh=pred_iou_thresh, 542 stability_score_thresh=stability_score_thresh, 543 box_nms_thresh=box_nms_thresh 544 ) 545 data.cat(crop_data) 546 547 if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0: 548 # Prefer masks from smaller crops 549 scores = 1 / box_area(data["crop_boxes"]) 550 scores = scores.to(data["boxes"].device) 551 keep_by_nms = batched_nms( 552 data["boxes"].float(), 553 scores, 554 torch.zeros_like(data["boxes"][:, 0]), # categories 555 iou_threshold=crop_nms_thresh, 556 ) 557 data.filter(keep_by_nms) 558 559 data.to_numpy() 560 masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode) 561 return masks
Generate instance segmentation for the currently initialized image.
Arguments:
- pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
- stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask under changes to the cutoff used to binarize the model prediction.
- box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
- crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops.
- min_mask_region_area: Minimal size for the predicted masks.
- output_mode: The form masks are returned in.
Returns:
The instance segmentation masks.
Inherited Members
589class TiledAutomaticMaskGenerator(AutomaticMaskGenerator): 590 """Generates an instance segmentation without prompts, using a point grid. 591 592 Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings. 593 594 Args: 595 predictor: The segment anything predictor. 596 points_per_side: The number of points to be sampled along one side of the image. 597 If None, `point_grids` must provide explicit point sampling. 598 points_per_batch: The number of points run simultaneously by the model. 599 Higher numbers may be faster but use more GPU memory. 600 point_grids: A lisst over explicit grids of points used for sampling masks. 601 Normalized to [0, 1] with respect to the image coordinate system. 602 stability_score_offset: The amount to shift the cutoff when calculating the stability score. 603 """ 604 605 # We only expose the arguments that make sense for the tiled mask generator. 606 # Anything related to crops doesn't make sense, because we re-use that functionality 607 # for tiling, so these parameters wouldn't have any effect. 608 def __init__( 609 self, 610 predictor: SamPredictor, 611 points_per_side: Optional[int] = 32, 612 points_per_batch: int = 64, 613 point_grids: Optional[List[np.ndarray]] = None, 614 stability_score_offset: float = 1.0, 615 ) -> None: 616 super().__init__( 617 predictor=predictor, 618 points_per_side=points_per_side, 619 points_per_batch=points_per_batch, 620 point_grids=point_grids, 621 stability_score_offset=stability_score_offset, 622 ) 623 624 @torch.no_grad() 625 def initialize( 626 self, 627 image: np.ndarray, 628 image_embeddings: Optional[util.ImageEmbeddings] = None, 629 i: Optional[int] = None, 630 tile_shape: Optional[Tuple[int, int]] = None, 631 halo: Optional[Tuple[int, int]] = None, 632 verbose: bool = False, 633 pbar_init: Optional[callable] = None, 634 pbar_update: Optional[callable] = None, 635 ) -> None: 636 """Initialize image embeddings and masks for an image. 637 638 Args: 639 image: The input image, volume or timeseries. 640 image_embeddings: Optional precomputed image embeddings. 641 See `util.precompute_image_embeddings` for details. 642 i: Index for the image data. Required if `image` has three spatial dimensions 643 or a time dimension and two spatial dimensions. 644 tile_shape: The tile shape for embedding prediction. 645 halo: The overlap of between tiles. 646 verbose: Whether to print computation progress. 647 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 648 Can be used together with pbar_update to handle napari progress bar in other thread. 649 To enables using this function within a threadworker. 650 pbar_update: Callback to update an external progress bar. 651 """ 652 original_size = image.shape[:2] 653 self._original_size = original_size 654 655 image_embeddings, tile_shape, halo = _process_tiled_embeddings( 656 self._predictor, image, image_embeddings, tile_shape, halo 657 ) 658 659 tiling = blocking([0, 0], original_size, tile_shape) 660 n_tiles = tiling.numberOfBlocks 661 662 # The crop box is always the full local tile. 663 tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)] 664 crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles] 665 666 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 667 pbar_init(n_tiles, "Compute masks for tile") 668 669 # We need to cast to the image representation that is compatible with SAM. 670 image = util._to_image(image) 671 672 mask_data = [] 673 for tile_id in range(n_tiles): 674 # set the pre-computed embeddings for this tile 675 features = image_embeddings["features"][tile_id] 676 tile_embeddings = { 677 "features": features, 678 "input_size": features.attrs["input_size"], 679 "original_size": features.attrs["original_size"], 680 } 681 util.set_precomputed(self._predictor, tile_embeddings, i) 682 683 # compute the mask data for this tile and append it 684 this_mask_data = self._process_crop( 685 image, crop_box=crop_boxes[tile_id], crop_layer_idx=0, precomputed_embeddings=True 686 ) 687 mask_data.append(this_mask_data) 688 pbar_update(1) 689 pbar_close() 690 691 # set the initialized data 692 self._is_initialized = True 693 self._crop_list = mask_data 694 self._crop_boxes = crop_boxes
Generates an instance segmentation without prompts, using a point grid.
Implements the same functionality as AutomaticMaskGenerator
but for tiled embeddings.
Arguments:
- predictor: The segment anything predictor.
- points_per_side: The number of points to be sampled along one side of the image.
If None,
point_grids
must provide explicit point sampling. - points_per_batch: The number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory.
- point_grids: A lisst over explicit grids of points used for sampling masks. Normalized to [0, 1] with respect to the image coordinate system.
- stability_score_offset: The amount to shift the cutoff when calculating the stability score.
608 def __init__( 609 self, 610 predictor: SamPredictor, 611 points_per_side: Optional[int] = 32, 612 points_per_batch: int = 64, 613 point_grids: Optional[List[np.ndarray]] = None, 614 stability_score_offset: float = 1.0, 615 ) -> None: 616 super().__init__( 617 predictor=predictor, 618 points_per_side=points_per_side, 619 points_per_batch=points_per_batch, 620 point_grids=point_grids, 621 stability_score_offset=stability_score_offset, 622 )
624 @torch.no_grad() 625 def initialize( 626 self, 627 image: np.ndarray, 628 image_embeddings: Optional[util.ImageEmbeddings] = None, 629 i: Optional[int] = None, 630 tile_shape: Optional[Tuple[int, int]] = None, 631 halo: Optional[Tuple[int, int]] = None, 632 verbose: bool = False, 633 pbar_init: Optional[callable] = None, 634 pbar_update: Optional[callable] = None, 635 ) -> None: 636 """Initialize image embeddings and masks for an image. 637 638 Args: 639 image: The input image, volume or timeseries. 640 image_embeddings: Optional precomputed image embeddings. 641 See `util.precompute_image_embeddings` for details. 642 i: Index for the image data. Required if `image` has three spatial dimensions 643 or a time dimension and two spatial dimensions. 644 tile_shape: The tile shape for embedding prediction. 645 halo: The overlap of between tiles. 646 verbose: Whether to print computation progress. 647 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 648 Can be used together with pbar_update to handle napari progress bar in other thread. 649 To enables using this function within a threadworker. 650 pbar_update: Callback to update an external progress bar. 651 """ 652 original_size = image.shape[:2] 653 self._original_size = original_size 654 655 image_embeddings, tile_shape, halo = _process_tiled_embeddings( 656 self._predictor, image, image_embeddings, tile_shape, halo 657 ) 658 659 tiling = blocking([0, 0], original_size, tile_shape) 660 n_tiles = tiling.numberOfBlocks 661 662 # The crop box is always the full local tile. 663 tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)] 664 crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles] 665 666 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 667 pbar_init(n_tiles, "Compute masks for tile") 668 669 # We need to cast to the image representation that is compatible with SAM. 670 image = util._to_image(image) 671 672 mask_data = [] 673 for tile_id in range(n_tiles): 674 # set the pre-computed embeddings for this tile 675 features = image_embeddings["features"][tile_id] 676 tile_embeddings = { 677 "features": features, 678 "input_size": features.attrs["input_size"], 679 "original_size": features.attrs["original_size"], 680 } 681 util.set_precomputed(self._predictor, tile_embeddings, i) 682 683 # compute the mask data for this tile and append it 684 this_mask_data = self._process_crop( 685 image, crop_box=crop_boxes[tile_id], crop_layer_idx=0, precomputed_embeddings=True 686 ) 687 mask_data.append(this_mask_data) 688 pbar_update(1) 689 pbar_close() 690 691 # set the initialized data 692 self._is_initialized = True 693 self._crop_list = mask_data 694 self._crop_boxes = crop_boxes
Initialize image embeddings and masks for an image.
Arguments:
- image: The input image, volume or timeseries.
- image_embeddings: Optional precomputed image embeddings.
See
util.precompute_image_embeddings
for details. - i: Index for the image data. Required if
image
has three spatial dimensions or a time dimension and two spatial dimensions. - tile_shape: The tile shape for embedding prediction.
- halo: The overlap of between tiles.
- verbose: Whether to print computation progress.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
702class DecoderAdapter(torch.nn.Module): 703 """Adapter to contain the UNETR decoder in a single module. 704 705 To apply the decoder on top of pre-computed embeddings for 706 the segmentation functionality. 707 See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py 708 """ 709 def __init__(self, unetr): 710 super().__init__() 711 712 self.base = unetr.base 713 self.out_conv = unetr.out_conv 714 self.deconv_out = unetr.deconv_out 715 self.decoder_head = unetr.decoder_head 716 self.final_activation = unetr.final_activation 717 self.postprocess_masks = unetr.postprocess_masks 718 719 self.decoder = unetr.decoder 720 self.deconv1 = unetr.deconv1 721 self.deconv2 = unetr.deconv2 722 self.deconv3 = unetr.deconv3 723 self.deconv4 = unetr.deconv4 724 725 def forward(self, input_, input_shape, original_shape): 726 z12 = input_ 727 728 z9 = self.deconv1(z12) 729 z6 = self.deconv2(z9) 730 z3 = self.deconv3(z6) 731 z0 = self.deconv4(z3) 732 733 updated_from_encoder = [z9, z6, z3] 734 735 x = self.base(z12) 736 x = self.decoder(x, encoder_inputs=updated_from_encoder) 737 x = self.deconv_out(x) 738 739 x = torch.cat([x, z0], dim=1) 740 x = self.decoder_head(x) 741 742 x = self.out_conv(x) 743 if self.final_activation is not None: 744 x = self.final_activation(x) 745 746 x = self.postprocess_masks(x, input_shape, original_shape) 747 return x
Adapter to contain the UNETR decoder in a single module.
To apply the decoder on top of pre-computed embeddings for the segmentation functionality. See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py
709 def __init__(self, unetr): 710 super().__init__() 711 712 self.base = unetr.base 713 self.out_conv = unetr.out_conv 714 self.deconv_out = unetr.deconv_out 715 self.decoder_head = unetr.decoder_head 716 self.final_activation = unetr.final_activation 717 self.postprocess_masks = unetr.postprocess_masks 718 719 self.decoder = unetr.decoder 720 self.deconv1 = unetr.deconv1 721 self.deconv2 = unetr.deconv2 722 self.deconv3 = unetr.deconv3 723 self.deconv4 = unetr.deconv4
Initialize internal Module state, shared by both nn.Module and ScriptModule.
725 def forward(self, input_, input_shape, original_shape): 726 z12 = input_ 727 728 z9 = self.deconv1(z12) 729 z6 = self.deconv2(z9) 730 z3 = self.deconv3(z6) 731 z0 = self.deconv4(z3) 732 733 updated_from_encoder = [z9, z6, z3] 734 735 x = self.base(z12) 736 x = self.decoder(x, encoder_inputs=updated_from_encoder) 737 x = self.deconv_out(x) 738 739 x = torch.cat([x, z0], dim=1) 740 x = self.decoder_head(x) 741 742 x = self.out_conv(x) 743 if self.final_activation is not None: 744 x = self.final_activation(x) 745 746 x = self.postprocess_masks(x, input_shape, original_shape) 747 return x
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
750def get_unetr( 751 image_encoder: torch.nn.Module, 752 decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None, 753 device: Optional[Union[str, torch.device]] = None, 754) -> torch.nn.Module: 755 """Get UNETR model for automatic instance segmentation. 756 757 Args: 758 image_encoder: The image encoder of the SAM model. 759 This is used as encoder by the UNETR too. 760 decoder_state: Optional decoder state to initialize the weights 761 of the UNETR decoder. 762 device: The device. 763 Returns: 764 The UNETR model. 765 """ 766 device = util.get_device(device) 767 768 unetr = UNETR( 769 backbone="sam", 770 encoder=image_encoder, 771 out_channels=3, 772 use_sam_stats=True, 773 final_activation="Sigmoid", 774 use_skip_connection=False, 775 resize_input=True, 776 ) 777 if decoder_state is not None: 778 unetr_state_dict = unetr.state_dict() 779 for k, v in unetr_state_dict.items(): 780 if not k.startswith("encoder"): 781 unetr_state_dict[k] = decoder_state[k] 782 unetr.load_state_dict(unetr_state_dict) 783 784 unetr.to(device) 785 return unetr
Get UNETR model for automatic instance segmentation.
Arguments:
- image_encoder: The image encoder of the SAM model. This is used as encoder by the UNETR too.
- decoder_state: Optional decoder state to initialize the weights of the UNETR decoder.
- device: The device.
Returns:
The UNETR model.
788def get_decoder( 789 image_encoder: torch.nn.Module, 790 decoder_state: OrderedDict[str, torch.Tensor], 791 device: Optional[Union[str, torch.device]] = None, 792) -> DecoderAdapter: 793 """Get decoder to predict outputs for automatic instance segmentation 794 795 Args: 796 image_encoder: The image encoder of the SAM model. 797 decoder_state: State to initialize the weights of the UNETR decoder. 798 device: The device. 799 Returns: 800 The decoder for instance segmentation. 801 """ 802 unetr = get_unetr(image_encoder, decoder_state, device) 803 return DecoderAdapter(unetr)
Get decoder to predict outputs for automatic instance segmentation
Arguments:
- image_encoder: The image encoder of the SAM model.
- decoder_state: State to initialize the weights of the UNETR decoder.
- device: The device.
Returns:
The decoder for instance segmentation.
806def get_predictor_and_decoder( 807 model_type: str, 808 checkpoint_path: Union[str, os.PathLike], 809 device: Optional[Union[str, torch.device]] = None, 810 peft_kwargs: Optional[Dict] = None, 811) -> Tuple[SamPredictor, DecoderAdapter]: 812 """Load the SAM model (predictor) and instance segmentation decoder. 813 814 This requires a checkpoint that contains the state for both predictor 815 and decoder. 816 817 Args: 818 model_type: The type of the image encoder used in the SAM model. 819 checkpoint_path: Path to the checkpoint from which to load the data. 820 device: The device. 821 lora_rank: The rank for low rank adaptation of the attention layers. 822 823 Returns: 824 The SAM predictor. 825 The decoder for instance segmentation. 826 """ 827 device = util.get_device(device) 828 predictor, state = util.get_sam_model( 829 model_type=model_type, 830 checkpoint_path=checkpoint_path, 831 device=device, 832 return_state=True, 833 peft_kwargs=peft_kwargs, 834 ) 835 if "decoder_state" not in state: 836 raise ValueError( 837 f"The checkpoint at '{checkpoint_path}' or the chosen model '{model_type}' does not contain a decoder state" 838 ) 839 decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device) 840 return predictor, decoder
Load the SAM model (predictor) and instance segmentation decoder.
This requires a checkpoint that contains the state for both predictor and decoder.
Arguments:
- model_type: The type of the image encoder used in the SAM model.
- checkpoint_path: Path to the checkpoint from which to load the data.
- device: The device.
- lora_rank: The rank for low rank adaptation of the attention layers.
Returns:
The SAM predictor. The decoder for instance segmentation.
843class InstanceSegmentationWithDecoder: 844 """Generates an instance segmentation without prompts, using a decoder. 845 846 Implements the same interface as `AutomaticMaskGenerator`. 847 848 Use this class as follows: 849 ```python 850 segmenter = InstanceSegmentationWithDecoder(predictor, decoder) 851 segmenter.initialize(image) # Predict the image embeddings and decoder outputs. 852 masks = segmenter.generate(center_distance_threshold=0.75) # Generate the instance segmentation. 853 ``` 854 855 Args: 856 predictor: The segment anything predictor. 857 decoder: The decoder to predict intermediate representations 858 for instance segmentation. 859 """ 860 def __init__( 861 self, 862 predictor: SamPredictor, 863 decoder: torch.nn.Module, 864 ) -> None: 865 self._predictor = predictor 866 self._decoder = decoder 867 868 # The decoder outputs. 869 self._foreground = None 870 self._center_distances = None 871 self._boundary_distances = None 872 873 self._is_initialized = False 874 875 @property 876 def is_initialized(self): 877 """Whether the mask generator has already been initialized. 878 """ 879 return self._is_initialized 880 881 @torch.no_grad() 882 def initialize( 883 self, 884 image: np.ndarray, 885 image_embeddings: Optional[util.ImageEmbeddings] = None, 886 i: Optional[int] = None, 887 verbose: bool = False, 888 pbar_init: Optional[callable] = None, 889 pbar_update: Optional[callable] = None, 890 ) -> None: 891 """Initialize image embeddings and decoder predictions for an image. 892 893 Args: 894 image: The input image, volume or timeseries. 895 image_embeddings: Optional precomputed image embeddings. 896 See `util.precompute_image_embeddings` for details. 897 i: Index for the image data. Required if `image` has three spatial dimensions 898 or a time dimension and two spatial dimensions. 899 verbose: Whether to be verbose. 900 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 901 Can be used together with pbar_update to handle napari progress bar in other thread. 902 To enables using this function within a threadworker. 903 pbar_update: Callback to update an external progress bar. 904 """ 905 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 906 pbar_init(1, "Initialize instance segmentation with decoder") 907 908 if image_embeddings is None: 909 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 910 911 # Get the image embeddings from the predictor. 912 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i) 913 embeddings = self._predictor.features 914 input_shape = tuple(self._predictor.input_size) 915 original_shape = tuple(self._predictor.original_size) 916 917 # Run prediction with the UNETR decoder. 918 output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0) 919 assert output.shape[0] == 3, f"{output.shape}" 920 pbar_update(1) 921 pbar_close() 922 923 # Set the state. 924 self._foreground = output[0] 925 self._center_distances = output[1] 926 self._boundary_distances = output[2] 927 self._is_initialized = True 928 929 def _to_masks(self, segmentation, output_mode): 930 if output_mode != "binary_mask": 931 raise NotImplementedError 932 933 props = regionprops(segmentation) 934 ndim = segmentation.ndim 935 assert ndim in (2, 3) 936 937 shape = segmentation.shape 938 if ndim == 2: 939 crop_box = [0, shape[1], 0, shape[0]] 940 else: 941 crop_box = [0, shape[2], 0, shape[1], 0, shape[0]] 942 943 # go from skimage bbox in format [y0, x0, y1, x1] to SAM format [x0, w, y0, h] 944 def to_bbox_2d(bbox): 945 y0, x0 = bbox[0], bbox[1] 946 w = bbox[3] - x0 947 h = bbox[2] - y0 948 return [x0, w, y0, h] 949 950 def to_bbox_3d(bbox): 951 z0, y0, x0 = bbox[0], bbox[1], bbox[2] 952 w = bbox[5] - x0 953 h = bbox[4] - y0 954 d = bbox[3] - y0 955 return [x0, w, y0, h, z0, d] 956 957 to_bbox = to_bbox_2d if ndim == 2 else to_bbox_3d 958 masks = [ 959 { 960 "segmentation": segmentation == prop.label, 961 "area": prop.area, 962 "bbox": to_bbox(prop.bbox), 963 "crop_box": crop_box, 964 "seg_id": prop.label, 965 } for prop in props 966 ] 967 return masks 968 969 def generate( 970 self, 971 center_distance_threshold: float = 0.5, 972 boundary_distance_threshold: float = 0.5, 973 foreground_threshold: float = 0.5, 974 foreground_smoothing: float = 1.0, 975 distance_smoothing: float = 1.6, 976 min_size: int = 0, 977 output_mode: Optional[str] = "binary_mask", 978 ) -> List[Dict[str, Any]]: 979 """Generate instance segmentation for the currently initialized image. 980 981 Args: 982 center_distance_threshold: Center distance predictions below this value will be 983 used to find seeds (intersected with thresholded boundary distance predictions). 984 boundary_distance_threshold: Boundary distance predictions below this value will be 985 used to find seeds (intersected with thresholded center distance predictions). 986 foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid 987 checkerboard artifacts in the prediction. 988 foreground_threshold: Foreground predictions above this value will be used as foreground mask. 989 distance_smoothing: Sigma value for smoothing the distance predictions. 990 min_size: Minimal object size in the segmentation result. 991 output_mode: The form masks are returned in. Pass None to directly return the instance segmentation. 992 993 Returns: 994 The instance segmentation masks. 995 """ 996 if not self.is_initialized: 997 raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.") 998 999 if foreground_smoothing > 0: 1000 foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing) 1001 else: 1002 foreground = self._foreground 1003 # Further optimization: parallel implementation using elf.parallel functionality. 1004 # (Make sure to expose n_threads to avoid over-subscription in case of outer parallelization) 1005 segmentation = watershed_from_center_and_boundary_distances( 1006 self._center_distances, self._boundary_distances, foreground, 1007 center_distance_threshold=center_distance_threshold, 1008 boundary_distance_threshold=boundary_distance_threshold, 1009 foreground_threshold=foreground_threshold, 1010 distance_smoothing=distance_smoothing, 1011 min_size=min_size, 1012 ) 1013 if output_mode is not None: 1014 segmentation = self._to_masks(segmentation, output_mode) 1015 return segmentation 1016 1017 def get_state(self) -> Dict[str, Any]: 1018 """Get the initialized state of the instance segmenter. 1019 1020 Returns: 1021 Instance segmentation state. 1022 """ 1023 if not self.is_initialized: 1024 raise RuntimeError("The state has not been computed yet. Call initialize first.") 1025 1026 return { 1027 "foreground": self._foreground, 1028 "center_distances": self._center_distances, 1029 "boundary_distances": self._boundary_distances, 1030 } 1031 1032 def set_state(self, state: Dict[str, Any]) -> None: 1033 """Set the state of the instance segmenter. 1034 1035 Args: 1036 state: The instance segmentation state 1037 """ 1038 self._foreground = state["foreground"] 1039 self._center_distances = state["center_distances"] 1040 self._boundary_distances = state["boundary_distances"] 1041 self._is_initialized = True 1042 1043 def clear_state(self): 1044 """Clear the state of the instance segmenter. 1045 """ 1046 self._foreground = None 1047 self._center_distances = None 1048 self._boundary_distances = None 1049 self._is_initialized = False
Generates an instance segmentation without prompts, using a decoder.
Implements the same interface as AutomaticMaskGenerator
.
Use this class as follows:
segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
segmenter.initialize(image) # Predict the image embeddings and decoder outputs.
masks = segmenter.generate(center_distance_threshold=0.75) # Generate the instance segmentation.
Arguments:
- predictor: The segment anything predictor.
- decoder: The decoder to predict intermediate representations for instance segmentation.
860 def __init__( 861 self, 862 predictor: SamPredictor, 863 decoder: torch.nn.Module, 864 ) -> None: 865 self._predictor = predictor 866 self._decoder = decoder 867 868 # The decoder outputs. 869 self._foreground = None 870 self._center_distances = None 871 self._boundary_distances = None 872 873 self._is_initialized = False
875 @property 876 def is_initialized(self): 877 """Whether the mask generator has already been initialized. 878 """ 879 return self._is_initialized
Whether the mask generator has already been initialized.
881 @torch.no_grad() 882 def initialize( 883 self, 884 image: np.ndarray, 885 image_embeddings: Optional[util.ImageEmbeddings] = None, 886 i: Optional[int] = None, 887 verbose: bool = False, 888 pbar_init: Optional[callable] = None, 889 pbar_update: Optional[callable] = None, 890 ) -> None: 891 """Initialize image embeddings and decoder predictions for an image. 892 893 Args: 894 image: The input image, volume or timeseries. 895 image_embeddings: Optional precomputed image embeddings. 896 See `util.precompute_image_embeddings` for details. 897 i: Index for the image data. Required if `image` has three spatial dimensions 898 or a time dimension and two spatial dimensions. 899 verbose: Whether to be verbose. 900 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 901 Can be used together with pbar_update to handle napari progress bar in other thread. 902 To enables using this function within a threadworker. 903 pbar_update: Callback to update an external progress bar. 904 """ 905 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 906 pbar_init(1, "Initialize instance segmentation with decoder") 907 908 if image_embeddings is None: 909 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 910 911 # Get the image embeddings from the predictor. 912 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i) 913 embeddings = self._predictor.features 914 input_shape = tuple(self._predictor.input_size) 915 original_shape = tuple(self._predictor.original_size) 916 917 # Run prediction with the UNETR decoder. 918 output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0) 919 assert output.shape[0] == 3, f"{output.shape}" 920 pbar_update(1) 921 pbar_close() 922 923 # Set the state. 924 self._foreground = output[0] 925 self._center_distances = output[1] 926 self._boundary_distances = output[2] 927 self._is_initialized = True
Initialize image embeddings and decoder predictions for an image.
Arguments:
- image: The input image, volume or timeseries.
- image_embeddings: Optional precomputed image embeddings.
See
util.precompute_image_embeddings
for details. - i: Index for the image data. Required if
image
has three spatial dimensions or a time dimension and two spatial dimensions. - verbose: Whether to be verbose.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
969 def generate( 970 self, 971 center_distance_threshold: float = 0.5, 972 boundary_distance_threshold: float = 0.5, 973 foreground_threshold: float = 0.5, 974 foreground_smoothing: float = 1.0, 975 distance_smoothing: float = 1.6, 976 min_size: int = 0, 977 output_mode: Optional[str] = "binary_mask", 978 ) -> List[Dict[str, Any]]: 979 """Generate instance segmentation for the currently initialized image. 980 981 Args: 982 center_distance_threshold: Center distance predictions below this value will be 983 used to find seeds (intersected with thresholded boundary distance predictions). 984 boundary_distance_threshold: Boundary distance predictions below this value will be 985 used to find seeds (intersected with thresholded center distance predictions). 986 foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid 987 checkerboard artifacts in the prediction. 988 foreground_threshold: Foreground predictions above this value will be used as foreground mask. 989 distance_smoothing: Sigma value for smoothing the distance predictions. 990 min_size: Minimal object size in the segmentation result. 991 output_mode: The form masks are returned in. Pass None to directly return the instance segmentation. 992 993 Returns: 994 The instance segmentation masks. 995 """ 996 if not self.is_initialized: 997 raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.") 998 999 if foreground_smoothing > 0: 1000 foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing) 1001 else: 1002 foreground = self._foreground 1003 # Further optimization: parallel implementation using elf.parallel functionality. 1004 # (Make sure to expose n_threads to avoid over-subscription in case of outer parallelization) 1005 segmentation = watershed_from_center_and_boundary_distances( 1006 self._center_distances, self._boundary_distances, foreground, 1007 center_distance_threshold=center_distance_threshold, 1008 boundary_distance_threshold=boundary_distance_threshold, 1009 foreground_threshold=foreground_threshold, 1010 distance_smoothing=distance_smoothing, 1011 min_size=min_size, 1012 ) 1013 if output_mode is not None: 1014 segmentation = self._to_masks(segmentation, output_mode) 1015 return segmentation
Generate instance segmentation for the currently initialized image.
Arguments:
- center_distance_threshold: Center distance predictions below this value will be used to find seeds (intersected with thresholded boundary distance predictions).
- boundary_distance_threshold: Boundary distance predictions below this value will be used to find seeds (intersected with thresholded center distance predictions).
- foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid checkerboard artifacts in the prediction.
- foreground_threshold: Foreground predictions above this value will be used as foreground mask.
- distance_smoothing: Sigma value for smoothing the distance predictions.
- min_size: Minimal object size in the segmentation result.
- output_mode: The form masks are returned in. Pass None to directly return the instance segmentation.
Returns:
The instance segmentation masks.
1017 def get_state(self) -> Dict[str, Any]: 1018 """Get the initialized state of the instance segmenter. 1019 1020 Returns: 1021 Instance segmentation state. 1022 """ 1023 if not self.is_initialized: 1024 raise RuntimeError("The state has not been computed yet. Call initialize first.") 1025 1026 return { 1027 "foreground": self._foreground, 1028 "center_distances": self._center_distances, 1029 "boundary_distances": self._boundary_distances, 1030 }
Get the initialized state of the instance segmenter.
Returns:
Instance segmentation state.
1032 def set_state(self, state: Dict[str, Any]) -> None: 1033 """Set the state of the instance segmenter. 1034 1035 Args: 1036 state: The instance segmentation state 1037 """ 1038 self._foreground = state["foreground"] 1039 self._center_distances = state["center_distances"] 1040 self._boundary_distances = state["boundary_distances"] 1041 self._is_initialized = True
Set the state of the instance segmenter.
Arguments:
- state: The instance segmentation state
1052class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder): 1053 """Same as `InstanceSegmentationWithDecoder` but for tiled image embeddings. 1054 """ 1055 1056 @torch.no_grad() 1057 def initialize( 1058 self, 1059 image: np.ndarray, 1060 image_embeddings: Optional[util.ImageEmbeddings] = None, 1061 i: Optional[int] = None, 1062 tile_shape: Optional[Tuple[int, int]] = None, 1063 halo: Optional[Tuple[int, int]] = None, 1064 verbose: bool = False, 1065 pbar_init: Optional[callable] = None, 1066 pbar_update: Optional[callable] = None, 1067 ) -> None: 1068 """Initialize image embeddings and decoder predictions for an image. 1069 1070 Args: 1071 image: The input image, volume or timeseries. 1072 image_embeddings: Optional precomputed image embeddings. 1073 See `util.precompute_image_embeddings` for details. 1074 i: Index for the image data. Required if `image` has three spatial dimensions 1075 or a time dimension and two spatial dimensions. 1076 verbose: Dummy input to be compatible with other function signatures. 1077 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 1078 Can be used together with pbar_update to handle napari progress bar in other thread. 1079 To enables using this function within a threadworker. 1080 pbar_update: Callback to update an external progress bar. 1081 """ 1082 original_size = image.shape[:2] 1083 image_embeddings, tile_shape, halo = _process_tiled_embeddings( 1084 self._predictor, image, image_embeddings, tile_shape, halo 1085 ) 1086 tiling = blocking([0, 0], original_size, tile_shape) 1087 1088 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 1089 pbar_init(tiling.numberOfBlocks, "Initialize tiled instance segmentation with decoder") 1090 1091 foreground = np.zeros(original_size, dtype="float32") 1092 center_distances = np.zeros(original_size, dtype="float32") 1093 boundary_distances = np.zeros(original_size, dtype="float32") 1094 1095 for tile_id in range(tiling.numberOfBlocks): 1096 1097 # Get the image embeddings from the predictor for this tile. 1098 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i, tile_id=tile_id) 1099 embeddings = self._predictor.features 1100 input_shape = tuple(self._predictor.input_size) 1101 original_shape = tuple(self._predictor.original_size) 1102 1103 # Predict with the UNETR decoder for this tile. 1104 output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0) 1105 assert output.shape[0] == 3, f"{output.shape}" 1106 1107 # Set the predictions in the output for this tile. 1108 block = tiling.getBlockWithHalo(tile_id, halo=list(halo)) 1109 local_bb = tuple( 1110 slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end) 1111 ) 1112 inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end)) 1113 1114 foreground[inner_bb] = output[0][local_bb] 1115 center_distances[inner_bb] = output[1][local_bb] 1116 boundary_distances[inner_bb] = output[2][local_bb] 1117 pbar_close() 1118 1119 # Set the state. 1120 self._foreground = foreground 1121 self._center_distances = center_distances 1122 self._boundary_distances = boundary_distances 1123 self._is_initialized = True
Same as InstanceSegmentationWithDecoder
but for tiled image embeddings.
1056 @torch.no_grad() 1057 def initialize( 1058 self, 1059 image: np.ndarray, 1060 image_embeddings: Optional[util.ImageEmbeddings] = None, 1061 i: Optional[int] = None, 1062 tile_shape: Optional[Tuple[int, int]] = None, 1063 halo: Optional[Tuple[int, int]] = None, 1064 verbose: bool = False, 1065 pbar_init: Optional[callable] = None, 1066 pbar_update: Optional[callable] = None, 1067 ) -> None: 1068 """Initialize image embeddings and decoder predictions for an image. 1069 1070 Args: 1071 image: The input image, volume or timeseries. 1072 image_embeddings: Optional precomputed image embeddings. 1073 See `util.precompute_image_embeddings` for details. 1074 i: Index for the image data. Required if `image` has three spatial dimensions 1075 or a time dimension and two spatial dimensions. 1076 verbose: Dummy input to be compatible with other function signatures. 1077 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 1078 Can be used together with pbar_update to handle napari progress bar in other thread. 1079 To enables using this function within a threadworker. 1080 pbar_update: Callback to update an external progress bar. 1081 """ 1082 original_size = image.shape[:2] 1083 image_embeddings, tile_shape, halo = _process_tiled_embeddings( 1084 self._predictor, image, image_embeddings, tile_shape, halo 1085 ) 1086 tiling = blocking([0, 0], original_size, tile_shape) 1087 1088 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 1089 pbar_init(tiling.numberOfBlocks, "Initialize tiled instance segmentation with decoder") 1090 1091 foreground = np.zeros(original_size, dtype="float32") 1092 center_distances = np.zeros(original_size, dtype="float32") 1093 boundary_distances = np.zeros(original_size, dtype="float32") 1094 1095 for tile_id in range(tiling.numberOfBlocks): 1096 1097 # Get the image embeddings from the predictor for this tile. 1098 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i, tile_id=tile_id) 1099 embeddings = self._predictor.features 1100 input_shape = tuple(self._predictor.input_size) 1101 original_shape = tuple(self._predictor.original_size) 1102 1103 # Predict with the UNETR decoder for this tile. 1104 output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0) 1105 assert output.shape[0] == 3, f"{output.shape}" 1106 1107 # Set the predictions in the output for this tile. 1108 block = tiling.getBlockWithHalo(tile_id, halo=list(halo)) 1109 local_bb = tuple( 1110 slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end) 1111 ) 1112 inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end)) 1113 1114 foreground[inner_bb] = output[0][local_bb] 1115 center_distances[inner_bb] = output[1][local_bb] 1116 boundary_distances[inner_bb] = output[2][local_bb] 1117 pbar_close() 1118 1119 # Set the state. 1120 self._foreground = foreground 1121 self._center_distances = center_distances 1122 self._boundary_distances = boundary_distances 1123 self._is_initialized = True
Initialize image embeddings and decoder predictions for an image.
Arguments:
- image: The input image, volume or timeseries.
- image_embeddings: Optional precomputed image embeddings.
See
util.precompute_image_embeddings
for details. - i: Index for the image data. Required if
image
has three spatial dimensions or a time dimension and two spatial dimensions. - verbose: Dummy input to be compatible with other function signatures.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
1126def get_amg( 1127 predictor: SamPredictor, is_tiled: bool, decoder: Optional[torch.nn.Module] = None, **kwargs, 1128) -> Union[AMGBase, InstanceSegmentationWithDecoder]: 1129 """Get the automatic mask generator class. 1130 1131 Args: 1132 predictor: The segment anything predictor. 1133 is_tiled: Whether tiled embeddings are used. 1134 decoder: Decoder to predict instacne segmmentation. 1135 kwargs: The keyword arguments for the amg class. 1136 1137 Returns: 1138 The automatic mask generator. 1139 """ 1140 if decoder is None: 1141 segmenter_class = TiledAutomaticMaskGenerator if is_tiled else AutomaticMaskGenerator 1142 segmenter = segmenter_class(predictor, **kwargs) 1143 else: 1144 segmenter_class = TiledInstanceSegmentationWithDecoder if is_tiled else InstanceSegmentationWithDecoder 1145 segmenter = segmenter_class(predictor, decoder, **kwargs) 1146 1147 return segmenter
Get the automatic mask generator class.
Arguments:
- predictor: The segment anything predictor.
- is_tiled: Whether tiled embeddings are used.
- decoder: Decoder to predict instacne segmmentation.
- kwargs: The keyword arguments for the amg class.
Returns:
The automatic mask generator.