micro_sam.instance_segmentation
Automated instance segmentation functionality. The classes implemented here extend the automatic instance segmentation from Segment Anything: https://computational-cell-analytics.github.io/micro-sam/micro_sam.html
1""" 2Automated instance segmentation functionality. 3The classes implemented here extend the automatic instance segmentation from Segment Anything: 4https://computational-cell-analytics.github.io/micro-sam/micro_sam.html 5""" 6 7import os 8from abc import ABC 9from copy import deepcopy 10from collections import OrderedDict 11from typing import Any, Dict, List, Optional, Tuple, Union 12 13import vigra 14import numpy as np 15import torch 16 17from skimage.measure import label, regionprops 18from skimage.segmentation import relabel_sequential 19from torchvision.ops.boxes import batched_nms, box_area 20 21from torch_em.model import UNETR 22from torch_em.util.segmentation import watershed_from_center_and_boundary_distances 23 24from nifty.tools import blocking 25 26import segment_anything.utils.amg as amg_utils 27from segment_anything.predictor import SamPredictor 28 29from . import util 30from ._vendored import batched_mask_to_box, mask_to_rle_pytorch 31 32# 33# Utility Functionality 34# 35 36 37class _FakeInput: 38 def __init__(self, shape): 39 self.shape = shape 40 41 def __getitem__(self, index): 42 block_shape = tuple(ind.stop - ind.start for ind in index) 43 return np.zeros(block_shape, dtype="float32") 44 45 46def mask_data_to_segmentation( 47 masks: List[Dict[str, Any]], 48 with_background: bool, 49 min_object_size: int = 0, 50 max_object_size: Optional[int] = None, 51 label_masks: bool = True, 52) -> np.ndarray: 53 """Convert the output of the automatic mask generation to an instance segmentation. 54 55 Args: 56 masks: The outputs generated by AutomaticMaskGenerator or EmbeddingMaskGenerator. 57 Only supports output_mode=binary_mask. 58 with_background: Whether the segmentation has background. If yes this function assures that the largest 59 object in the output will be mapped to zero (the background value). 60 min_object_size: The minimal size of an object in pixels. 61 max_object_size: The maximal size of an object in pixels. 62<<<<<<< HEAD 63 64======= 65 label_masks: Whether to apply connected components to the result before remving small objects. 66>>>>>>> master 67 Returns: 68 The instance segmentation. 69 """ 70 71 masks = sorted(masks, key=(lambda x: x["area"]), reverse=True) 72 # we could also get the shape from the crop box 73 shape = next(iter(masks))["segmentation"].shape 74 segmentation = np.zeros(shape, dtype="uint32") 75 76 def require_numpy(mask): 77 return mask.cpu().numpy() if torch.is_tensor(mask) else mask 78 79 seg_id = 1 80 for mask in masks: 81 if mask["area"] < min_object_size: 82 continue 83 if max_object_size is not None and mask["area"] > max_object_size: 84 continue 85 86 this_seg_id = mask.get("seg_id", seg_id) 87 segmentation[require_numpy(mask["segmentation"])] = this_seg_id 88 seg_id = this_seg_id + 1 89 90 if label_masks: 91 segmentation = label(segmentation) 92 seg_ids, sizes = np.unique(segmentation, return_counts=True) 93 94 # In some cases objects may be smaller than peviously calculated, 95 # since they are covered by other objects. We ensure these also get 96 # filtered out here. 97 filter_ids = seg_ids[sizes < min_object_size] 98 99 # If we run segmentation with background we also map the largest segment 100 # (the most likely background object) to zero. This is often zero already, 101 # but it does not hurt to reset that to zero either. 102 if with_background: 103 bg_id = seg_ids[np.argmax(sizes)] 104 filter_ids = np.concatenate([filter_ids, [bg_id]]) 105 106 segmentation[np.isin(segmentation, filter_ids)] = 0 107 segmentation = relabel_sequential(segmentation)[0] 108 109 return segmentation 110 111 112# 113# Classes for automatic instance segmentation 114# 115 116 117class AMGBase(ABC): 118 """Base class for the automatic mask generators. 119 """ 120 def __init__(self): 121 # the state that has to be computed by the 'initialize' method of the child classes 122 self._is_initialized = False 123 self._crop_list = None 124 self._crop_boxes = None 125 self._original_size = None 126 127 @property 128 def is_initialized(self): 129 """Whether the mask generator has already been initialized. 130 """ 131 return self._is_initialized 132 133 @property 134 def crop_list(self): 135 """The list of mask data after initialization. 136 """ 137 return self._crop_list 138 139 @property 140 def crop_boxes(self): 141 """The list of crop boxes. 142 """ 143 return self._crop_boxes 144 145 @property 146 def original_size(self): 147 """The original image size. 148 """ 149 return self._original_size 150 151 def _postprocess_batch( 152 self, 153 data, 154 crop_box, 155 original_size, 156 pred_iou_thresh, 157 stability_score_thresh, 158 box_nms_thresh, 159 ): 160 orig_h, orig_w = original_size 161 162 # filter by predicted IoU 163 if pred_iou_thresh > 0.0: 164 keep_mask = data["iou_preds"] > pred_iou_thresh 165 data.filter(keep_mask) 166 167 # filter by stability score 168 if stability_score_thresh > 0.0: 169 keep_mask = data["stability_score"] >= stability_score_thresh 170 data.filter(keep_mask) 171 172 # filter boxes that touch crop boundaries 173 keep_mask = ~amg_utils.is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) 174 if not torch.all(keep_mask): 175 data.filter(keep_mask) 176 177 # remove duplicates within this crop. 178 keep_by_nms = batched_nms( 179 data["boxes"].float(), 180 data["iou_preds"], 181 torch.zeros_like(data["boxes"][:, 0]), # categories 182 iou_threshold=box_nms_thresh, 183 ) 184 data.filter(keep_by_nms) 185 186 # return to the original image frame 187 data["boxes"] = amg_utils.uncrop_boxes_xyxy(data["boxes"], crop_box) 188 data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) 189 # the data from embedding based segmentation doesn't have the points 190 # so we skip if the corresponding key can't be found 191 try: 192 data["points"] = amg_utils.uncrop_points(data["points"], crop_box) 193 except KeyError: 194 pass 195 196 return data 197 198 def _postprocess_small_regions(self, mask_data, min_area, nms_thresh): 199 200 if len(mask_data["rles"]) == 0: 201 return mask_data 202 203 # filter small disconnected regions and holes 204 new_masks = [] 205 scores = [] 206 for rle in mask_data["rles"]: 207 mask = amg_utils.rle_to_mask(rle) 208 209 mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="holes") 210 unchanged = not changed 211 mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="islands") 212 unchanged = unchanged and not changed 213 214 new_masks.append(torch.as_tensor(mask, dtype=torch.int).unsqueeze(0)) 215 # give score=0 to changed masks and score=1 to unchanged masks 216 # so NMS will prefer ones that didn't need postprocessing 217 scores.append(float(unchanged)) 218 219 # recalculate boxes and remove any new duplicates 220 masks = torch.cat(new_masks, dim=0) 221 boxes = batched_mask_to_box(masks) 222 keep_by_nms = batched_nms( 223 boxes.float(), 224 torch.as_tensor(scores, dtype=torch.float), 225 torch.zeros_like(boxes[:, 0]), # categories 226 iou_threshold=nms_thresh, 227 ) 228 229 # only recalculate RLEs for masks that have changed 230 for i_mask in keep_by_nms: 231 if scores[i_mask] == 0.0: 232 mask_torch = masks[i_mask].unsqueeze(0) 233 # mask_data["rles"][i_mask] = amg_utils.mask_to_rle_pytorch(mask_torch)[0] 234 mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] 235 mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly 236 mask_data.filter(keep_by_nms) 237 238 return mask_data 239 240 def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode): 241 # filter small disconnected regions and holes in masks 242 if min_mask_region_area > 0: 243 mask_data = self._postprocess_small_regions( 244 mask_data, 245 min_mask_region_area, 246 max(box_nms_thresh, crop_nms_thresh), 247 ) 248 249 # encode masks 250 if output_mode == "coco_rle": 251 mask_data["segmentations"] = [amg_utils.coco_encode_rle(rle) for rle in mask_data["rles"]] 252 elif output_mode == "binary_mask": 253 mask_data["segmentations"] = [amg_utils.rle_to_mask(rle) for rle in mask_data["rles"]] 254 else: 255 mask_data["segmentations"] = mask_data["rles"] 256 257 # write mask records 258 curr_anns = [] 259 for idx in range(len(mask_data["segmentations"])): 260 ann = { 261 "segmentation": mask_data["segmentations"][idx], 262 "area": amg_utils.area_from_rle(mask_data["rles"][idx]), 263 "bbox": amg_utils.box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), 264 "predicted_iou": mask_data["iou_preds"][idx].item(), 265 "stability_score": mask_data["stability_score"][idx].item(), 266 "crop_box": amg_utils.box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), 267 } 268 # the data from embedding based segmentation doesn't have the points 269 # so we skip if the corresponding key can't be found 270 try: 271 ann["point_coords"] = [mask_data["points"][idx].tolist()] 272 except KeyError: 273 pass 274 curr_anns.append(ann) 275 276 return curr_anns 277 278 def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None): 279 orig_h, orig_w = original_size 280 281 # serialize predictions and store in MaskData 282 data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1)) 283 if points is not None: 284 data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0), dtype=torch.float) 285 286 del masks 287 288 # calculate the stability scores 289 data["stability_score"] = amg_utils.calculate_stability_score( 290 data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset 291 ) 292 293 # threshold masks and calculate boxes 294 data["masks"] = data["masks"] > self._predictor.model.mask_threshold 295 data["masks"] = data["masks"].type(torch.bool) 296 data["boxes"] = batched_mask_to_box(data["masks"]) 297 298 # compress to RLE 299 data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w) 300 # data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"]) 301 data["rles"] = mask_to_rle_pytorch(data["masks"]) 302 del data["masks"] 303 304 return data 305 306 def get_state(self) -> Dict[str, Any]: 307 """Get the initialized state of the mask generator. 308 309 Returns: 310 State of the mask generator. 311 """ 312 if not self.is_initialized: 313 raise RuntimeError("The state has not been computed yet. Call initialize first.") 314 315 return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size} 316 317 def set_state(self, state: Dict[str, Any]) -> None: 318 """Set the state of the mask generator. 319 320 Args: 321 state: The state of the mask generator, e.g. from serialized state. 322 """ 323 self._crop_list = state["crop_list"] 324 self._crop_boxes = state["crop_boxes"] 325 self._original_size = state["original_size"] 326 self._is_initialized = True 327 328 def clear_state(self): 329 """Clear the state of the mask generator. 330 """ 331 self._crop_list = None 332 self._crop_boxes = None 333 self._original_size = None 334 self._is_initialized = False 335 336 337class AutomaticMaskGenerator(AMGBase): 338 """Generates an instance segmentation without prompts, using a point grid. 339 340 This class implements the same logic as 341 https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py 342 It decouples the computationally expensive steps of generating masks from the cheap post-processing operation 343 to filter these masks to enable grid search and interactively changing the post-processing. 344 345 Use this class as follows: 346 ```python 347 amg = AutomaticMaskGenerator(predictor) 348 amg.initialize(image) # Initialize the masks, this takes care of all expensive computations. 349 masks = amg.generate(pred_iou_thresh=0.8) # Generate the masks. This is fast and enables testing parameters 350 ``` 351 352 Args: 353 predictor: The segment anything predictor. 354 points_per_side: The number of points to be sampled along one side of the image. 355 If None, `point_grids` must provide explicit point sampling. 356 points_per_batch: The number of points run simultaneously by the model. 357 Higher numbers may be faster but use more GPU memory. 358 crop_n_layers: If >0, the mask prediction will be run again on crops of the image. 359 crop_overlap_ratio: Sets the degree to which crops overlap. 360 crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops. 361 point_grids: A lisst over explicit grids of points used for sampling masks. 362 Normalized to [0, 1] with respect to the image coordinate system. 363 stability_score_offset: The amount to shift the cutoff when calculating the stability score. 364 """ 365 def __init__( 366 self, 367 predictor: SamPredictor, 368 points_per_side: Optional[int] = 32, 369 points_per_batch: Optional[int] = None, 370 crop_n_layers: int = 0, 371 crop_overlap_ratio: float = 512 / 1500, 372 crop_n_points_downscale_factor: int = 1, 373 point_grids: Optional[List[np.ndarray]] = None, 374 stability_score_offset: float = 1.0, 375 ): 376 super().__init__() 377 378 if points_per_side is not None: 379 self.point_grids = amg_utils.build_all_layer_point_grids( 380 points_per_side, 381 crop_n_layers, 382 crop_n_points_downscale_factor, 383 ) 384 elif point_grids is not None: 385 self.point_grids = point_grids 386 else: 387 raise ValueError("Can't have both points_per_side and point_grid be None or not None.") 388 389 self._predictor = predictor 390 self._points_per_side = points_per_side 391 392 # we set the points per batch to 16 for mps for performance reasons 393 # and otherwise keep them at the default of 64 394 if points_per_batch is None: 395 points_per_batch = 16 if str(predictor.device) == "mps" else 64 396 self._points_per_batch = points_per_batch 397 398 self._crop_n_layers = crop_n_layers 399 self._crop_overlap_ratio = crop_overlap_ratio 400 self._crop_n_points_downscale_factor = crop_n_points_downscale_factor 401 self._stability_score_offset = stability_score_offset 402 403 def _process_batch(self, points, im_size, crop_box, original_size): 404 # run model on this batch 405 transformed_points = self._predictor.transform.apply_coords(points, im_size) 406 in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float) 407 in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) 408 masks, iou_preds, _ = self._predictor.predict_torch( 409 in_points[:, None, :], 410 in_labels[:, None], 411 multimask_output=True, 412 return_logits=True, 413 ) 414 data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points) 415 del masks 416 return data 417 418 def _process_crop(self, image, crop_box, crop_layer_idx, precomputed_embeddings, pbar_init=None, pbar_update=None): 419 # Crop the image and calculate embeddings. 420 x0, y0, x1, y1 = crop_box 421 cropped_im = image[y0:y1, x0:x1, :] 422 cropped_im_size = cropped_im.shape[:2] 423 424 if not precomputed_embeddings: 425 self._predictor.set_image(cropped_im) 426 427 # Get the points for this crop. 428 points_scale = np.array(cropped_im_size)[None, ::-1] 429 points_for_image = self.point_grids[crop_layer_idx] * points_scale 430 431 # Generate masks for this crop in batches. 432 data = amg_utils.MaskData() 433 n_batches = len(points_for_image) // self._points_per_batch +\ 434 int(len(points_for_image) % self._points_per_batch != 0) 435 if pbar_init is not None: 436 pbar_init(n_batches, "Predict masks for point grid prompts") 437 438 for (points,) in amg_utils.batch_iterator(self._points_per_batch, points_for_image): 439 batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size) 440 data.cat(batch_data) 441 del batch_data 442 if pbar_update is not None: 443 pbar_update(1) 444 445 if not precomputed_embeddings: 446 self._predictor.reset_image() 447 448 return data 449 450 @torch.no_grad() 451 def initialize( 452 self, 453 image: np.ndarray, 454 image_embeddings: Optional[util.ImageEmbeddings] = None, 455 i: Optional[int] = None, 456 verbose: bool = False, 457 pbar_init: Optional[callable] = None, 458 pbar_update: Optional[callable] = None, 459 ) -> None: 460 """Initialize image embeddings and masks for an image. 461 462 Args: 463 image: The input image, volume or timeseries. 464 image_embeddings: Optional precomputed image embeddings. 465 See `util.precompute_image_embeddings` for details. 466 i: Index for the image data. Required if `image` has three spatial dimensions 467 or a time dimension and two spatial dimensions. 468 verbose: Whether to print computation progress. 469 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 470 Can be used together with pbar_update to handle napari progress bar in other thread. 471 To enables using this function within a threadworker. 472 pbar_update: Callback to update an external progress bar. 473 """ 474 original_size = image.shape[:2] 475 self._original_size = original_size 476 477 crop_boxes, layer_idxs = amg_utils.generate_crop_boxes( 478 original_size, self._crop_n_layers, self._crop_overlap_ratio 479 ) 480 481 # We can set fixed image embeddings if we only have a single crop box (the default setting). 482 # Otherwise we have to recompute the embeddings for each crop and can't precompute. 483 if len(crop_boxes) == 1: 484 if image_embeddings is None: 485 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 486 util.set_precomputed(self._predictor, image_embeddings, i=i) 487 precomputed_embeddings = True 488 else: 489 precomputed_embeddings = False 490 491 # we need to cast to the image representation that is compatible with SAM 492 image = util._to_image(image) 493 494 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 495 496 crop_list = [] 497 for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 498 crop_data = self._process_crop( 499 image, crop_box, layer_idx, 500 precomputed_embeddings=precomputed_embeddings, 501 pbar_init=pbar_init, pbar_update=pbar_update, 502 ) 503 crop_list.append(crop_data) 504 pbar_close() 505 506 self._is_initialized = True 507 self._crop_list = crop_list 508 self._crop_boxes = crop_boxes 509 510 @torch.no_grad() 511 def generate( 512 self, 513 pred_iou_thresh: float = 0.88, 514 stability_score_thresh: float = 0.95, 515 box_nms_thresh: float = 0.7, 516 crop_nms_thresh: float = 0.7, 517 min_mask_region_area: int = 0, 518 output_mode: str = "binary_mask", 519 ) -> List[Dict[str, Any]]: 520 """Generate instance segmentation for the currently initialized image. 521 522 Args: 523 pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model. 524 stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask 525 under changes to the cutoff used to binarize the model prediction. 526 box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks. 527 crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops. 528 min_mask_region_area: Minimal size for the predicted masks. 529 output_mode: The form masks are returned in. 530 531 Returns: 532 The instance segmentation masks. 533 """ 534 if not self.is_initialized: 535 raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.") 536 537 data = amg_utils.MaskData() 538 for data_, crop_box in zip(self.crop_list, self.crop_boxes): 539 crop_data = self._postprocess_batch( 540 data=deepcopy(data_), 541 crop_box=crop_box, original_size=self.original_size, 542 pred_iou_thresh=pred_iou_thresh, 543 stability_score_thresh=stability_score_thresh, 544 box_nms_thresh=box_nms_thresh 545 ) 546 data.cat(crop_data) 547 548 if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0: 549 # Prefer masks from smaller crops 550 scores = 1 / box_area(data["crop_boxes"]) 551 scores = scores.to(data["boxes"].device) 552 keep_by_nms = batched_nms( 553 data["boxes"].float(), 554 scores, 555 torch.zeros_like(data["boxes"][:, 0]), # categories 556 iou_threshold=crop_nms_thresh, 557 ) 558 data.filter(keep_by_nms) 559 560 data.to_numpy() 561 masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode) 562 return masks 563 564 565# Helper function for tiled embedding computation and checking consistent state. 566def _process_tiled_embeddings(predictor, image, image_embeddings, tile_shape, halo): 567 if image_embeddings is None: 568 if tile_shape is None or halo is None: 569 raise ValueError("To compute tiled embeddings the parameters tile_shape and halo have to be passed.") 570 image_embeddings = util.precompute_image_embeddings(predictor, image, tile_shape=tile_shape, halo=halo) 571 572 # Use tile shape and halo from the precomputed embeddings if not given. 573 # Otherwise check that they are consistent. 574 feats = image_embeddings["features"] 575 tile_shape_, halo_ = feats.attrs["tile_shape"], feats.attrs["halo"] 576 if tile_shape is None: 577 tile_shape = tile_shape_ 578 elif tile_shape != tile_shape_: 579 raise ValueError( 580 f"Inconsistent tile_shape parameter {tile_shape} with precomputed embeedings: {tile_shape_}." 581 ) 582 if halo is None: 583 halo = halo_ 584 elif halo != halo_: 585 raise ValueError(f"Inconsistent halo parameter {halo} with precomputed embeedings: {halo_}.") 586 587 return image_embeddings, tile_shape, halo 588 589 590class TiledAutomaticMaskGenerator(AutomaticMaskGenerator): 591 """Generates an instance segmentation without prompts, using a point grid. 592 593 Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings. 594 595 Args: 596 predictor: The segment anything predictor. 597 points_per_side: The number of points to be sampled along one side of the image. 598 If None, `point_grids` must provide explicit point sampling. 599 points_per_batch: The number of points run simultaneously by the model. 600 Higher numbers may be faster but use more GPU memory. 601 point_grids: A lisst over explicit grids of points used for sampling masks. 602 Normalized to [0, 1] with respect to the image coordinate system. 603 stability_score_offset: The amount to shift the cutoff when calculating the stability score. 604 """ 605 606 # We only expose the arguments that make sense for the tiled mask generator. 607 # Anything related to crops doesn't make sense, because we re-use that functionality 608 # for tiling, so these parameters wouldn't have any effect. 609 def __init__( 610 self, 611 predictor: SamPredictor, 612 points_per_side: Optional[int] = 32, 613 points_per_batch: int = 64, 614 point_grids: Optional[List[np.ndarray]] = None, 615 stability_score_offset: float = 1.0, 616 ) -> None: 617 super().__init__( 618 predictor=predictor, 619 points_per_side=points_per_side, 620 points_per_batch=points_per_batch, 621 point_grids=point_grids, 622 stability_score_offset=stability_score_offset, 623 ) 624 625 @torch.no_grad() 626 def initialize( 627 self, 628 image: np.ndarray, 629 image_embeddings: Optional[util.ImageEmbeddings] = None, 630 i: Optional[int] = None, 631 tile_shape: Optional[Tuple[int, int]] = None, 632 halo: Optional[Tuple[int, int]] = None, 633 verbose: bool = False, 634 pbar_init: Optional[callable] = None, 635 pbar_update: Optional[callable] = None, 636 ) -> None: 637 """Initialize image embeddings and masks for an image. 638 639 Args: 640 image: The input image, volume or timeseries. 641 image_embeddings: Optional precomputed image embeddings. 642 See `util.precompute_image_embeddings` for details. 643 i: Index for the image data. Required if `image` has three spatial dimensions 644 or a time dimension and two spatial dimensions. 645 tile_shape: The tile shape for embedding prediction. 646 halo: The overlap of between tiles. 647 verbose: Whether to print computation progress. 648 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 649 Can be used together with pbar_update to handle napari progress bar in other thread. 650 To enables using this function within a threadworker. 651 pbar_update: Callback to update an external progress bar. 652 """ 653 original_size = image.shape[:2] 654 self._original_size = original_size 655 656 image_embeddings, tile_shape, halo = _process_tiled_embeddings( 657 self._predictor, image, image_embeddings, tile_shape, halo 658 ) 659 660 tiling = blocking([0, 0], original_size, tile_shape) 661 n_tiles = tiling.numberOfBlocks 662 663 # The crop box is always the full local tile. 664 tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)] 665 crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles] 666 667 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 668 pbar_init(n_tiles, "Compute masks for tile") 669 670 # We need to cast to the image representation that is compatible with SAM. 671 image = util._to_image(image) 672 673 mask_data = [] 674 for tile_id in range(n_tiles): 675 # set the pre-computed embeddings for this tile 676 features = image_embeddings["features"][tile_id] 677 tile_embeddings = { 678 "features": features, 679 "input_size": features.attrs["input_size"], 680 "original_size": features.attrs["original_size"], 681 } 682 util.set_precomputed(self._predictor, tile_embeddings, i) 683 684 # compute the mask data for this tile and append it 685 this_mask_data = self._process_crop( 686 image, crop_box=crop_boxes[tile_id], crop_layer_idx=0, precomputed_embeddings=True 687 ) 688 mask_data.append(this_mask_data) 689 pbar_update(1) 690 pbar_close() 691 692 # set the initialized data 693 self._is_initialized = True 694 self._crop_list = mask_data 695 self._crop_boxes = crop_boxes 696 697 698# 699# Instance segmentation functionality based on fine-tuned decoder 700# 701 702 703class DecoderAdapter(torch.nn.Module): 704 """Adapter to contain the UNETR decoder in a single module. 705 706 To apply the decoder on top of pre-computed embeddings for 707 the segmentation functionality. 708 See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py 709 """ 710 def __init__(self, unetr): 711 super().__init__() 712 713 self.base = unetr.base 714 self.out_conv = unetr.out_conv 715 self.deconv_out = unetr.deconv_out 716 self.decoder_head = unetr.decoder_head 717 self.final_activation = unetr.final_activation 718 self.postprocess_masks = unetr.postprocess_masks 719 720 self.decoder = unetr.decoder 721 self.deconv1 = unetr.deconv1 722 self.deconv2 = unetr.deconv2 723 self.deconv3 = unetr.deconv3 724 self.deconv4 = unetr.deconv4 725 726 def forward(self, input_, input_shape, original_shape): 727 z12 = input_ 728 729 z9 = self.deconv1(z12) 730 z6 = self.deconv2(z9) 731 z3 = self.deconv3(z6) 732 z0 = self.deconv4(z3) 733 734 updated_from_encoder = [z9, z6, z3] 735 736 x = self.base(z12) 737 x = self.decoder(x, encoder_inputs=updated_from_encoder) 738 x = self.deconv_out(x) 739 740 x = torch.cat([x, z0], dim=1) 741 x = self.decoder_head(x) 742 743 x = self.out_conv(x) 744 if self.final_activation is not None: 745 x = self.final_activation(x) 746 747 x = self.postprocess_masks(x, input_shape, original_shape) 748 return x 749 750 751def get_unetr( 752 image_encoder: torch.nn.Module, 753 decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None, 754 device: Optional[Union[str, torch.device]] = None, 755) -> torch.nn.Module: 756 """Get UNETR model for automatic instance segmentation. 757 758 Args: 759 image_encoder: The image encoder of the SAM model. 760 This is used as encoder by the UNETR too. 761 decoder_state: Optional decoder state to initialize the weights 762 of the UNETR decoder. 763 device: The device. 764 Returns: 765 The UNETR model. 766 """ 767 device = util.get_device(device) 768 769 unetr = UNETR( 770 backbone="sam", 771 encoder=image_encoder, 772 out_channels=3, 773 use_sam_stats=True, 774 final_activation="Sigmoid", 775 use_skip_connection=False, 776 resize_input=True, 777 ) 778 if decoder_state is not None: 779 unetr_state_dict = unetr.state_dict() 780 for k, v in unetr_state_dict.items(): 781 if not k.startswith("encoder"): 782 unetr_state_dict[k] = decoder_state[k] 783 unetr.load_state_dict(unetr_state_dict) 784 785 unetr.to(device) 786 return unetr 787 788 789def get_decoder( 790 image_encoder: torch.nn.Module, 791 decoder_state: OrderedDict[str, torch.Tensor], 792 device: Optional[Union[str, torch.device]] = None, 793) -> DecoderAdapter: 794 """Get decoder to predict outputs for automatic instance segmentation 795 796 Args: 797 image_encoder: The image encoder of the SAM model. 798 decoder_state: State to initialize the weights of the UNETR decoder. 799 device: The device. 800 Returns: 801 The decoder for instance segmentation. 802 """ 803 unetr = get_unetr(image_encoder, decoder_state, device) 804 return DecoderAdapter(unetr) 805 806 807def get_predictor_and_decoder( 808 model_type: str, 809 checkpoint_path: Union[str, os.PathLike], 810 device: Optional[Union[str, torch.device]] = None, 811 peft_kwargs: Optional[Dict] = None, 812) -> Tuple[SamPredictor, DecoderAdapter]: 813 """Load the SAM model (predictor) and instance segmentation decoder. 814 815 This requires a checkpoint that contains the state for both predictor 816 and decoder. 817 818 Args: 819 model_type: The type of the image encoder used in the SAM model. 820 checkpoint_path: Path to the checkpoint from which to load the data. 821 device: The device. 822 lora_rank: The rank for low rank adaptation of the attention layers. 823 824 Returns: 825 The SAM predictor. 826 The decoder for instance segmentation. 827 """ 828 device = util.get_device(device) 829 predictor, state = util.get_sam_model( 830 model_type=model_type, 831 checkpoint_path=checkpoint_path, 832 device=device, 833 return_state=True, 834 peft_kwargs=peft_kwargs, 835 ) 836 if "decoder_state" not in state: 837 raise ValueError( 838 f"The checkpoint at '{checkpoint_path}' or the chosen model '{model_type}' does not contain a decoder state" 839 ) 840 decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device) 841 return predictor, decoder 842 843 844class InstanceSegmentationWithDecoder: 845 """Generates an instance segmentation without prompts, using a decoder. 846 847 Implements the same interface as `AutomaticMaskGenerator`. 848 849 Use this class as follows: 850 ```python 851 segmenter = InstanceSegmentationWithDecoder(predictor, decoder) 852 segmenter.initialize(image) # Predict the image embeddings and decoder outputs. 853 masks = segmenter.generate(center_distance_threshold=0.75) # Generate the instance segmentation. 854 ``` 855 856 Args: 857 predictor: The segment anything predictor. 858 decoder: The decoder to predict intermediate representations 859 for instance segmentation. 860 """ 861 def __init__( 862 self, 863 predictor: SamPredictor, 864 decoder: torch.nn.Module, 865 ) -> None: 866 self._predictor = predictor 867 self._decoder = decoder 868 869 # The decoder outputs. 870 self._foreground = None 871 self._center_distances = None 872 self._boundary_distances = None 873 874 self._is_initialized = False 875 876 @property 877 def is_initialized(self): 878 """Whether the mask generator has already been initialized. 879 """ 880 return self._is_initialized 881 882 @torch.no_grad() 883 def initialize( 884 self, 885 image: np.ndarray, 886 image_embeddings: Optional[util.ImageEmbeddings] = None, 887 i: Optional[int] = None, 888 verbose: bool = False, 889 pbar_init: Optional[callable] = None, 890 pbar_update: Optional[callable] = None, 891 ) -> None: 892 """Initialize image embeddings and decoder predictions for an image. 893 894 Args: 895 image: The input image, volume or timeseries. 896 image_embeddings: Optional precomputed image embeddings. 897 See `util.precompute_image_embeddings` for details. 898 i: Index for the image data. Required if `image` has three spatial dimensions 899 or a time dimension and two spatial dimensions. 900 verbose: Whether to be verbose. 901 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 902 Can be used together with pbar_update to handle napari progress bar in other thread. 903 To enables using this function within a threadworker. 904 pbar_update: Callback to update an external progress bar. 905 """ 906 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 907 pbar_init(1, "Initialize instance segmentation with decoder") 908 909 if image_embeddings is None: 910 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 911 912 # Get the image embeddings from the predictor. 913 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i) 914 embeddings = self._predictor.features 915 input_shape = tuple(self._predictor.input_size) 916 original_shape = tuple(self._predictor.original_size) 917 918 # Run prediction with the UNETR decoder. 919 output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0) 920 assert output.shape[0] == 3, f"{output.shape}" 921 pbar_update(1) 922 pbar_close() 923 924 # Set the state. 925 self._foreground = output[0] 926 self._center_distances = output[1] 927 self._boundary_distances = output[2] 928 self._is_initialized = True 929 930 def _to_masks(self, segmentation, output_mode): 931 if output_mode != "binary_mask": 932 raise NotImplementedError 933 934 props = regionprops(segmentation) 935 ndim = segmentation.ndim 936 assert ndim in (2, 3) 937 938 shape = segmentation.shape 939 if ndim == 2: 940 crop_box = [0, shape[1], 0, shape[0]] 941 else: 942 crop_box = [0, shape[2], 0, shape[1], 0, shape[0]] 943 944 # go from skimage bbox in format [y0, x0, y1, x1] to SAM format [x0, w, y0, h] 945 def to_bbox_2d(bbox): 946 y0, x0 = bbox[0], bbox[1] 947 w = bbox[3] - x0 948 h = bbox[2] - y0 949 return [x0, w, y0, h] 950 951 def to_bbox_3d(bbox): 952 z0, y0, x0 = bbox[0], bbox[1], bbox[2] 953 w = bbox[5] - x0 954 h = bbox[4] - y0 955 d = bbox[3] - y0 956 return [x0, w, y0, h, z0, d] 957 958 to_bbox = to_bbox_2d if ndim == 2 else to_bbox_3d 959 masks = [ 960 { 961 "segmentation": segmentation == prop.label, 962 "area": prop.area, 963 "bbox": to_bbox(prop.bbox), 964 "crop_box": crop_box, 965 "seg_id": prop.label, 966 } for prop in props 967 ] 968 return masks 969 970 def generate( 971 self, 972 center_distance_threshold: float = 0.5, 973 boundary_distance_threshold: float = 0.5, 974 foreground_threshold: float = 0.5, 975 foreground_smoothing: float = 1.0, 976 distance_smoothing: float = 1.6, 977 min_size: int = 0, 978 output_mode: Optional[str] = "binary_mask", 979 ) -> List[Dict[str, Any]]: 980 """Generate instance segmentation for the currently initialized image. 981 982 Args: 983 center_distance_threshold: Center distance predictions below this value will be 984 used to find seeds (intersected with thresholded boundary distance predictions). 985 boundary_distance_threshold: Boundary distance predictions below this value will be 986 used to find seeds (intersected with thresholded center distance predictions). 987 foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid 988 checkerboard artifacts in the prediction. 989 foreground_threshold: Foreground predictions above this value will be used as foreground mask. 990 distance_smoothing: Sigma value for smoothing the distance predictions. 991 min_size: Minimal object size in the segmentation result. 992 output_mode: The form masks are returned in. Pass None to directly return the instance segmentation. 993 994 Returns: 995 The instance segmentation masks. 996 """ 997 if not self.is_initialized: 998 raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.") 999 1000 if foreground_smoothing > 0: 1001 foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing) 1002 else: 1003 foreground = self._foreground 1004 # Further optimization: parallel implementation using elf.parallel functionality. 1005 # (Make sure to expose n_threads to avoid over-subscription in case of outer parallelization) 1006 segmentation = watershed_from_center_and_boundary_distances( 1007 self._center_distances, self._boundary_distances, foreground, 1008 center_distance_threshold=center_distance_threshold, 1009 boundary_distance_threshold=boundary_distance_threshold, 1010 foreground_threshold=foreground_threshold, 1011 distance_smoothing=distance_smoothing, 1012 min_size=min_size, 1013 ) 1014 if output_mode is not None: 1015 segmentation = self._to_masks(segmentation, output_mode) 1016 return segmentation 1017 1018 def get_state(self) -> Dict[str, Any]: 1019 """Get the initialized state of the instance segmenter. 1020 1021 Returns: 1022 Instance segmentation state. 1023 """ 1024 if not self.is_initialized: 1025 raise RuntimeError("The state has not been computed yet. Call initialize first.") 1026 1027 return { 1028 "foreground": self._foreground, 1029 "center_distances": self._center_distances, 1030 "boundary_distances": self._boundary_distances, 1031 } 1032 1033 def set_state(self, state: Dict[str, Any]) -> None: 1034 """Set the state of the instance segmenter. 1035 1036 Args: 1037 state: The instance segmentation state 1038 """ 1039 self._foreground = state["foreground"] 1040 self._center_distances = state["center_distances"] 1041 self._boundary_distances = state["boundary_distances"] 1042 self._is_initialized = True 1043 1044 def clear_state(self): 1045 """Clear the state of the instance segmenter. 1046 """ 1047 self._foreground = None 1048 self._center_distances = None 1049 self._boundary_distances = None 1050 self._is_initialized = False 1051 1052 1053class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder): 1054 """Same as `InstanceSegmentationWithDecoder` but for tiled image embeddings. 1055 """ 1056 1057 @torch.no_grad() 1058 def initialize( 1059 self, 1060 image: np.ndarray, 1061 image_embeddings: Optional[util.ImageEmbeddings] = None, 1062 i: Optional[int] = None, 1063 tile_shape: Optional[Tuple[int, int]] = None, 1064 halo: Optional[Tuple[int, int]] = None, 1065 verbose: bool = False, 1066 pbar_init: Optional[callable] = None, 1067 pbar_update: Optional[callable] = None, 1068 ) -> None: 1069 """Initialize image embeddings and decoder predictions for an image. 1070 1071 Args: 1072 image: The input image, volume or timeseries. 1073 image_embeddings: Optional precomputed image embeddings. 1074 See `util.precompute_image_embeddings` for details. 1075 i: Index for the image data. Required if `image` has three spatial dimensions 1076 or a time dimension and two spatial dimensions. 1077 verbose: Dummy input to be compatible with other function signatures. 1078 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 1079 Can be used together with pbar_update to handle napari progress bar in other thread. 1080 To enables using this function within a threadworker. 1081 pbar_update: Callback to update an external progress bar. 1082 """ 1083 original_size = image.shape[:2] 1084 image_embeddings, tile_shape, halo = _process_tiled_embeddings( 1085 self._predictor, image, image_embeddings, tile_shape, halo 1086 ) 1087 tiling = blocking([0, 0], original_size, tile_shape) 1088 1089 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 1090 pbar_init(tiling.numberOfBlocks, "Initialize tiled instance segmentation with decoder") 1091 1092 foreground = np.zeros(original_size, dtype="float32") 1093 center_distances = np.zeros(original_size, dtype="float32") 1094 boundary_distances = np.zeros(original_size, dtype="float32") 1095 1096 for tile_id in range(tiling.numberOfBlocks): 1097 1098 # Get the image embeddings from the predictor for this tile. 1099 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i, tile_id=tile_id) 1100 embeddings = self._predictor.features 1101 input_shape = tuple(self._predictor.input_size) 1102 original_shape = tuple(self._predictor.original_size) 1103 1104 # Predict with the UNETR decoder for this tile. 1105 output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0) 1106 assert output.shape[0] == 3, f"{output.shape}" 1107 1108 # Set the predictions in the output for this tile. 1109 block = tiling.getBlockWithHalo(tile_id, halo=list(halo)) 1110 local_bb = tuple( 1111 slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end) 1112 ) 1113 inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end)) 1114 1115 foreground[inner_bb] = output[0][local_bb] 1116 center_distances[inner_bb] = output[1][local_bb] 1117 boundary_distances[inner_bb] = output[2][local_bb] 1118 pbar_close() 1119 1120 # Set the state. 1121 self._foreground = foreground 1122 self._center_distances = center_distances 1123 self._boundary_distances = boundary_distances 1124 self._is_initialized = True 1125 1126 1127def get_amg( 1128 predictor: SamPredictor, 1129 is_tiled: bool, 1130 decoder: Optional[torch.nn.Module] = None, 1131 **kwargs, 1132) -> Union[AMGBase, InstanceSegmentationWithDecoder]: 1133 """Get the automatic mask generator class. 1134 1135 Args: 1136 predictor: The segment anything predictor. 1137 is_tiled: Whether tiled embeddings are used. 1138 decoder: Decoder to predict instacne segmmentation. 1139 kwargs: The keyword arguments for the amg class. 1140 1141 Returns: 1142 The automatic mask generator. 1143 """ 1144 if decoder is None: 1145 segmenter = TiledAutomaticMaskGenerator(predictor, **kwargs) if is_tiled else\ 1146 AutomaticMaskGenerator(predictor, **kwargs) 1147 else: 1148 segmenter = TiledInstanceSegmentationWithDecoder(predictor, decoder, **kwargs) if is_tiled else\ 1149 InstanceSegmentationWithDecoder(predictor, decoder, **kwargs) 1150 return segmenter
47def mask_data_to_segmentation( 48 masks: List[Dict[str, Any]], 49 with_background: bool, 50 min_object_size: int = 0, 51 max_object_size: Optional[int] = None, 52 label_masks: bool = True, 53) -> np.ndarray: 54 """Convert the output of the automatic mask generation to an instance segmentation. 55 56 Args: 57 masks: The outputs generated by AutomaticMaskGenerator or EmbeddingMaskGenerator. 58 Only supports output_mode=binary_mask. 59 with_background: Whether the segmentation has background. If yes this function assures that the largest 60 object in the output will be mapped to zero (the background value). 61 min_object_size: The minimal size of an object in pixels. 62 max_object_size: The maximal size of an object in pixels. 63<<<<<<< HEAD 64 65======= 66 label_masks: Whether to apply connected components to the result before remving small objects. 67>>>>>>> master 68 Returns: 69 The instance segmentation. 70 """ 71 72 masks = sorted(masks, key=(lambda x: x["area"]), reverse=True) 73 # we could also get the shape from the crop box 74 shape = next(iter(masks))["segmentation"].shape 75 segmentation = np.zeros(shape, dtype="uint32") 76 77 def require_numpy(mask): 78 return mask.cpu().numpy() if torch.is_tensor(mask) else mask 79 80 seg_id = 1 81 for mask in masks: 82 if mask["area"] < min_object_size: 83 continue 84 if max_object_size is not None and mask["area"] > max_object_size: 85 continue 86 87 this_seg_id = mask.get("seg_id", seg_id) 88 segmentation[require_numpy(mask["segmentation"])] = this_seg_id 89 seg_id = this_seg_id + 1 90 91 if label_masks: 92 segmentation = label(segmentation) 93 seg_ids, sizes = np.unique(segmentation, return_counts=True) 94 95 # In some cases objects may be smaller than peviously calculated, 96 # since they are covered by other objects. We ensure these also get 97 # filtered out here. 98 filter_ids = seg_ids[sizes < min_object_size] 99 100 # If we run segmentation with background we also map the largest segment 101 # (the most likely background object) to zero. This is often zero already, 102 # but it does not hurt to reset that to zero either. 103 if with_background: 104 bg_id = seg_ids[np.argmax(sizes)] 105 filter_ids = np.concatenate([filter_ids, [bg_id]]) 106 107 segmentation[np.isin(segmentation, filter_ids)] = 0 108 segmentation = relabel_sequential(segmentation)[0] 109 110 return segmentation
Convert the output of the automatic mask generation to an instance segmentation.
Args:
masks: The outputs generated by AutomaticMaskGenerator or EmbeddingMaskGenerator.
Only supports output_mode=binary_mask.
with_background: Whether the segmentation has background. If yes this function assures that the largest
object in the output will be mapped to zero (the background value).
min_object_size: The minimal size of an object in pixels.
max_object_size: The maximal size of an object in pixels.
<<<<<<< HEAD
======= label_masks: Whether to apply connected components to the result before remving small objects.
>>> master Returns: The instance segmentation.
118class AMGBase(ABC): 119 """Base class for the automatic mask generators. 120 """ 121 def __init__(self): 122 # the state that has to be computed by the 'initialize' method of the child classes 123 self._is_initialized = False 124 self._crop_list = None 125 self._crop_boxes = None 126 self._original_size = None 127 128 @property 129 def is_initialized(self): 130 """Whether the mask generator has already been initialized. 131 """ 132 return self._is_initialized 133 134 @property 135 def crop_list(self): 136 """The list of mask data after initialization. 137 """ 138 return self._crop_list 139 140 @property 141 def crop_boxes(self): 142 """The list of crop boxes. 143 """ 144 return self._crop_boxes 145 146 @property 147 def original_size(self): 148 """The original image size. 149 """ 150 return self._original_size 151 152 def _postprocess_batch( 153 self, 154 data, 155 crop_box, 156 original_size, 157 pred_iou_thresh, 158 stability_score_thresh, 159 box_nms_thresh, 160 ): 161 orig_h, orig_w = original_size 162 163 # filter by predicted IoU 164 if pred_iou_thresh > 0.0: 165 keep_mask = data["iou_preds"] > pred_iou_thresh 166 data.filter(keep_mask) 167 168 # filter by stability score 169 if stability_score_thresh > 0.0: 170 keep_mask = data["stability_score"] >= stability_score_thresh 171 data.filter(keep_mask) 172 173 # filter boxes that touch crop boundaries 174 keep_mask = ~amg_utils.is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) 175 if not torch.all(keep_mask): 176 data.filter(keep_mask) 177 178 # remove duplicates within this crop. 179 keep_by_nms = batched_nms( 180 data["boxes"].float(), 181 data["iou_preds"], 182 torch.zeros_like(data["boxes"][:, 0]), # categories 183 iou_threshold=box_nms_thresh, 184 ) 185 data.filter(keep_by_nms) 186 187 # return to the original image frame 188 data["boxes"] = amg_utils.uncrop_boxes_xyxy(data["boxes"], crop_box) 189 data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) 190 # the data from embedding based segmentation doesn't have the points 191 # so we skip if the corresponding key can't be found 192 try: 193 data["points"] = amg_utils.uncrop_points(data["points"], crop_box) 194 except KeyError: 195 pass 196 197 return data 198 199 def _postprocess_small_regions(self, mask_data, min_area, nms_thresh): 200 201 if len(mask_data["rles"]) == 0: 202 return mask_data 203 204 # filter small disconnected regions and holes 205 new_masks = [] 206 scores = [] 207 for rle in mask_data["rles"]: 208 mask = amg_utils.rle_to_mask(rle) 209 210 mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="holes") 211 unchanged = not changed 212 mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="islands") 213 unchanged = unchanged and not changed 214 215 new_masks.append(torch.as_tensor(mask, dtype=torch.int).unsqueeze(0)) 216 # give score=0 to changed masks and score=1 to unchanged masks 217 # so NMS will prefer ones that didn't need postprocessing 218 scores.append(float(unchanged)) 219 220 # recalculate boxes and remove any new duplicates 221 masks = torch.cat(new_masks, dim=0) 222 boxes = batched_mask_to_box(masks) 223 keep_by_nms = batched_nms( 224 boxes.float(), 225 torch.as_tensor(scores, dtype=torch.float), 226 torch.zeros_like(boxes[:, 0]), # categories 227 iou_threshold=nms_thresh, 228 ) 229 230 # only recalculate RLEs for masks that have changed 231 for i_mask in keep_by_nms: 232 if scores[i_mask] == 0.0: 233 mask_torch = masks[i_mask].unsqueeze(0) 234 # mask_data["rles"][i_mask] = amg_utils.mask_to_rle_pytorch(mask_torch)[0] 235 mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] 236 mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly 237 mask_data.filter(keep_by_nms) 238 239 return mask_data 240 241 def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode): 242 # filter small disconnected regions and holes in masks 243 if min_mask_region_area > 0: 244 mask_data = self._postprocess_small_regions( 245 mask_data, 246 min_mask_region_area, 247 max(box_nms_thresh, crop_nms_thresh), 248 ) 249 250 # encode masks 251 if output_mode == "coco_rle": 252 mask_data["segmentations"] = [amg_utils.coco_encode_rle(rle) for rle in mask_data["rles"]] 253 elif output_mode == "binary_mask": 254 mask_data["segmentations"] = [amg_utils.rle_to_mask(rle) for rle in mask_data["rles"]] 255 else: 256 mask_data["segmentations"] = mask_data["rles"] 257 258 # write mask records 259 curr_anns = [] 260 for idx in range(len(mask_data["segmentations"])): 261 ann = { 262 "segmentation": mask_data["segmentations"][idx], 263 "area": amg_utils.area_from_rle(mask_data["rles"][idx]), 264 "bbox": amg_utils.box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), 265 "predicted_iou": mask_data["iou_preds"][idx].item(), 266 "stability_score": mask_data["stability_score"][idx].item(), 267 "crop_box": amg_utils.box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), 268 } 269 # the data from embedding based segmentation doesn't have the points 270 # so we skip if the corresponding key can't be found 271 try: 272 ann["point_coords"] = [mask_data["points"][idx].tolist()] 273 except KeyError: 274 pass 275 curr_anns.append(ann) 276 277 return curr_anns 278 279 def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None): 280 orig_h, orig_w = original_size 281 282 # serialize predictions and store in MaskData 283 data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1)) 284 if points is not None: 285 data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0), dtype=torch.float) 286 287 del masks 288 289 # calculate the stability scores 290 data["stability_score"] = amg_utils.calculate_stability_score( 291 data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset 292 ) 293 294 # threshold masks and calculate boxes 295 data["masks"] = data["masks"] > self._predictor.model.mask_threshold 296 data["masks"] = data["masks"].type(torch.bool) 297 data["boxes"] = batched_mask_to_box(data["masks"]) 298 299 # compress to RLE 300 data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w) 301 # data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"]) 302 data["rles"] = mask_to_rle_pytorch(data["masks"]) 303 del data["masks"] 304 305 return data 306 307 def get_state(self) -> Dict[str, Any]: 308 """Get the initialized state of the mask generator. 309 310 Returns: 311 State of the mask generator. 312 """ 313 if not self.is_initialized: 314 raise RuntimeError("The state has not been computed yet. Call initialize first.") 315 316 return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size} 317 318 def set_state(self, state: Dict[str, Any]) -> None: 319 """Set the state of the mask generator. 320 321 Args: 322 state: The state of the mask generator, e.g. from serialized state. 323 """ 324 self._crop_list = state["crop_list"] 325 self._crop_boxes = state["crop_boxes"] 326 self._original_size = state["original_size"] 327 self._is_initialized = True 328 329 def clear_state(self): 330 """Clear the state of the mask generator. 331 """ 332 self._crop_list = None 333 self._crop_boxes = None 334 self._original_size = None 335 self._is_initialized = False
Base class for the automatic mask generators.
128 @property 129 def is_initialized(self): 130 """Whether the mask generator has already been initialized. 131 """ 132 return self._is_initialized
Whether the mask generator has already been initialized.
134 @property 135 def crop_list(self): 136 """The list of mask data after initialization. 137 """ 138 return self._crop_list
The list of mask data after initialization.
140 @property 141 def crop_boxes(self): 142 """The list of crop boxes. 143 """ 144 return self._crop_boxes
The list of crop boxes.
146 @property 147 def original_size(self): 148 """The original image size. 149 """ 150 return self._original_size
The original image size.
307 def get_state(self) -> Dict[str, Any]: 308 """Get the initialized state of the mask generator. 309 310 Returns: 311 State of the mask generator. 312 """ 313 if not self.is_initialized: 314 raise RuntimeError("The state has not been computed yet. Call initialize first.") 315 316 return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size}
Get the initialized state of the mask generator.
Returns:
State of the mask generator.
318 def set_state(self, state: Dict[str, Any]) -> None: 319 """Set the state of the mask generator. 320 321 Args: 322 state: The state of the mask generator, e.g. from serialized state. 323 """ 324 self._crop_list = state["crop_list"] 325 self._crop_boxes = state["crop_boxes"] 326 self._original_size = state["original_size"] 327 self._is_initialized = True
Set the state of the mask generator.
Arguments:
- state: The state of the mask generator, e.g. from serialized state.
338class AutomaticMaskGenerator(AMGBase): 339 """Generates an instance segmentation without prompts, using a point grid. 340 341 This class implements the same logic as 342 https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py 343 It decouples the computationally expensive steps of generating masks from the cheap post-processing operation 344 to filter these masks to enable grid search and interactively changing the post-processing. 345 346 Use this class as follows: 347 ```python 348 amg = AutomaticMaskGenerator(predictor) 349 amg.initialize(image) # Initialize the masks, this takes care of all expensive computations. 350 masks = amg.generate(pred_iou_thresh=0.8) # Generate the masks. This is fast and enables testing parameters 351 ``` 352 353 Args: 354 predictor: The segment anything predictor. 355 points_per_side: The number of points to be sampled along one side of the image. 356 If None, `point_grids` must provide explicit point sampling. 357 points_per_batch: The number of points run simultaneously by the model. 358 Higher numbers may be faster but use more GPU memory. 359 crop_n_layers: If >0, the mask prediction will be run again on crops of the image. 360 crop_overlap_ratio: Sets the degree to which crops overlap. 361 crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops. 362 point_grids: A lisst over explicit grids of points used for sampling masks. 363 Normalized to [0, 1] with respect to the image coordinate system. 364 stability_score_offset: The amount to shift the cutoff when calculating the stability score. 365 """ 366 def __init__( 367 self, 368 predictor: SamPredictor, 369 points_per_side: Optional[int] = 32, 370 points_per_batch: Optional[int] = None, 371 crop_n_layers: int = 0, 372 crop_overlap_ratio: float = 512 / 1500, 373 crop_n_points_downscale_factor: int = 1, 374 point_grids: Optional[List[np.ndarray]] = None, 375 stability_score_offset: float = 1.0, 376 ): 377 super().__init__() 378 379 if points_per_side is not None: 380 self.point_grids = amg_utils.build_all_layer_point_grids( 381 points_per_side, 382 crop_n_layers, 383 crop_n_points_downscale_factor, 384 ) 385 elif point_grids is not None: 386 self.point_grids = point_grids 387 else: 388 raise ValueError("Can't have both points_per_side and point_grid be None or not None.") 389 390 self._predictor = predictor 391 self._points_per_side = points_per_side 392 393 # we set the points per batch to 16 for mps for performance reasons 394 # and otherwise keep them at the default of 64 395 if points_per_batch is None: 396 points_per_batch = 16 if str(predictor.device) == "mps" else 64 397 self._points_per_batch = points_per_batch 398 399 self._crop_n_layers = crop_n_layers 400 self._crop_overlap_ratio = crop_overlap_ratio 401 self._crop_n_points_downscale_factor = crop_n_points_downscale_factor 402 self._stability_score_offset = stability_score_offset 403 404 def _process_batch(self, points, im_size, crop_box, original_size): 405 # run model on this batch 406 transformed_points = self._predictor.transform.apply_coords(points, im_size) 407 in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float) 408 in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) 409 masks, iou_preds, _ = self._predictor.predict_torch( 410 in_points[:, None, :], 411 in_labels[:, None], 412 multimask_output=True, 413 return_logits=True, 414 ) 415 data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points) 416 del masks 417 return data 418 419 def _process_crop(self, image, crop_box, crop_layer_idx, precomputed_embeddings, pbar_init=None, pbar_update=None): 420 # Crop the image and calculate embeddings. 421 x0, y0, x1, y1 = crop_box 422 cropped_im = image[y0:y1, x0:x1, :] 423 cropped_im_size = cropped_im.shape[:2] 424 425 if not precomputed_embeddings: 426 self._predictor.set_image(cropped_im) 427 428 # Get the points for this crop. 429 points_scale = np.array(cropped_im_size)[None, ::-1] 430 points_for_image = self.point_grids[crop_layer_idx] * points_scale 431 432 # Generate masks for this crop in batches. 433 data = amg_utils.MaskData() 434 n_batches = len(points_for_image) // self._points_per_batch +\ 435 int(len(points_for_image) % self._points_per_batch != 0) 436 if pbar_init is not None: 437 pbar_init(n_batches, "Predict masks for point grid prompts") 438 439 for (points,) in amg_utils.batch_iterator(self._points_per_batch, points_for_image): 440 batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size) 441 data.cat(batch_data) 442 del batch_data 443 if pbar_update is not None: 444 pbar_update(1) 445 446 if not precomputed_embeddings: 447 self._predictor.reset_image() 448 449 return data 450 451 @torch.no_grad() 452 def initialize( 453 self, 454 image: np.ndarray, 455 image_embeddings: Optional[util.ImageEmbeddings] = None, 456 i: Optional[int] = None, 457 verbose: bool = False, 458 pbar_init: Optional[callable] = None, 459 pbar_update: Optional[callable] = None, 460 ) -> None: 461 """Initialize image embeddings and masks for an image. 462 463 Args: 464 image: The input image, volume or timeseries. 465 image_embeddings: Optional precomputed image embeddings. 466 See `util.precompute_image_embeddings` for details. 467 i: Index for the image data. Required if `image` has three spatial dimensions 468 or a time dimension and two spatial dimensions. 469 verbose: Whether to print computation progress. 470 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 471 Can be used together with pbar_update to handle napari progress bar in other thread. 472 To enables using this function within a threadworker. 473 pbar_update: Callback to update an external progress bar. 474 """ 475 original_size = image.shape[:2] 476 self._original_size = original_size 477 478 crop_boxes, layer_idxs = amg_utils.generate_crop_boxes( 479 original_size, self._crop_n_layers, self._crop_overlap_ratio 480 ) 481 482 # We can set fixed image embeddings if we only have a single crop box (the default setting). 483 # Otherwise we have to recompute the embeddings for each crop and can't precompute. 484 if len(crop_boxes) == 1: 485 if image_embeddings is None: 486 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 487 util.set_precomputed(self._predictor, image_embeddings, i=i) 488 precomputed_embeddings = True 489 else: 490 precomputed_embeddings = False 491 492 # we need to cast to the image representation that is compatible with SAM 493 image = util._to_image(image) 494 495 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 496 497 crop_list = [] 498 for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 499 crop_data = self._process_crop( 500 image, crop_box, layer_idx, 501 precomputed_embeddings=precomputed_embeddings, 502 pbar_init=pbar_init, pbar_update=pbar_update, 503 ) 504 crop_list.append(crop_data) 505 pbar_close() 506 507 self._is_initialized = True 508 self._crop_list = crop_list 509 self._crop_boxes = crop_boxes 510 511 @torch.no_grad() 512 def generate( 513 self, 514 pred_iou_thresh: float = 0.88, 515 stability_score_thresh: float = 0.95, 516 box_nms_thresh: float = 0.7, 517 crop_nms_thresh: float = 0.7, 518 min_mask_region_area: int = 0, 519 output_mode: str = "binary_mask", 520 ) -> List[Dict[str, Any]]: 521 """Generate instance segmentation for the currently initialized image. 522 523 Args: 524 pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model. 525 stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask 526 under changes to the cutoff used to binarize the model prediction. 527 box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks. 528 crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops. 529 min_mask_region_area: Minimal size for the predicted masks. 530 output_mode: The form masks are returned in. 531 532 Returns: 533 The instance segmentation masks. 534 """ 535 if not self.is_initialized: 536 raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.") 537 538 data = amg_utils.MaskData() 539 for data_, crop_box in zip(self.crop_list, self.crop_boxes): 540 crop_data = self._postprocess_batch( 541 data=deepcopy(data_), 542 crop_box=crop_box, original_size=self.original_size, 543 pred_iou_thresh=pred_iou_thresh, 544 stability_score_thresh=stability_score_thresh, 545 box_nms_thresh=box_nms_thresh 546 ) 547 data.cat(crop_data) 548 549 if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0: 550 # Prefer masks from smaller crops 551 scores = 1 / box_area(data["crop_boxes"]) 552 scores = scores.to(data["boxes"].device) 553 keep_by_nms = batched_nms( 554 data["boxes"].float(), 555 scores, 556 torch.zeros_like(data["boxes"][:, 0]), # categories 557 iou_threshold=crop_nms_thresh, 558 ) 559 data.filter(keep_by_nms) 560 561 data.to_numpy() 562 masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode) 563 return masks
Generates an instance segmentation without prompts, using a point grid.
This class implements the same logic as https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py It decouples the computationally expensive steps of generating masks from the cheap post-processing operation to filter these masks to enable grid search and interactively changing the post-processing.
Use this class as follows:
amg = AutomaticMaskGenerator(predictor)
amg.initialize(image) # Initialize the masks, this takes care of all expensive computations.
masks = amg.generate(pred_iou_thresh=0.8) # Generate the masks. This is fast and enables testing parameters
Arguments:
- predictor: The segment anything predictor.
- points_per_side: The number of points to be sampled along one side of the image.
If None,
point_grids
must provide explicit point sampling. - points_per_batch: The number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory.
- crop_n_layers: If >0, the mask prediction will be run again on crops of the image.
- crop_overlap_ratio: Sets the degree to which crops overlap.
- crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops.
- point_grids: A lisst over explicit grids of points used for sampling masks. Normalized to [0, 1] with respect to the image coordinate system.
- stability_score_offset: The amount to shift the cutoff when calculating the stability score.
366 def __init__( 367 self, 368 predictor: SamPredictor, 369 points_per_side: Optional[int] = 32, 370 points_per_batch: Optional[int] = None, 371 crop_n_layers: int = 0, 372 crop_overlap_ratio: float = 512 / 1500, 373 crop_n_points_downscale_factor: int = 1, 374 point_grids: Optional[List[np.ndarray]] = None, 375 stability_score_offset: float = 1.0, 376 ): 377 super().__init__() 378 379 if points_per_side is not None: 380 self.point_grids = amg_utils.build_all_layer_point_grids( 381 points_per_side, 382 crop_n_layers, 383 crop_n_points_downscale_factor, 384 ) 385 elif point_grids is not None: 386 self.point_grids = point_grids 387 else: 388 raise ValueError("Can't have both points_per_side and point_grid be None or not None.") 389 390 self._predictor = predictor 391 self._points_per_side = points_per_side 392 393 # we set the points per batch to 16 for mps for performance reasons 394 # and otherwise keep them at the default of 64 395 if points_per_batch is None: 396 points_per_batch = 16 if str(predictor.device) == "mps" else 64 397 self._points_per_batch = points_per_batch 398 399 self._crop_n_layers = crop_n_layers 400 self._crop_overlap_ratio = crop_overlap_ratio 401 self._crop_n_points_downscale_factor = crop_n_points_downscale_factor 402 self._stability_score_offset = stability_score_offset
451 @torch.no_grad() 452 def initialize( 453 self, 454 image: np.ndarray, 455 image_embeddings: Optional[util.ImageEmbeddings] = None, 456 i: Optional[int] = None, 457 verbose: bool = False, 458 pbar_init: Optional[callable] = None, 459 pbar_update: Optional[callable] = None, 460 ) -> None: 461 """Initialize image embeddings and masks for an image. 462 463 Args: 464 image: The input image, volume or timeseries. 465 image_embeddings: Optional precomputed image embeddings. 466 See `util.precompute_image_embeddings` for details. 467 i: Index for the image data. Required if `image` has three spatial dimensions 468 or a time dimension and two spatial dimensions. 469 verbose: Whether to print computation progress. 470 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 471 Can be used together with pbar_update to handle napari progress bar in other thread. 472 To enables using this function within a threadworker. 473 pbar_update: Callback to update an external progress bar. 474 """ 475 original_size = image.shape[:2] 476 self._original_size = original_size 477 478 crop_boxes, layer_idxs = amg_utils.generate_crop_boxes( 479 original_size, self._crop_n_layers, self._crop_overlap_ratio 480 ) 481 482 # We can set fixed image embeddings if we only have a single crop box (the default setting). 483 # Otherwise we have to recompute the embeddings for each crop and can't precompute. 484 if len(crop_boxes) == 1: 485 if image_embeddings is None: 486 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 487 util.set_precomputed(self._predictor, image_embeddings, i=i) 488 precomputed_embeddings = True 489 else: 490 precomputed_embeddings = False 491 492 # we need to cast to the image representation that is compatible with SAM 493 image = util._to_image(image) 494 495 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 496 497 crop_list = [] 498 for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 499 crop_data = self._process_crop( 500 image, crop_box, layer_idx, 501 precomputed_embeddings=precomputed_embeddings, 502 pbar_init=pbar_init, pbar_update=pbar_update, 503 ) 504 crop_list.append(crop_data) 505 pbar_close() 506 507 self._is_initialized = True 508 self._crop_list = crop_list 509 self._crop_boxes = crop_boxes
Initialize image embeddings and masks for an image.
Arguments:
- image: The input image, volume or timeseries.
- image_embeddings: Optional precomputed image embeddings.
See
util.precompute_image_embeddings
for details. - i: Index for the image data. Required if
image
has three spatial dimensions or a time dimension and two spatial dimensions. - verbose: Whether to print computation progress.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
511 @torch.no_grad() 512 def generate( 513 self, 514 pred_iou_thresh: float = 0.88, 515 stability_score_thresh: float = 0.95, 516 box_nms_thresh: float = 0.7, 517 crop_nms_thresh: float = 0.7, 518 min_mask_region_area: int = 0, 519 output_mode: str = "binary_mask", 520 ) -> List[Dict[str, Any]]: 521 """Generate instance segmentation for the currently initialized image. 522 523 Args: 524 pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model. 525 stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask 526 under changes to the cutoff used to binarize the model prediction. 527 box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks. 528 crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops. 529 min_mask_region_area: Minimal size for the predicted masks. 530 output_mode: The form masks are returned in. 531 532 Returns: 533 The instance segmentation masks. 534 """ 535 if not self.is_initialized: 536 raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.") 537 538 data = amg_utils.MaskData() 539 for data_, crop_box in zip(self.crop_list, self.crop_boxes): 540 crop_data = self._postprocess_batch( 541 data=deepcopy(data_), 542 crop_box=crop_box, original_size=self.original_size, 543 pred_iou_thresh=pred_iou_thresh, 544 stability_score_thresh=stability_score_thresh, 545 box_nms_thresh=box_nms_thresh 546 ) 547 data.cat(crop_data) 548 549 if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0: 550 # Prefer masks from smaller crops 551 scores = 1 / box_area(data["crop_boxes"]) 552 scores = scores.to(data["boxes"].device) 553 keep_by_nms = batched_nms( 554 data["boxes"].float(), 555 scores, 556 torch.zeros_like(data["boxes"][:, 0]), # categories 557 iou_threshold=crop_nms_thresh, 558 ) 559 data.filter(keep_by_nms) 560 561 data.to_numpy() 562 masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode) 563 return masks
Generate instance segmentation for the currently initialized image.
Arguments:
- pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
- stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask under changes to the cutoff used to binarize the model prediction.
- box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
- crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops.
- min_mask_region_area: Minimal size for the predicted masks.
- output_mode: The form masks are returned in.
Returns:
The instance segmentation masks.
Inherited Members
591class TiledAutomaticMaskGenerator(AutomaticMaskGenerator): 592 """Generates an instance segmentation without prompts, using a point grid. 593 594 Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings. 595 596 Args: 597 predictor: The segment anything predictor. 598 points_per_side: The number of points to be sampled along one side of the image. 599 If None, `point_grids` must provide explicit point sampling. 600 points_per_batch: The number of points run simultaneously by the model. 601 Higher numbers may be faster but use more GPU memory. 602 point_grids: A lisst over explicit grids of points used for sampling masks. 603 Normalized to [0, 1] with respect to the image coordinate system. 604 stability_score_offset: The amount to shift the cutoff when calculating the stability score. 605 """ 606 607 # We only expose the arguments that make sense for the tiled mask generator. 608 # Anything related to crops doesn't make sense, because we re-use that functionality 609 # for tiling, so these parameters wouldn't have any effect. 610 def __init__( 611 self, 612 predictor: SamPredictor, 613 points_per_side: Optional[int] = 32, 614 points_per_batch: int = 64, 615 point_grids: Optional[List[np.ndarray]] = None, 616 stability_score_offset: float = 1.0, 617 ) -> None: 618 super().__init__( 619 predictor=predictor, 620 points_per_side=points_per_side, 621 points_per_batch=points_per_batch, 622 point_grids=point_grids, 623 stability_score_offset=stability_score_offset, 624 ) 625 626 @torch.no_grad() 627 def initialize( 628 self, 629 image: np.ndarray, 630 image_embeddings: Optional[util.ImageEmbeddings] = None, 631 i: Optional[int] = None, 632 tile_shape: Optional[Tuple[int, int]] = None, 633 halo: Optional[Tuple[int, int]] = None, 634 verbose: bool = False, 635 pbar_init: Optional[callable] = None, 636 pbar_update: Optional[callable] = None, 637 ) -> None: 638 """Initialize image embeddings and masks for an image. 639 640 Args: 641 image: The input image, volume or timeseries. 642 image_embeddings: Optional precomputed image embeddings. 643 See `util.precompute_image_embeddings` for details. 644 i: Index for the image data. Required if `image` has three spatial dimensions 645 or a time dimension and two spatial dimensions. 646 tile_shape: The tile shape for embedding prediction. 647 halo: The overlap of between tiles. 648 verbose: Whether to print computation progress. 649 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 650 Can be used together with pbar_update to handle napari progress bar in other thread. 651 To enables using this function within a threadworker. 652 pbar_update: Callback to update an external progress bar. 653 """ 654 original_size = image.shape[:2] 655 self._original_size = original_size 656 657 image_embeddings, tile_shape, halo = _process_tiled_embeddings( 658 self._predictor, image, image_embeddings, tile_shape, halo 659 ) 660 661 tiling = blocking([0, 0], original_size, tile_shape) 662 n_tiles = tiling.numberOfBlocks 663 664 # The crop box is always the full local tile. 665 tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)] 666 crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles] 667 668 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 669 pbar_init(n_tiles, "Compute masks for tile") 670 671 # We need to cast to the image representation that is compatible with SAM. 672 image = util._to_image(image) 673 674 mask_data = [] 675 for tile_id in range(n_tiles): 676 # set the pre-computed embeddings for this tile 677 features = image_embeddings["features"][tile_id] 678 tile_embeddings = { 679 "features": features, 680 "input_size": features.attrs["input_size"], 681 "original_size": features.attrs["original_size"], 682 } 683 util.set_precomputed(self._predictor, tile_embeddings, i) 684 685 # compute the mask data for this tile and append it 686 this_mask_data = self._process_crop( 687 image, crop_box=crop_boxes[tile_id], crop_layer_idx=0, precomputed_embeddings=True 688 ) 689 mask_data.append(this_mask_data) 690 pbar_update(1) 691 pbar_close() 692 693 # set the initialized data 694 self._is_initialized = True 695 self._crop_list = mask_data 696 self._crop_boxes = crop_boxes
Generates an instance segmentation without prompts, using a point grid.
Implements the same functionality as AutomaticMaskGenerator
but for tiled embeddings.
Arguments:
- predictor: The segment anything predictor.
- points_per_side: The number of points to be sampled along one side of the image.
If None,
point_grids
must provide explicit point sampling. - points_per_batch: The number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory.
- point_grids: A lisst over explicit grids of points used for sampling masks. Normalized to [0, 1] with respect to the image coordinate system.
- stability_score_offset: The amount to shift the cutoff when calculating the stability score.
610 def __init__( 611 self, 612 predictor: SamPredictor, 613 points_per_side: Optional[int] = 32, 614 points_per_batch: int = 64, 615 point_grids: Optional[List[np.ndarray]] = None, 616 stability_score_offset: float = 1.0, 617 ) -> None: 618 super().__init__( 619 predictor=predictor, 620 points_per_side=points_per_side, 621 points_per_batch=points_per_batch, 622 point_grids=point_grids, 623 stability_score_offset=stability_score_offset, 624 )
626 @torch.no_grad() 627 def initialize( 628 self, 629 image: np.ndarray, 630 image_embeddings: Optional[util.ImageEmbeddings] = None, 631 i: Optional[int] = None, 632 tile_shape: Optional[Tuple[int, int]] = None, 633 halo: Optional[Tuple[int, int]] = None, 634 verbose: bool = False, 635 pbar_init: Optional[callable] = None, 636 pbar_update: Optional[callable] = None, 637 ) -> None: 638 """Initialize image embeddings and masks for an image. 639 640 Args: 641 image: The input image, volume or timeseries. 642 image_embeddings: Optional precomputed image embeddings. 643 See `util.precompute_image_embeddings` for details. 644 i: Index for the image data. Required if `image` has three spatial dimensions 645 or a time dimension and two spatial dimensions. 646 tile_shape: The tile shape for embedding prediction. 647 halo: The overlap of between tiles. 648 verbose: Whether to print computation progress. 649 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 650 Can be used together with pbar_update to handle napari progress bar in other thread. 651 To enables using this function within a threadworker. 652 pbar_update: Callback to update an external progress bar. 653 """ 654 original_size = image.shape[:2] 655 self._original_size = original_size 656 657 image_embeddings, tile_shape, halo = _process_tiled_embeddings( 658 self._predictor, image, image_embeddings, tile_shape, halo 659 ) 660 661 tiling = blocking([0, 0], original_size, tile_shape) 662 n_tiles = tiling.numberOfBlocks 663 664 # The crop box is always the full local tile. 665 tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)] 666 crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles] 667 668 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 669 pbar_init(n_tiles, "Compute masks for tile") 670 671 # We need to cast to the image representation that is compatible with SAM. 672 image = util._to_image(image) 673 674 mask_data = [] 675 for tile_id in range(n_tiles): 676 # set the pre-computed embeddings for this tile 677 features = image_embeddings["features"][tile_id] 678 tile_embeddings = { 679 "features": features, 680 "input_size": features.attrs["input_size"], 681 "original_size": features.attrs["original_size"], 682 } 683 util.set_precomputed(self._predictor, tile_embeddings, i) 684 685 # compute the mask data for this tile and append it 686 this_mask_data = self._process_crop( 687 image, crop_box=crop_boxes[tile_id], crop_layer_idx=0, precomputed_embeddings=True 688 ) 689 mask_data.append(this_mask_data) 690 pbar_update(1) 691 pbar_close() 692 693 # set the initialized data 694 self._is_initialized = True 695 self._crop_list = mask_data 696 self._crop_boxes = crop_boxes
Initialize image embeddings and masks for an image.
Arguments:
- image: The input image, volume or timeseries.
- image_embeddings: Optional precomputed image embeddings.
See
util.precompute_image_embeddings
for details. - i: Index for the image data. Required if
image
has three spatial dimensions or a time dimension and two spatial dimensions. - tile_shape: The tile shape for embedding prediction.
- halo: The overlap of between tiles.
- verbose: Whether to print computation progress.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
704class DecoderAdapter(torch.nn.Module): 705 """Adapter to contain the UNETR decoder in a single module. 706 707 To apply the decoder on top of pre-computed embeddings for 708 the segmentation functionality. 709 See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py 710 """ 711 def __init__(self, unetr): 712 super().__init__() 713 714 self.base = unetr.base 715 self.out_conv = unetr.out_conv 716 self.deconv_out = unetr.deconv_out 717 self.decoder_head = unetr.decoder_head 718 self.final_activation = unetr.final_activation 719 self.postprocess_masks = unetr.postprocess_masks 720 721 self.decoder = unetr.decoder 722 self.deconv1 = unetr.deconv1 723 self.deconv2 = unetr.deconv2 724 self.deconv3 = unetr.deconv3 725 self.deconv4 = unetr.deconv4 726 727 def forward(self, input_, input_shape, original_shape): 728 z12 = input_ 729 730 z9 = self.deconv1(z12) 731 z6 = self.deconv2(z9) 732 z3 = self.deconv3(z6) 733 z0 = self.deconv4(z3) 734 735 updated_from_encoder = [z9, z6, z3] 736 737 x = self.base(z12) 738 x = self.decoder(x, encoder_inputs=updated_from_encoder) 739 x = self.deconv_out(x) 740 741 x = torch.cat([x, z0], dim=1) 742 x = self.decoder_head(x) 743 744 x = self.out_conv(x) 745 if self.final_activation is not None: 746 x = self.final_activation(x) 747 748 x = self.postprocess_masks(x, input_shape, original_shape) 749 return x
Adapter to contain the UNETR decoder in a single module.
To apply the decoder on top of pre-computed embeddings for the segmentation functionality. See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py
711 def __init__(self, unetr): 712 super().__init__() 713 714 self.base = unetr.base 715 self.out_conv = unetr.out_conv 716 self.deconv_out = unetr.deconv_out 717 self.decoder_head = unetr.decoder_head 718 self.final_activation = unetr.final_activation 719 self.postprocess_masks = unetr.postprocess_masks 720 721 self.decoder = unetr.decoder 722 self.deconv1 = unetr.deconv1 723 self.deconv2 = unetr.deconv2 724 self.deconv3 = unetr.deconv3 725 self.deconv4 = unetr.deconv4
Initialize internal Module state, shared by both nn.Module and ScriptModule.
727 def forward(self, input_, input_shape, original_shape): 728 z12 = input_ 729 730 z9 = self.deconv1(z12) 731 z6 = self.deconv2(z9) 732 z3 = self.deconv3(z6) 733 z0 = self.deconv4(z3) 734 735 updated_from_encoder = [z9, z6, z3] 736 737 x = self.base(z12) 738 x = self.decoder(x, encoder_inputs=updated_from_encoder) 739 x = self.deconv_out(x) 740 741 x = torch.cat([x, z0], dim=1) 742 x = self.decoder_head(x) 743 744 x = self.out_conv(x) 745 if self.final_activation is not None: 746 x = self.final_activation(x) 747 748 x = self.postprocess_masks(x, input_shape, original_shape) 749 return x
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
752def get_unetr( 753 image_encoder: torch.nn.Module, 754 decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None, 755 device: Optional[Union[str, torch.device]] = None, 756) -> torch.nn.Module: 757 """Get UNETR model for automatic instance segmentation. 758 759 Args: 760 image_encoder: The image encoder of the SAM model. 761 This is used as encoder by the UNETR too. 762 decoder_state: Optional decoder state to initialize the weights 763 of the UNETR decoder. 764 device: The device. 765 Returns: 766 The UNETR model. 767 """ 768 device = util.get_device(device) 769 770 unetr = UNETR( 771 backbone="sam", 772 encoder=image_encoder, 773 out_channels=3, 774 use_sam_stats=True, 775 final_activation="Sigmoid", 776 use_skip_connection=False, 777 resize_input=True, 778 ) 779 if decoder_state is not None: 780 unetr_state_dict = unetr.state_dict() 781 for k, v in unetr_state_dict.items(): 782 if not k.startswith("encoder"): 783 unetr_state_dict[k] = decoder_state[k] 784 unetr.load_state_dict(unetr_state_dict) 785 786 unetr.to(device) 787 return unetr
Get UNETR model for automatic instance segmentation.
Arguments:
- image_encoder: The image encoder of the SAM model. This is used as encoder by the UNETR too.
- decoder_state: Optional decoder state to initialize the weights of the UNETR decoder.
- device: The device.
Returns:
The UNETR model.
790def get_decoder( 791 image_encoder: torch.nn.Module, 792 decoder_state: OrderedDict[str, torch.Tensor], 793 device: Optional[Union[str, torch.device]] = None, 794) -> DecoderAdapter: 795 """Get decoder to predict outputs for automatic instance segmentation 796 797 Args: 798 image_encoder: The image encoder of the SAM model. 799 decoder_state: State to initialize the weights of the UNETR decoder. 800 device: The device. 801 Returns: 802 The decoder for instance segmentation. 803 """ 804 unetr = get_unetr(image_encoder, decoder_state, device) 805 return DecoderAdapter(unetr)
Get decoder to predict outputs for automatic instance segmentation
Arguments:
- image_encoder: The image encoder of the SAM model.
- decoder_state: State to initialize the weights of the UNETR decoder.
- device: The device.
Returns:
The decoder for instance segmentation.
808def get_predictor_and_decoder( 809 model_type: str, 810 checkpoint_path: Union[str, os.PathLike], 811 device: Optional[Union[str, torch.device]] = None, 812 peft_kwargs: Optional[Dict] = None, 813) -> Tuple[SamPredictor, DecoderAdapter]: 814 """Load the SAM model (predictor) and instance segmentation decoder. 815 816 This requires a checkpoint that contains the state for both predictor 817 and decoder. 818 819 Args: 820 model_type: The type of the image encoder used in the SAM model. 821 checkpoint_path: Path to the checkpoint from which to load the data. 822 device: The device. 823 lora_rank: The rank for low rank adaptation of the attention layers. 824 825 Returns: 826 The SAM predictor. 827 The decoder for instance segmentation. 828 """ 829 device = util.get_device(device) 830 predictor, state = util.get_sam_model( 831 model_type=model_type, 832 checkpoint_path=checkpoint_path, 833 device=device, 834 return_state=True, 835 peft_kwargs=peft_kwargs, 836 ) 837 if "decoder_state" not in state: 838 raise ValueError( 839 f"The checkpoint at '{checkpoint_path}' or the chosen model '{model_type}' does not contain a decoder state" 840 ) 841 decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device) 842 return predictor, decoder
Load the SAM model (predictor) and instance segmentation decoder.
This requires a checkpoint that contains the state for both predictor and decoder.
Arguments:
- model_type: The type of the image encoder used in the SAM model.
- checkpoint_path: Path to the checkpoint from which to load the data.
- device: The device.
- lora_rank: The rank for low rank adaptation of the attention layers.
Returns:
The SAM predictor. The decoder for instance segmentation.
845class InstanceSegmentationWithDecoder: 846 """Generates an instance segmentation without prompts, using a decoder. 847 848 Implements the same interface as `AutomaticMaskGenerator`. 849 850 Use this class as follows: 851 ```python 852 segmenter = InstanceSegmentationWithDecoder(predictor, decoder) 853 segmenter.initialize(image) # Predict the image embeddings and decoder outputs. 854 masks = segmenter.generate(center_distance_threshold=0.75) # Generate the instance segmentation. 855 ``` 856 857 Args: 858 predictor: The segment anything predictor. 859 decoder: The decoder to predict intermediate representations 860 for instance segmentation. 861 """ 862 def __init__( 863 self, 864 predictor: SamPredictor, 865 decoder: torch.nn.Module, 866 ) -> None: 867 self._predictor = predictor 868 self._decoder = decoder 869 870 # The decoder outputs. 871 self._foreground = None 872 self._center_distances = None 873 self._boundary_distances = None 874 875 self._is_initialized = False 876 877 @property 878 def is_initialized(self): 879 """Whether the mask generator has already been initialized. 880 """ 881 return self._is_initialized 882 883 @torch.no_grad() 884 def initialize( 885 self, 886 image: np.ndarray, 887 image_embeddings: Optional[util.ImageEmbeddings] = None, 888 i: Optional[int] = None, 889 verbose: bool = False, 890 pbar_init: Optional[callable] = None, 891 pbar_update: Optional[callable] = None, 892 ) -> None: 893 """Initialize image embeddings and decoder predictions for an image. 894 895 Args: 896 image: The input image, volume or timeseries. 897 image_embeddings: Optional precomputed image embeddings. 898 See `util.precompute_image_embeddings` for details. 899 i: Index for the image data. Required if `image` has three spatial dimensions 900 or a time dimension and two spatial dimensions. 901 verbose: Whether to be verbose. 902 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 903 Can be used together with pbar_update to handle napari progress bar in other thread. 904 To enables using this function within a threadworker. 905 pbar_update: Callback to update an external progress bar. 906 """ 907 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 908 pbar_init(1, "Initialize instance segmentation with decoder") 909 910 if image_embeddings is None: 911 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 912 913 # Get the image embeddings from the predictor. 914 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i) 915 embeddings = self._predictor.features 916 input_shape = tuple(self._predictor.input_size) 917 original_shape = tuple(self._predictor.original_size) 918 919 # Run prediction with the UNETR decoder. 920 output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0) 921 assert output.shape[0] == 3, f"{output.shape}" 922 pbar_update(1) 923 pbar_close() 924 925 # Set the state. 926 self._foreground = output[0] 927 self._center_distances = output[1] 928 self._boundary_distances = output[2] 929 self._is_initialized = True 930 931 def _to_masks(self, segmentation, output_mode): 932 if output_mode != "binary_mask": 933 raise NotImplementedError 934 935 props = regionprops(segmentation) 936 ndim = segmentation.ndim 937 assert ndim in (2, 3) 938 939 shape = segmentation.shape 940 if ndim == 2: 941 crop_box = [0, shape[1], 0, shape[0]] 942 else: 943 crop_box = [0, shape[2], 0, shape[1], 0, shape[0]] 944 945 # go from skimage bbox in format [y0, x0, y1, x1] to SAM format [x0, w, y0, h] 946 def to_bbox_2d(bbox): 947 y0, x0 = bbox[0], bbox[1] 948 w = bbox[3] - x0 949 h = bbox[2] - y0 950 return [x0, w, y0, h] 951 952 def to_bbox_3d(bbox): 953 z0, y0, x0 = bbox[0], bbox[1], bbox[2] 954 w = bbox[5] - x0 955 h = bbox[4] - y0 956 d = bbox[3] - y0 957 return [x0, w, y0, h, z0, d] 958 959 to_bbox = to_bbox_2d if ndim == 2 else to_bbox_3d 960 masks = [ 961 { 962 "segmentation": segmentation == prop.label, 963 "area": prop.area, 964 "bbox": to_bbox(prop.bbox), 965 "crop_box": crop_box, 966 "seg_id": prop.label, 967 } for prop in props 968 ] 969 return masks 970 971 def generate( 972 self, 973 center_distance_threshold: float = 0.5, 974 boundary_distance_threshold: float = 0.5, 975 foreground_threshold: float = 0.5, 976 foreground_smoothing: float = 1.0, 977 distance_smoothing: float = 1.6, 978 min_size: int = 0, 979 output_mode: Optional[str] = "binary_mask", 980 ) -> List[Dict[str, Any]]: 981 """Generate instance segmentation for the currently initialized image. 982 983 Args: 984 center_distance_threshold: Center distance predictions below this value will be 985 used to find seeds (intersected with thresholded boundary distance predictions). 986 boundary_distance_threshold: Boundary distance predictions below this value will be 987 used to find seeds (intersected with thresholded center distance predictions). 988 foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid 989 checkerboard artifacts in the prediction. 990 foreground_threshold: Foreground predictions above this value will be used as foreground mask. 991 distance_smoothing: Sigma value for smoothing the distance predictions. 992 min_size: Minimal object size in the segmentation result. 993 output_mode: The form masks are returned in. Pass None to directly return the instance segmentation. 994 995 Returns: 996 The instance segmentation masks. 997 """ 998 if not self.is_initialized: 999 raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.") 1000 1001 if foreground_smoothing > 0: 1002 foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing) 1003 else: 1004 foreground = self._foreground 1005 # Further optimization: parallel implementation using elf.parallel functionality. 1006 # (Make sure to expose n_threads to avoid over-subscription in case of outer parallelization) 1007 segmentation = watershed_from_center_and_boundary_distances( 1008 self._center_distances, self._boundary_distances, foreground, 1009 center_distance_threshold=center_distance_threshold, 1010 boundary_distance_threshold=boundary_distance_threshold, 1011 foreground_threshold=foreground_threshold, 1012 distance_smoothing=distance_smoothing, 1013 min_size=min_size, 1014 ) 1015 if output_mode is not None: 1016 segmentation = self._to_masks(segmentation, output_mode) 1017 return segmentation 1018 1019 def get_state(self) -> Dict[str, Any]: 1020 """Get the initialized state of the instance segmenter. 1021 1022 Returns: 1023 Instance segmentation state. 1024 """ 1025 if not self.is_initialized: 1026 raise RuntimeError("The state has not been computed yet. Call initialize first.") 1027 1028 return { 1029 "foreground": self._foreground, 1030 "center_distances": self._center_distances, 1031 "boundary_distances": self._boundary_distances, 1032 } 1033 1034 def set_state(self, state: Dict[str, Any]) -> None: 1035 """Set the state of the instance segmenter. 1036 1037 Args: 1038 state: The instance segmentation state 1039 """ 1040 self._foreground = state["foreground"] 1041 self._center_distances = state["center_distances"] 1042 self._boundary_distances = state["boundary_distances"] 1043 self._is_initialized = True 1044 1045 def clear_state(self): 1046 """Clear the state of the instance segmenter. 1047 """ 1048 self._foreground = None 1049 self._center_distances = None 1050 self._boundary_distances = None 1051 self._is_initialized = False
Generates an instance segmentation without prompts, using a decoder.
Implements the same interface as AutomaticMaskGenerator
.
Use this class as follows:
segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
segmenter.initialize(image) # Predict the image embeddings and decoder outputs.
masks = segmenter.generate(center_distance_threshold=0.75) # Generate the instance segmentation.
Arguments:
- predictor: The segment anything predictor.
- decoder: The decoder to predict intermediate representations for instance segmentation.
862 def __init__( 863 self, 864 predictor: SamPredictor, 865 decoder: torch.nn.Module, 866 ) -> None: 867 self._predictor = predictor 868 self._decoder = decoder 869 870 # The decoder outputs. 871 self._foreground = None 872 self._center_distances = None 873 self._boundary_distances = None 874 875 self._is_initialized = False
877 @property 878 def is_initialized(self): 879 """Whether the mask generator has already been initialized. 880 """ 881 return self._is_initialized
Whether the mask generator has already been initialized.
883 @torch.no_grad() 884 def initialize( 885 self, 886 image: np.ndarray, 887 image_embeddings: Optional[util.ImageEmbeddings] = None, 888 i: Optional[int] = None, 889 verbose: bool = False, 890 pbar_init: Optional[callable] = None, 891 pbar_update: Optional[callable] = None, 892 ) -> None: 893 """Initialize image embeddings and decoder predictions for an image. 894 895 Args: 896 image: The input image, volume or timeseries. 897 image_embeddings: Optional precomputed image embeddings. 898 See `util.precompute_image_embeddings` for details. 899 i: Index for the image data. Required if `image` has three spatial dimensions 900 or a time dimension and two spatial dimensions. 901 verbose: Whether to be verbose. 902 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 903 Can be used together with pbar_update to handle napari progress bar in other thread. 904 To enables using this function within a threadworker. 905 pbar_update: Callback to update an external progress bar. 906 """ 907 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 908 pbar_init(1, "Initialize instance segmentation with decoder") 909 910 if image_embeddings is None: 911 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 912 913 # Get the image embeddings from the predictor. 914 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i) 915 embeddings = self._predictor.features 916 input_shape = tuple(self._predictor.input_size) 917 original_shape = tuple(self._predictor.original_size) 918 919 # Run prediction with the UNETR decoder. 920 output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0) 921 assert output.shape[0] == 3, f"{output.shape}" 922 pbar_update(1) 923 pbar_close() 924 925 # Set the state. 926 self._foreground = output[0] 927 self._center_distances = output[1] 928 self._boundary_distances = output[2] 929 self._is_initialized = True
Initialize image embeddings and decoder predictions for an image.
Arguments:
- image: The input image, volume or timeseries.
- image_embeddings: Optional precomputed image embeddings.
See
util.precompute_image_embeddings
for details. - i: Index for the image data. Required if
image
has three spatial dimensions or a time dimension and two spatial dimensions. - verbose: Whether to be verbose.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
971 def generate( 972 self, 973 center_distance_threshold: float = 0.5, 974 boundary_distance_threshold: float = 0.5, 975 foreground_threshold: float = 0.5, 976 foreground_smoothing: float = 1.0, 977 distance_smoothing: float = 1.6, 978 min_size: int = 0, 979 output_mode: Optional[str] = "binary_mask", 980 ) -> List[Dict[str, Any]]: 981 """Generate instance segmentation for the currently initialized image. 982 983 Args: 984 center_distance_threshold: Center distance predictions below this value will be 985 used to find seeds (intersected with thresholded boundary distance predictions). 986 boundary_distance_threshold: Boundary distance predictions below this value will be 987 used to find seeds (intersected with thresholded center distance predictions). 988 foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid 989 checkerboard artifacts in the prediction. 990 foreground_threshold: Foreground predictions above this value will be used as foreground mask. 991 distance_smoothing: Sigma value for smoothing the distance predictions. 992 min_size: Minimal object size in the segmentation result. 993 output_mode: The form masks are returned in. Pass None to directly return the instance segmentation. 994 995 Returns: 996 The instance segmentation masks. 997 """ 998 if not self.is_initialized: 999 raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.") 1000 1001 if foreground_smoothing > 0: 1002 foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing) 1003 else: 1004 foreground = self._foreground 1005 # Further optimization: parallel implementation using elf.parallel functionality. 1006 # (Make sure to expose n_threads to avoid over-subscription in case of outer parallelization) 1007 segmentation = watershed_from_center_and_boundary_distances( 1008 self._center_distances, self._boundary_distances, foreground, 1009 center_distance_threshold=center_distance_threshold, 1010 boundary_distance_threshold=boundary_distance_threshold, 1011 foreground_threshold=foreground_threshold, 1012 distance_smoothing=distance_smoothing, 1013 min_size=min_size, 1014 ) 1015 if output_mode is not None: 1016 segmentation = self._to_masks(segmentation, output_mode) 1017 return segmentation
Generate instance segmentation for the currently initialized image.
Arguments:
- center_distance_threshold: Center distance predictions below this value will be used to find seeds (intersected with thresholded boundary distance predictions).
- boundary_distance_threshold: Boundary distance predictions below this value will be used to find seeds (intersected with thresholded center distance predictions).
- foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid checkerboard artifacts in the prediction.
- foreground_threshold: Foreground predictions above this value will be used as foreground mask.
- distance_smoothing: Sigma value for smoothing the distance predictions.
- min_size: Minimal object size in the segmentation result.
- output_mode: The form masks are returned in. Pass None to directly return the instance segmentation.
Returns:
The instance segmentation masks.
1019 def get_state(self) -> Dict[str, Any]: 1020 """Get the initialized state of the instance segmenter. 1021 1022 Returns: 1023 Instance segmentation state. 1024 """ 1025 if not self.is_initialized: 1026 raise RuntimeError("The state has not been computed yet. Call initialize first.") 1027 1028 return { 1029 "foreground": self._foreground, 1030 "center_distances": self._center_distances, 1031 "boundary_distances": self._boundary_distances, 1032 }
Get the initialized state of the instance segmenter.
Returns:
Instance segmentation state.
1034 def set_state(self, state: Dict[str, Any]) -> None: 1035 """Set the state of the instance segmenter. 1036 1037 Args: 1038 state: The instance segmentation state 1039 """ 1040 self._foreground = state["foreground"] 1041 self._center_distances = state["center_distances"] 1042 self._boundary_distances = state["boundary_distances"] 1043 self._is_initialized = True
Set the state of the instance segmenter.
Arguments:
- state: The instance segmentation state
1054class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder): 1055 """Same as `InstanceSegmentationWithDecoder` but for tiled image embeddings. 1056 """ 1057 1058 @torch.no_grad() 1059 def initialize( 1060 self, 1061 image: np.ndarray, 1062 image_embeddings: Optional[util.ImageEmbeddings] = None, 1063 i: Optional[int] = None, 1064 tile_shape: Optional[Tuple[int, int]] = None, 1065 halo: Optional[Tuple[int, int]] = None, 1066 verbose: bool = False, 1067 pbar_init: Optional[callable] = None, 1068 pbar_update: Optional[callable] = None, 1069 ) -> None: 1070 """Initialize image embeddings and decoder predictions for an image. 1071 1072 Args: 1073 image: The input image, volume or timeseries. 1074 image_embeddings: Optional precomputed image embeddings. 1075 See `util.precompute_image_embeddings` for details. 1076 i: Index for the image data. Required if `image` has three spatial dimensions 1077 or a time dimension and two spatial dimensions. 1078 verbose: Dummy input to be compatible with other function signatures. 1079 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 1080 Can be used together with pbar_update to handle napari progress bar in other thread. 1081 To enables using this function within a threadworker. 1082 pbar_update: Callback to update an external progress bar. 1083 """ 1084 original_size = image.shape[:2] 1085 image_embeddings, tile_shape, halo = _process_tiled_embeddings( 1086 self._predictor, image, image_embeddings, tile_shape, halo 1087 ) 1088 tiling = blocking([0, 0], original_size, tile_shape) 1089 1090 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 1091 pbar_init(tiling.numberOfBlocks, "Initialize tiled instance segmentation with decoder") 1092 1093 foreground = np.zeros(original_size, dtype="float32") 1094 center_distances = np.zeros(original_size, dtype="float32") 1095 boundary_distances = np.zeros(original_size, dtype="float32") 1096 1097 for tile_id in range(tiling.numberOfBlocks): 1098 1099 # Get the image embeddings from the predictor for this tile. 1100 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i, tile_id=tile_id) 1101 embeddings = self._predictor.features 1102 input_shape = tuple(self._predictor.input_size) 1103 original_shape = tuple(self._predictor.original_size) 1104 1105 # Predict with the UNETR decoder for this tile. 1106 output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0) 1107 assert output.shape[0] == 3, f"{output.shape}" 1108 1109 # Set the predictions in the output for this tile. 1110 block = tiling.getBlockWithHalo(tile_id, halo=list(halo)) 1111 local_bb = tuple( 1112 slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end) 1113 ) 1114 inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end)) 1115 1116 foreground[inner_bb] = output[0][local_bb] 1117 center_distances[inner_bb] = output[1][local_bb] 1118 boundary_distances[inner_bb] = output[2][local_bb] 1119 pbar_close() 1120 1121 # Set the state. 1122 self._foreground = foreground 1123 self._center_distances = center_distances 1124 self._boundary_distances = boundary_distances 1125 self._is_initialized = True
Same as InstanceSegmentationWithDecoder
but for tiled image embeddings.
1058 @torch.no_grad() 1059 def initialize( 1060 self, 1061 image: np.ndarray, 1062 image_embeddings: Optional[util.ImageEmbeddings] = None, 1063 i: Optional[int] = None, 1064 tile_shape: Optional[Tuple[int, int]] = None, 1065 halo: Optional[Tuple[int, int]] = None, 1066 verbose: bool = False, 1067 pbar_init: Optional[callable] = None, 1068 pbar_update: Optional[callable] = None, 1069 ) -> None: 1070 """Initialize image embeddings and decoder predictions for an image. 1071 1072 Args: 1073 image: The input image, volume or timeseries. 1074 image_embeddings: Optional precomputed image embeddings. 1075 See `util.precompute_image_embeddings` for details. 1076 i: Index for the image data. Required if `image` has three spatial dimensions 1077 or a time dimension and two spatial dimensions. 1078 verbose: Dummy input to be compatible with other function signatures. 1079 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 1080 Can be used together with pbar_update to handle napari progress bar in other thread. 1081 To enables using this function within a threadworker. 1082 pbar_update: Callback to update an external progress bar. 1083 """ 1084 original_size = image.shape[:2] 1085 image_embeddings, tile_shape, halo = _process_tiled_embeddings( 1086 self._predictor, image, image_embeddings, tile_shape, halo 1087 ) 1088 tiling = blocking([0, 0], original_size, tile_shape) 1089 1090 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 1091 pbar_init(tiling.numberOfBlocks, "Initialize tiled instance segmentation with decoder") 1092 1093 foreground = np.zeros(original_size, dtype="float32") 1094 center_distances = np.zeros(original_size, dtype="float32") 1095 boundary_distances = np.zeros(original_size, dtype="float32") 1096 1097 for tile_id in range(tiling.numberOfBlocks): 1098 1099 # Get the image embeddings from the predictor for this tile. 1100 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i, tile_id=tile_id) 1101 embeddings = self._predictor.features 1102 input_shape = tuple(self._predictor.input_size) 1103 original_shape = tuple(self._predictor.original_size) 1104 1105 # Predict with the UNETR decoder for this tile. 1106 output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0) 1107 assert output.shape[0] == 3, f"{output.shape}" 1108 1109 # Set the predictions in the output for this tile. 1110 block = tiling.getBlockWithHalo(tile_id, halo=list(halo)) 1111 local_bb = tuple( 1112 slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end) 1113 ) 1114 inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end)) 1115 1116 foreground[inner_bb] = output[0][local_bb] 1117 center_distances[inner_bb] = output[1][local_bb] 1118 boundary_distances[inner_bb] = output[2][local_bb] 1119 pbar_close() 1120 1121 # Set the state. 1122 self._foreground = foreground 1123 self._center_distances = center_distances 1124 self._boundary_distances = boundary_distances 1125 self._is_initialized = True
Initialize image embeddings and decoder predictions for an image.
Arguments:
- image: The input image, volume or timeseries.
- image_embeddings: Optional precomputed image embeddings.
See
util.precompute_image_embeddings
for details. - i: Index for the image data. Required if
image
has three spatial dimensions or a time dimension and two spatial dimensions. - verbose: Dummy input to be compatible with other function signatures.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
1128def get_amg( 1129 predictor: SamPredictor, 1130 is_tiled: bool, 1131 decoder: Optional[torch.nn.Module] = None, 1132 **kwargs, 1133) -> Union[AMGBase, InstanceSegmentationWithDecoder]: 1134 """Get the automatic mask generator class. 1135 1136 Args: 1137 predictor: The segment anything predictor. 1138 is_tiled: Whether tiled embeddings are used. 1139 decoder: Decoder to predict instacne segmmentation. 1140 kwargs: The keyword arguments for the amg class. 1141 1142 Returns: 1143 The automatic mask generator. 1144 """ 1145 if decoder is None: 1146 segmenter = TiledAutomaticMaskGenerator(predictor, **kwargs) if is_tiled else\ 1147 AutomaticMaskGenerator(predictor, **kwargs) 1148 else: 1149 segmenter = TiledInstanceSegmentationWithDecoder(predictor, decoder, **kwargs) if is_tiled else\ 1150 InstanceSegmentationWithDecoder(predictor, decoder, **kwargs) 1151 return segmenter
Get the automatic mask generator class.
Arguments:
- predictor: The segment anything predictor.
- is_tiled: Whether tiled embeddings are used.
- decoder: Decoder to predict instacne segmmentation.
- kwargs: The keyword arguments for the amg class.
Returns:
The automatic mask generator.