micro_sam.instance_segmentation
Automated instance segmentation functionality. The classes implemented here extend the automatic instance segmentation from Segment Anything: https://computational-cell-analytics.github.io/micro-sam/micro_sam.html
1"""Automated instance segmentation functionality. 2The classes implemented here extend the automatic instance segmentation from Segment Anything: 3https://computational-cell-analytics.github.io/micro-sam/micro_sam.html 4""" 5 6import os 7import warnings 8from abc import ABC 9from copy import deepcopy 10from collections import OrderedDict 11from typing import Any, Dict, List, Optional, Tuple, Union 12 13import vigra 14import numpy as np 15from skimage.measure import label, regionprops 16from skimage.segmentation import relabel_sequential 17 18import torch 19from torchvision.ops.boxes import batched_nms, box_area 20 21from torch_em.model import UNETR 22from torch_em.util.segmentation import watershed_from_center_and_boundary_distances 23 24import elf.parallel as parallel 25from elf.parallel.filters import apply_filter 26 27from nifty.tools import blocking 28 29import segment_anything.utils.amg as amg_utils 30from segment_anything.predictor import SamPredictor 31 32from . import util 33from ._vendored import batched_mask_to_box, mask_to_rle_pytorch 34 35# 36# Utility Functionality 37# 38 39 40class _FakeInput: 41 def __init__(self, shape): 42 self.shape = shape 43 44 def __getitem__(self, index): 45 block_shape = tuple(ind.stop - ind.start for ind in index) 46 return np.zeros(block_shape, dtype="float32") 47 48 49def mask_data_to_segmentation( 50 masks: List[Dict[str, Any]], 51 with_background: bool, 52 min_object_size: int = 0, 53 max_object_size: Optional[int] = None, 54 label_masks: bool = True, 55) -> np.ndarray: 56 """Convert the output of the automatic mask generation to an instance segmentation. 57 58 Args: 59 masks: The outputs generated by AutomaticMaskGenerator or EmbeddingMaskGenerator. 60 Only supports output_mode=binary_mask. 61 with_background: Whether the segmentation has background. If yes this function assures that the largest 62 object in the output will be mapped to zero (the background value). 63 min_object_size: The minimal size of an object in pixels. By default, set to '0'. 64 max_object_size: The maximal size of an object in pixels. 65 label_masks: Whether to apply connected components to the result before removing small objects. 66 By default, set to 'True'. 67 68 Returns: 69 The instance segmentation. 70 """ 71 72 masks = sorted(masks, key=(lambda x: x["area"]), reverse=True) 73 # we could also get the shape from the crop box 74 shape = next(iter(masks))["segmentation"].shape 75 segmentation = np.zeros(shape, dtype="uint32") 76 77 def require_numpy(mask): 78 return mask.cpu().numpy() if torch.is_tensor(mask) else mask 79 80 seg_id = 1 81 for mask in masks: 82 if mask["area"] < min_object_size: 83 continue 84 if max_object_size is not None and mask["area"] > max_object_size: 85 continue 86 87 this_seg_id = mask.get("seg_id", seg_id) 88 segmentation[require_numpy(mask["segmentation"])] = this_seg_id 89 seg_id = this_seg_id + 1 90 91 if label_masks: 92 segmentation = label(segmentation).astype(segmentation.dtype) 93 94 seg_ids, sizes = np.unique(segmentation, return_counts=True) 95 96 # In some cases objects may be smaller than peviously calculated, 97 # since they are covered by other objects. We ensure these also get 98 # filtered out here. 99 filter_ids = seg_ids[sizes < min_object_size] 100 101 # If we run segmentation with background we also map the largest segment 102 # (the most likely background object) to zero. This is often zero already, 103 # but it does not hurt to reset that to zero either. 104 if with_background: 105 bg_id = seg_ids[np.argmax(sizes)] 106 filter_ids = np.concatenate([filter_ids, [bg_id]]) 107 108 segmentation[np.isin(segmentation, filter_ids)] = 0 109 segmentation = relabel_sequential(segmentation)[0] 110 111 return segmentation 112 113 114# 115# Classes for automatic instance segmentation 116# 117 118 119class AMGBase(ABC): 120 """Base class for the automatic mask generators. 121 """ 122 def __init__(self): 123 # the state that has to be computed by the 'initialize' method of the child classes 124 self._is_initialized = False 125 self._crop_list = None 126 self._crop_boxes = None 127 self._original_size = None 128 129 @property 130 def is_initialized(self): 131 """Whether the mask generator has already been initialized. 132 """ 133 return self._is_initialized 134 135 @property 136 def crop_list(self): 137 """The list of mask data after initialization. 138 """ 139 return self._crop_list 140 141 @property 142 def crop_boxes(self): 143 """The list of crop boxes. 144 """ 145 return self._crop_boxes 146 147 @property 148 def original_size(self): 149 """The original image size. 150 """ 151 return self._original_size 152 153 def _postprocess_batch( 154 self, 155 data, 156 crop_box, 157 original_size, 158 pred_iou_thresh, 159 stability_score_thresh, 160 box_nms_thresh, 161 ): 162 orig_h, orig_w = original_size 163 164 # filter by predicted IoU 165 if pred_iou_thresh > 0.0: 166 keep_mask = data["iou_preds"] > pred_iou_thresh 167 data.filter(keep_mask) 168 169 # filter by stability score 170 if stability_score_thresh > 0.0: 171 keep_mask = data["stability_score"] >= stability_score_thresh 172 data.filter(keep_mask) 173 174 # filter boxes that touch crop boundaries 175 keep_mask = ~amg_utils.is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) 176 if not torch.all(keep_mask): 177 data.filter(keep_mask) 178 179 # remove duplicates within this crop. 180 keep_by_nms = batched_nms( 181 data["boxes"].float(), 182 data["iou_preds"], 183 torch.zeros_like(data["boxes"][:, 0]), # categories 184 iou_threshold=box_nms_thresh, 185 ) 186 data.filter(keep_by_nms) 187 188 # return to the original image frame 189 data["boxes"] = amg_utils.uncrop_boxes_xyxy(data["boxes"], crop_box) 190 data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) 191 # the data from embedding based segmentation doesn't have the points 192 # so we skip if the corresponding key can't be found 193 try: 194 data["points"] = amg_utils.uncrop_points(data["points"], crop_box) 195 except KeyError: 196 pass 197 198 return data 199 200 def _postprocess_small_regions(self, mask_data, min_area, nms_thresh): 201 202 if len(mask_data["rles"]) == 0: 203 return mask_data 204 205 # filter small disconnected regions and holes 206 new_masks = [] 207 scores = [] 208 for rle in mask_data["rles"]: 209 mask = amg_utils.rle_to_mask(rle) 210 211 mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="holes") 212 unchanged = not changed 213 mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="islands") 214 unchanged = unchanged and not changed 215 216 new_masks.append(torch.as_tensor(mask, dtype=torch.int).unsqueeze(0)) 217 # give score=0 to changed masks and score=1 to unchanged masks 218 # so NMS will prefer ones that didn't need postprocessing 219 scores.append(float(unchanged)) 220 221 # recalculate boxes and remove any new duplicates 222 masks = torch.cat(new_masks, dim=0) 223 boxes = batched_mask_to_box(masks.to(torch.bool)) # Casting this to boolean as we work with one-hot labels. 224 keep_by_nms = batched_nms( 225 boxes.float(), 226 torch.as_tensor(scores, dtype=torch.float), 227 torch.zeros_like(boxes[:, 0]), # categories 228 iou_threshold=nms_thresh, 229 ) 230 231 # only recalculate RLEs for masks that have changed 232 for i_mask in keep_by_nms: 233 if scores[i_mask] == 0.0: 234 mask_torch = masks[i_mask].unsqueeze(0) 235 # mask_data["rles"][i_mask] = amg_utils.mask_to_rle_pytorch(mask_torch)[0] 236 mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] 237 mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly 238 mask_data.filter(keep_by_nms) 239 240 return mask_data 241 242 def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode): 243 # filter small disconnected regions and holes in masks 244 if min_mask_region_area > 0: 245 mask_data = self._postprocess_small_regions( 246 mask_data, 247 min_mask_region_area, 248 max(box_nms_thresh, crop_nms_thresh), 249 ) 250 251 # encode masks 252 if output_mode == "coco_rle": 253 mask_data["segmentations"] = [amg_utils.coco_encode_rle(rle) for rle in mask_data["rles"]] 254 elif output_mode == "binary_mask": 255 mask_data["segmentations"] = [amg_utils.rle_to_mask(rle) for rle in mask_data["rles"]] 256 else: 257 mask_data["segmentations"] = mask_data["rles"] 258 259 # write mask records 260 curr_anns = [] 261 for idx in range(len(mask_data["segmentations"])): 262 ann = { 263 "segmentation": mask_data["segmentations"][idx], 264 "area": amg_utils.area_from_rle(mask_data["rles"][idx]), 265 "bbox": amg_utils.box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), 266 "predicted_iou": mask_data["iou_preds"][idx].item(), 267 "stability_score": mask_data["stability_score"][idx].item(), 268 "crop_box": amg_utils.box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), 269 } 270 # the data from embedding based segmentation doesn't have the points 271 # so we skip if the corresponding key can't be found 272 try: 273 ann["point_coords"] = [mask_data["points"][idx].tolist()] 274 except KeyError: 275 pass 276 curr_anns.append(ann) 277 278 return curr_anns 279 280 def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None): 281 orig_h, orig_w = original_size 282 283 # serialize predictions and store in MaskData 284 data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1)) 285 if points is not None: 286 data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0), dtype=torch.float) 287 288 del masks 289 290 # calculate the stability scores 291 data["stability_score"] = amg_utils.calculate_stability_score( 292 data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset 293 ) 294 295 # threshold masks and calculate boxes 296 data["masks"] = data["masks"] > self._predictor.model.mask_threshold 297 data["masks"] = data["masks"].type(torch.bool) 298 data["boxes"] = batched_mask_to_box(data["masks"]) 299 300 # compress to RLE 301 data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w) 302 # data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"]) 303 data["rles"] = mask_to_rle_pytorch(data["masks"]) 304 del data["masks"] 305 306 return data 307 308 def get_state(self) -> Dict[str, Any]: 309 """Get the initialized state of the mask generator. 310 311 Returns: 312 State of the mask generator. 313 """ 314 if not self.is_initialized: 315 raise RuntimeError("The state has not been computed yet. Call initialize first.") 316 317 return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size} 318 319 def set_state(self, state: Dict[str, Any]) -> None: 320 """Set the state of the mask generator. 321 322 Args: 323 state: The state of the mask generator, e.g. from serialized state. 324 """ 325 self._crop_list = state["crop_list"] 326 self._crop_boxes = state["crop_boxes"] 327 self._original_size = state["original_size"] 328 self._is_initialized = True 329 330 def clear_state(self): 331 """Clear the state of the mask generator. 332 """ 333 self._crop_list = None 334 self._crop_boxes = None 335 self._original_size = None 336 self._is_initialized = False 337 338 339class AutomaticMaskGenerator(AMGBase): 340 """Generates an instance segmentation without prompts, using a point grid. 341 342 This class implements the same logic as 343 https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py 344 It decouples the computationally expensive steps of generating masks from the cheap post-processing operation 345 to filter these masks to enable grid search and interactively changing the post-processing. 346 347 Use this class as follows: 348 ```python 349 amg = AutomaticMaskGenerator(predictor) 350 amg.initialize(image) # Initialize the masks, this takes care of all expensive computations. 351 masks = amg.generate(pred_iou_thresh=0.8) # Generate the masks. This is fast and enables testing parameters 352 ``` 353 354 Args: 355 predictor: The segment anything predictor. 356 points_per_side: The number of points to be sampled along one side of the image. 357 If None, `point_grids` must provide explicit point sampling. By default, set to '32'. 358 points_per_batch: The number of points run simultaneously by the model. 359 Higher numbers may be faster but use more GPU memory. 360 By default, selects '64' for all devices except 'mps' (selects '16' for performance reasons). 361 crop_n_layers: If >0, the mask prediction will be run again on crops of the image. 362 By default, set to '0'. 363 crop_overlap_ratio: Sets the degree to which crops overlap. By default, set to '512 / 1500'. 364 crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops. 365 By default, set to '1'. 366 point_grids: A list over explicit grids of points used for sampling masks. 367 Normalized to [0, 1] with respect to the image coordinate system. 368 stability_score_offset: The amount to shift the cutoff when calculating the stability score. 369 By default, set to '1.0'. 370 """ 371 def __init__( 372 self, 373 predictor: SamPredictor, 374 points_per_side: Optional[int] = 32, 375 points_per_batch: Optional[int] = None, 376 crop_n_layers: int = 0, 377 crop_overlap_ratio: float = 512 / 1500, 378 crop_n_points_downscale_factor: int = 1, 379 point_grids: Optional[List[np.ndarray]] = None, 380 stability_score_offset: float = 1.0, 381 ): 382 super().__init__() 383 384 if points_per_side is not None: 385 self.point_grids = amg_utils.build_all_layer_point_grids( 386 points_per_side, crop_n_layers, crop_n_points_downscale_factor, 387 ) 388 elif point_grids is not None: 389 self.point_grids = point_grids 390 else: 391 raise ValueError("Can't have both points_per_side and point_grid be None or not None.") 392 393 self._predictor = predictor 394 self._points_per_side = points_per_side 395 396 # we set the points per batch to 16 for mps for performance reasons 397 # and otherwise keep them at the default of 64 398 if points_per_batch is None: 399 points_per_batch = 16 if str(predictor.device) == "mps" else 64 400 self._points_per_batch = points_per_batch 401 402 self._crop_n_layers = crop_n_layers 403 self._crop_overlap_ratio = crop_overlap_ratio 404 self._crop_n_points_downscale_factor = crop_n_points_downscale_factor 405 self._stability_score_offset = stability_score_offset 406 407 def _process_batch(self, points, im_size, crop_box, original_size): 408 # run model on this batch 409 transformed_points = self._predictor.transform.apply_coords(points, im_size) 410 in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float) 411 in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) 412 masks, iou_preds, _ = self._predictor.predict_torch( 413 point_coords=in_points[:, None, :], 414 point_labels=in_labels[:, None], 415 multimask_output=True, 416 return_logits=True, 417 ) 418 data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points) 419 del masks 420 return data 421 422 def _process_crop(self, image, crop_box, crop_layer_idx, precomputed_embeddings, pbar_init=None, pbar_update=None): 423 # Crop the image and calculate embeddings. 424 x0, y0, x1, y1 = crop_box 425 cropped_im = image[y0:y1, x0:x1, :] 426 cropped_im_size = cropped_im.shape[:2] 427 428 if not precomputed_embeddings: 429 self._predictor.set_image(cropped_im) 430 431 # Get the points for this crop. 432 points_scale = np.array(cropped_im_size)[None, ::-1] 433 points_for_image = self.point_grids[crop_layer_idx] * points_scale 434 435 # Generate masks for this crop in batches. 436 data = amg_utils.MaskData() 437 n_batches = len(points_for_image) // self._points_per_batch +\ 438 int(len(points_for_image) % self._points_per_batch != 0) 439 if pbar_init is not None: 440 pbar_init(n_batches, "Predict masks for point grid prompts") 441 442 for (points,) in amg_utils.batch_iterator(self._points_per_batch, points_for_image): 443 batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size) 444 data.cat(batch_data) 445 del batch_data 446 if pbar_update is not None: 447 pbar_update(1) 448 449 if not precomputed_embeddings: 450 self._predictor.reset_image() 451 452 return data 453 454 @torch.no_grad() 455 def initialize( 456 self, 457 image: np.ndarray, 458 image_embeddings: Optional[util.ImageEmbeddings] = None, 459 i: Optional[int] = None, 460 verbose: bool = False, 461 pbar_init: Optional[callable] = None, 462 pbar_update: Optional[callable] = None, 463 ) -> None: 464 """Initialize image embeddings and masks for an image. 465 466 Args: 467 image: The input image, volume or timeseries. 468 image_embeddings: Optional precomputed image embeddings. 469 See `util.precompute_image_embeddings` for details. 470 i: Index for the image data. Required if `image` has three spatial dimensions 471 or a time dimension and two spatial dimensions. 472 verbose: Whether to print computation progress. By default, set to 'False'. 473 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 474 Can be used together with pbar_update to handle napari progress bar in other thread. 475 To enables using this function within a threadworker. 476 pbar_update: Callback to update an external progress bar. 477 """ 478 original_size = image.shape[:2] 479 self._original_size = original_size 480 481 crop_boxes, layer_idxs = amg_utils.generate_crop_boxes( 482 original_size, self._crop_n_layers, self._crop_overlap_ratio 483 ) 484 485 # We can set fixed image embeddings if we only have a single crop box (the default setting). 486 # Otherwise we have to recompute the embeddings for each crop and can't precompute. 487 if len(crop_boxes) == 1: 488 if image_embeddings is None: 489 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 490 util.set_precomputed(self._predictor, image_embeddings, i=i) 491 precomputed_embeddings = True 492 else: 493 precomputed_embeddings = False 494 495 # we need to cast to the image representation that is compatible with SAM 496 image = util._to_image(image) 497 498 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 499 500 crop_list = [] 501 for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 502 crop_data = self._process_crop( 503 image, crop_box, layer_idx, 504 precomputed_embeddings=precomputed_embeddings, 505 pbar_init=pbar_init, pbar_update=pbar_update, 506 ) 507 crop_list.append(crop_data) 508 pbar_close() 509 510 self._is_initialized = True 511 self._crop_list = crop_list 512 self._crop_boxes = crop_boxes 513 514 @torch.no_grad() 515 def generate( 516 self, 517 pred_iou_thresh: float = 0.88, 518 stability_score_thresh: float = 0.95, 519 box_nms_thresh: float = 0.7, 520 crop_nms_thresh: float = 0.7, 521 min_mask_region_area: int = 0, 522 output_mode: str = "binary_mask", 523 ) -> List[Dict[str, Any]]: 524 """Generate instance segmentation for the currently initialized image. 525 526 Args: 527 pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model. 528 By default, set to '0.88'. 529 stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask 530 under changes to the cutoff used to binarize the model prediction. By default, set to '0.95'. 531 box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks. 532 By default, set to '0.7'. 533 crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops. 534 By default, set to '0.7'. 535 min_mask_region_area: Minimal size for the predicted masks. By default, set to '0'. 536 output_mode: The form masks are returned in. By default, set to 'binary_mask'. 537 538 Returns: 539 The instance segmentation masks. 540 """ 541 if not self.is_initialized: 542 raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.") 543 544 data = amg_utils.MaskData() 545 for data_, crop_box in zip(self.crop_list, self.crop_boxes): 546 crop_data = self._postprocess_batch( 547 data=deepcopy(data_), 548 crop_box=crop_box, original_size=self.original_size, 549 pred_iou_thresh=pred_iou_thresh, 550 stability_score_thresh=stability_score_thresh, 551 box_nms_thresh=box_nms_thresh 552 ) 553 data.cat(crop_data) 554 555 if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0: 556 # Prefer masks from smaller crops 557 scores = 1 / box_area(data["crop_boxes"]) 558 scores = scores.to(data["boxes"].device) 559 keep_by_nms = batched_nms( 560 data["boxes"].float(), 561 scores, 562 torch.zeros_like(data["boxes"][:, 0]), # categories 563 iou_threshold=crop_nms_thresh, 564 ) 565 data.filter(keep_by_nms) 566 567 data.to_numpy() 568 masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode) 569 return masks 570 571 572# Helper function for tiled embedding computation and checking consistent state. 573def _process_tiled_embeddings(predictor, image, image_embeddings, tile_shape, halo, verbose, batch_size): 574 if image_embeddings is None: 575 if tile_shape is None or halo is None: 576 raise ValueError("To compute tiled embeddings the parameters tile_shape and halo have to be passed.") 577 image_embeddings = util.precompute_image_embeddings( 578 predictor, image, tile_shape=tile_shape, halo=halo, verbose=verbose, batch_size=batch_size 579 ) 580 581 # Use tile shape and halo from the precomputed embeddings if not given. 582 # Otherwise check that they are consistent. 583 feats = image_embeddings["features"] 584 tile_shape_, halo_ = tuple(feats.attrs["tile_shape"]), tuple(feats.attrs["halo"]) 585 if tile_shape is None: 586 tile_shape = tile_shape_ 587 elif tile_shape != tile_shape_: 588 raise ValueError( 589 f"Inconsistent tile_shape parameter {tile_shape} with precomputed embeedings: {tile_shape_}." 590 ) 591 if halo is None: 592 halo = halo_ 593 elif halo != halo_: 594 raise ValueError(f"Inconsistent halo parameter {halo} with precomputed embeedings: {halo_}.") 595 596 return image_embeddings, tile_shape, halo 597 598 599class TiledAutomaticMaskGenerator(AutomaticMaskGenerator): 600 """Generates an instance segmentation without prompts, using a point grid. 601 602 Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings. 603 604 Args: 605 predictor: The Segment Anything predictor. 606 points_per_side: The number of points to be sampled along one side of the image. 607 If None, `point_grids` must provide explicit point sampling. By default, set to '32'. 608 points_per_batch: The number of points run simultaneously by the model. 609 Higher numbers may be faster but use more GPU memory. By default, set to '64'. 610 point_grids: A list over explicit grids of points used for sampling masks. 611 Normalized to [0, 1] with respect to the image coordinate system. 612 stability_score_offset: The amount to shift the cutoff when calculating the stability score. 613 By default, set to '1.0'. 614 """ 615 616 # We only expose the arguments that make sense for the tiled mask generator. 617 # Anything related to crops doesn't make sense, because we re-use that functionality 618 # for tiling, so these parameters wouldn't have any effect. 619 def __init__( 620 self, 621 predictor: SamPredictor, 622 points_per_side: Optional[int] = 32, 623 points_per_batch: int = 64, 624 point_grids: Optional[List[np.ndarray]] = None, 625 stability_score_offset: float = 1.0, 626 ) -> None: 627 super().__init__( 628 predictor=predictor, 629 points_per_side=points_per_side, 630 points_per_batch=points_per_batch, 631 point_grids=point_grids, 632 stability_score_offset=stability_score_offset, 633 ) 634 635 @torch.no_grad() 636 def initialize( 637 self, 638 image: np.ndarray, 639 image_embeddings: Optional[util.ImageEmbeddings] = None, 640 i: Optional[int] = None, 641 tile_shape: Optional[Tuple[int, int]] = None, 642 halo: Optional[Tuple[int, int]] = None, 643 verbose: bool = False, 644 pbar_init: Optional[callable] = None, 645 pbar_update: Optional[callable] = None, 646 batch_size: int = 1, 647 ) -> None: 648 """Initialize image embeddings and masks for an image. 649 650 Args: 651 image: The input image, volume or timeseries. 652 image_embeddings: Optional precomputed image embeddings. 653 See `util.precompute_image_embeddings` for details. 654 i: Index for the image data. Required if `image` has three spatial dimensions 655 or a time dimension and two spatial dimensions. 656 tile_shape: The tile shape for embedding prediction. 657 halo: The overlap of between tiles. 658 verbose: Whether to print computation progress. By default, set to 'False'. 659 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 660 Can be used together with pbar_update to handle napari progress bar in other thread. 661 To enables using this function within a threadworker. 662 pbar_update: Callback to update an external progress bar. 663 batch_size: The batch size for image embedding prediction. By default, set to '1'. 664 """ 665 original_size = image.shape[:2] 666 self._original_size = original_size 667 668 image_embeddings, tile_shape, halo = _process_tiled_embeddings( 669 self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose, batch_size=batch_size 670 ) 671 672 tiling = blocking([0, 0], original_size, tile_shape) 673 n_tiles = tiling.numberOfBlocks 674 675 # The crop box is always the full local tile. 676 tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)] 677 crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles] 678 679 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 680 pbar_init(n_tiles, "Compute masks for tile") 681 682 # We need to cast to the image representation that is compatible with SAM. 683 image = util._to_image(image) 684 685 mask_data = [] 686 for tile_id in range(n_tiles): 687 # set the pre-computed embeddings for this tile 688 features = image_embeddings["features"][str(tile_id)] 689 tile_embeddings = { 690 "features": features, 691 "input_size": features.attrs["input_size"], 692 "original_size": features.attrs["original_size"], 693 } 694 util.set_precomputed(self._predictor, tile_embeddings, i) 695 696 # compute the mask data for this tile and append it 697 this_mask_data = self._process_crop( 698 image, crop_box=crop_boxes[tile_id], crop_layer_idx=0, precomputed_embeddings=True 699 ) 700 mask_data.append(this_mask_data) 701 pbar_update(1) 702 pbar_close() 703 704 # set the initialized data 705 self._is_initialized = True 706 self._crop_list = mask_data 707 self._crop_boxes = crop_boxes 708 709 710# 711# Instance segmentation functionality based on fine-tuned decoder 712# 713 714 715class DecoderAdapter(torch.nn.Module): 716 """Adapter to contain the UNETR decoder in a single module. 717 718 To apply the decoder on top of pre-computed embeddings for the segmentation functionality. 719 See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py 720 """ 721 def __init__(self, unetr: torch.nn.Module): 722 super().__init__() 723 724 self.base = unetr.base 725 self.out_conv = unetr.out_conv 726 self.deconv_out = unetr.deconv_out 727 self.decoder_head = unetr.decoder_head 728 self.final_activation = unetr.final_activation 729 self.postprocess_masks = unetr.postprocess_masks 730 731 self.decoder = unetr.decoder 732 self.deconv1 = unetr.deconv1 733 self.deconv2 = unetr.deconv2 734 self.deconv3 = unetr.deconv3 735 self.deconv4 = unetr.deconv4 736 737 def _forward_impl(self, input_): 738 z12 = input_ 739 740 z9 = self.deconv1(z12) 741 z6 = self.deconv2(z9) 742 z3 = self.deconv3(z6) 743 z0 = self.deconv4(z3) 744 745 updated_from_encoder = [z9, z6, z3] 746 747 x = self.base(z12) 748 x = self.decoder(x, encoder_inputs=updated_from_encoder) 749 x = self.deconv_out(x) 750 751 x = torch.cat([x, z0], dim=1) 752 x = self.decoder_head(x) 753 754 x = self.out_conv(x) 755 if self.final_activation is not None: 756 x = self.final_activation(x) 757 return x 758 759 def forward(self, input_, input_shape, original_shape): 760 x = self._forward_impl(input_) 761 x = self.postprocess_masks(x, input_shape, original_shape) 762 return x 763 764 765def get_unetr( 766 image_encoder: torch.nn.Module, 767 decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None, 768 device: Optional[Union[str, torch.device]] = None, 769 out_channels: int = 3, 770 flexible_load_checkpoint: bool = False, 771) -> torch.nn.Module: 772 """Get UNETR model for automatic instance segmentation. 773 774 Args: 775 image_encoder: The image encoder of the SAM model. 776 This is used as encoder by the UNETR too. 777 decoder_state: Optional decoder state to initialize the weights of the UNETR decoder. 778 device: The device. By default, automatically chooses the best available device. 779 out_channels: The number of output channels. By default, set to '3'. 780 flexible_load_checkpoint: Whether to allow reinitialization of parameters 781 which could not be found in the provided decoder state. By default, set to 'False'. 782 783 Returns: 784 The UNETR model. 785 """ 786 device = util.get_device(device) 787 788 if decoder_state is None: 789 use_conv_transpose = False # By default, we use interpolation for upsampling. 790 else: 791 # From the provided pretrained 'decoder_state', we check whether it uses transposed convolutions. 792 # NOTE: Explanation to the logic below - 793 # - We do this by looking for parameter names that contain '.block.' within the "decoder.samplers" 794 # submodules. This naming convention indicates that transposed convolutions are used, 795 # wrapped inside a custom block module. 796 # - Otherwise '.conv.' appears. It indicates a standard `Conv2d` applied after interpolation for upsampling. 797 use_conv_transpose = any(".block." in k for k in decoder_state.keys() if k.startswith("decoder.samplers")) 798 799 unetr = UNETR( 800 backbone="sam", 801 encoder=image_encoder, 802 out_channels=out_channels, 803 use_sam_stats=True, 804 final_activation="Sigmoid", 805 use_skip_connection=False, 806 resize_input=True, 807 use_conv_transpose=use_conv_transpose, 808 809 ) 810 811 if decoder_state is not None: 812 unetr_state_dict = unetr.state_dict() 813 for k, v in unetr_state_dict.items(): 814 if not k.startswith("encoder"): 815 if flexible_load_checkpoint: # Whether allow reinitalization of params, if not found. 816 if k in decoder_state: # First check whether the key is available in the provided decoder state. 817 unetr_state_dict[k] = decoder_state[k] 818 else: # Otherwise, allow it to initialize it. 819 warnings.warn(f"Could not find '{k}' in the pretrained state dict. Hence, we reinitialize it.") 820 unetr_state_dict[k] = v 821 822 else: # Whether be strict on finding the parameter in the decoder state. 823 if k not in decoder_state: 824 raise RuntimeError(f"The parameters for '{k}' could not be found.") 825 unetr_state_dict[k] = decoder_state[k] 826 827 unetr.load_state_dict(unetr_state_dict) 828 829 unetr.to(device) 830 return unetr 831 832 833def get_decoder( 834 image_encoder: torch.nn.Module, 835 decoder_state: OrderedDict[str, torch.Tensor], 836 device: Optional[Union[str, torch.device]] = None, 837) -> DecoderAdapter: 838 """Get decoder to predict outputs for automatic instance segmentation 839 840 Args: 841 image_encoder: The image encoder of the SAM model. 842 decoder_state: State to initialize the weights of the UNETR decoder. 843 device: The device. By default, automatically chooses the best available device. 844 845 Returns: 846 The decoder for instance segmentation. 847 """ 848 unetr = get_unetr(image_encoder, decoder_state, device) 849 return DecoderAdapter(unetr) 850 851 852def get_predictor_and_decoder( 853 model_type: str, 854 checkpoint_path: Union[str, os.PathLike], 855 device: Optional[Union[str, torch.device]] = None, 856 peft_kwargs: Optional[Dict] = None, 857) -> Tuple[SamPredictor, DecoderAdapter]: 858 """Load the SAM model (predictor) and instance segmentation decoder. 859 860 This requires a checkpoint that contains the state for both predictor 861 and decoder. 862 863 Args: 864 model_type: The type of the image encoder used in the SAM model. 865 checkpoint_path: Path to the checkpoint from which to load the data. 866 device: The device. By default, automatically chooses the best available device. 867 peft_kwargs: Keyword arguments for the PEFT wrapper class. 868 869 Returns: 870 The SAM predictor. 871 The decoder for instance segmentation. 872 """ 873 device = util.get_device(device) 874 predictor, state = util.get_sam_model( 875 model_type=model_type, 876 checkpoint_path=checkpoint_path, 877 device=device, 878 return_state=True, 879 peft_kwargs=peft_kwargs, 880 ) 881 882 if "decoder_state" not in state: 883 raise ValueError( 884 f"The checkpoint at '{checkpoint_path}' or the chosen model '{model_type}' does not contain a decoder state" 885 ) 886 887 decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device) 888 return predictor, decoder 889 890 891def _watershed_from_center_and_boundary_distances_parallel( 892 center_distances, 893 boundary_distances, 894 foreground_map, 895 center_distance_threshold, 896 boundary_distance_threshold, 897 foreground_threshold, 898 distance_smoothing, 899 min_size, 900 tile_shape, 901 halo, 902 n_threads, 903 verbose=False, 904): 905 center_distances = apply_filter( 906 center_distances, "gaussianSmoothing", sigma=distance_smoothing, 907 block_shape=tile_shape, n_threads=n_threads 908 ) 909 boundary_distances = apply_filter( 910 boundary_distances, "gaussianSmoothing", sigma=distance_smoothing, 911 block_shape=tile_shape, n_threads=n_threads 912 ) 913 914 fg_mask = foreground_map > foreground_threshold 915 916 marker_map = np.logical_and( 917 center_distances < center_distance_threshold, boundary_distances < boundary_distance_threshold 918 ) 919 marker_map[~fg_mask] = 0 920 921 markers = np.zeros(marker_map.shape, dtype="uint64") 922 markers = parallel.label( 923 marker_map, out=markers, block_shape=tile_shape, n_threads=n_threads, verbose=verbose, 924 ) 925 926 seg = np.zeros_like(markers, dtype="uint64") 927 seg = parallel.seeded_watershed( 928 boundary_distances, seeds=markers, out=seg, block_shape=tile_shape, 929 halo=halo, n_threads=n_threads, verbose=verbose, mask=fg_mask, 930 ) 931 932 out = np.zeros_like(seg, dtype="uint64") 933 out = parallel.size_filter( 934 seg, out=out, min_size=min_size, block_shape=tile_shape, n_threads=n_threads, verbose=verbose 935 ) 936 937 return out 938 939 940class InstanceSegmentationWithDecoder: 941 """Generates an instance segmentation without prompts, using a decoder. 942 943 Implements the same interface as `AutomaticMaskGenerator`. 944 945 Use this class as follows: 946 ```python 947 segmenter = InstanceSegmentationWithDecoder(predictor, decoder) 948 segmenter.initialize(image) # Predict the image embeddings and decoder outputs. 949 masks = segmenter.generate(center_distance_threshold=0.75) # Generate the instance segmentation. 950 ``` 951 952 Args: 953 predictor: The segment anything predictor. 954 decoder: The decoder to predict intermediate representations 955 for instance segmentation. 956 """ 957 def __init__(self, predictor: SamPredictor, decoder: torch.nn.Module) -> None: 958 self._predictor = predictor 959 self._decoder = decoder 960 961 # The decoder outputs. 962 self._foreground = None 963 self._center_distances = None 964 self._boundary_distances = None 965 966 self._is_initialized = False 967 968 @property 969 def is_initialized(self): 970 """Whether the mask generator has already been initialized. 971 """ 972 return self._is_initialized 973 974 @torch.no_grad() 975 def initialize( 976 self, 977 image: np.ndarray, 978 image_embeddings: Optional[util.ImageEmbeddings] = None, 979 i: Optional[int] = None, 980 verbose: bool = False, 981 pbar_init: Optional[callable] = None, 982 pbar_update: Optional[callable] = None, 983 ) -> None: 984 """Initialize image embeddings and decoder predictions for an image. 985 986 Args: 987 image: The input image, volume or timeseries. 988 image_embeddings: Optional precomputed image embeddings. 989 See `util.precompute_image_embeddings` for details. 990 i: Index for the image data. Required if `image` has three spatial dimensions 991 or a time dimension and two spatial dimensions. 992 verbose: Whether to be verbose. By default, set to 'False'. 993 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 994 Can be used together with pbar_update to handle napari progress bar in other thread. 995 To enables using this function within a threadworker. 996 pbar_update: Callback to update an external progress bar. 997 """ 998 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 999 pbar_init(1, "Initialize instance segmentation with decoder") 1000 1001 if image_embeddings is None: 1002 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 1003 1004 # Get the image embeddings from the predictor. 1005 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i) 1006 embeddings = self._predictor.features 1007 input_shape = tuple(self._predictor.input_size) 1008 original_shape = tuple(self._predictor.original_size) 1009 1010 # Run prediction with the UNETR decoder. 1011 output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0) 1012 assert output.shape[0] == 3, f"{output.shape}" 1013 pbar_update(1) 1014 pbar_close() 1015 1016 # Set the state. 1017 self._foreground = output[0] 1018 self._center_distances = output[1] 1019 self._boundary_distances = output[2] 1020 self._is_initialized = True 1021 1022 def _to_masks(self, segmentation, output_mode): 1023 if output_mode != "binary_mask": 1024 raise NotImplementedError 1025 1026 props = regionprops(segmentation) 1027 ndim = segmentation.ndim 1028 assert ndim in (2, 3) 1029 1030 shape = segmentation.shape 1031 if ndim == 2: 1032 crop_box = [0, shape[1], 0, shape[0]] 1033 else: 1034 crop_box = [0, shape[2], 0, shape[1], 0, shape[0]] 1035 1036 # go from skimage bbox in format [y0, x0, y1, x1] to SAM format [x0, w, y0, h] 1037 def to_bbox_2d(bbox): 1038 y0, x0 = bbox[0], bbox[1] 1039 w = bbox[3] - x0 1040 h = bbox[2] - y0 1041 return [x0, w, y0, h] 1042 1043 def to_bbox_3d(bbox): 1044 z0, y0, x0 = bbox[0], bbox[1], bbox[2] 1045 w = bbox[5] - x0 1046 h = bbox[4] - y0 1047 d = bbox[3] - y0 1048 return [x0, w, y0, h, z0, d] 1049 1050 to_bbox = to_bbox_2d if ndim == 2 else to_bbox_3d 1051 masks = [ 1052 { 1053 "segmentation": segmentation == prop.label, 1054 "area": prop.area, 1055 "bbox": to_bbox(prop.bbox), 1056 "crop_box": crop_box, 1057 "seg_id": prop.label, 1058 } for prop in props 1059 ] 1060 return masks 1061 1062 def generate( 1063 self, 1064 center_distance_threshold: float = 0.5, 1065 boundary_distance_threshold: float = 0.5, 1066 foreground_threshold: float = 0.5, 1067 foreground_smoothing: float = 1.0, 1068 distance_smoothing: float = 1.6, 1069 min_size: int = 0, 1070 output_mode: Optional[str] = "binary_mask", 1071 tile_shape: Optional[Tuple[int, int]] = None, 1072 halo: Optional[Tuple[int, int]] = None, 1073 n_threads: Optional[int] = None, 1074 ) -> List[Dict[str, Any]]: 1075 """Generate instance segmentation for the currently initialized image. 1076 1077 Args: 1078 center_distance_threshold: Center distance predictions below this value will be 1079 used to find seeds (intersected with thresholded boundary distance predictions). 1080 By default, set to '0.5'. 1081 boundary_distance_threshold: Boundary distance predictions below this value will be 1082 used to find seeds (intersected with thresholded center distance predictions). 1083 By default, set to '0.5'. 1084 foreground_threshold: Foreground predictions above this value will be used as foreground mask. 1085 By default, set to '0.5'. 1086 foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid 1087 checkerboard artifacts in the prediction. By default, set to '1.0'. 1088 distance_smoothing: Sigma value for smoothing the distance predictions. 1089 min_size: Minimal object size in the segmentation result. By default, set to '0'. 1090 output_mode: The form masks are returned in. Pass None to directly return the instance segmentation. 1091 By default, set to 'binary_mask'. 1092 tile_shape: Tile shape for parallelizing the instance segmentation post-processing. 1093 This parameter is independent from the tile shape for computing the embeddings. 1094 If not given then post-processing will not be parallelized. 1095 halo: Halo for parallel post-processing. See also `tile_shape`. 1096 n_threads: Number of threads for parallel post-processing. See also `tile_shape`. 1097 1098 Returns: 1099 The instance segmentation masks. 1100 """ 1101 if not self.is_initialized: 1102 raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.") 1103 1104 if foreground_smoothing > 0: 1105 foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing) 1106 else: 1107 foreground = self._foreground 1108 1109 if tile_shape is None: 1110 segmentation = watershed_from_center_and_boundary_distances( 1111 center_distances=self._center_distances, 1112 boundary_distances=self._boundary_distances, 1113 foreground_map=foreground, 1114 center_distance_threshold=center_distance_threshold, 1115 boundary_distance_threshold=boundary_distance_threshold, 1116 foreground_threshold=foreground_threshold, 1117 distance_smoothing=distance_smoothing, 1118 min_size=min_size, 1119 ) 1120 else: 1121 if halo is None: 1122 raise ValueError("You must pass a value for halo if tile_shape is given.") 1123 segmentation = _watershed_from_center_and_boundary_distances_parallel( 1124 center_distances=self._center_distances, 1125 boundary_distances=self._boundary_distances, 1126 foreground_map=foreground, 1127 center_distance_threshold=center_distance_threshold, 1128 boundary_distance_threshold=boundary_distance_threshold, 1129 foreground_threshold=foreground_threshold, 1130 distance_smoothing=distance_smoothing, 1131 min_size=min_size, 1132 tile_shape=tile_shape, 1133 halo=halo, 1134 n_threads=n_threads, 1135 verbose=False, 1136 ) 1137 1138 if output_mode is not None: 1139 segmentation = self._to_masks(segmentation, output_mode) 1140 return segmentation 1141 1142 def get_state(self) -> Dict[str, Any]: 1143 """Get the initialized state of the instance segmenter. 1144 1145 Returns: 1146 Instance segmentation state. 1147 """ 1148 if not self.is_initialized: 1149 raise RuntimeError("The state has not been computed yet. Call initialize first.") 1150 1151 return { 1152 "foreground": self._foreground, 1153 "center_distances": self._center_distances, 1154 "boundary_distances": self._boundary_distances, 1155 } 1156 1157 def set_state(self, state: Dict[str, Any]) -> None: 1158 """Set the state of the instance segmenter. 1159 1160 Args: 1161 state: The instance segmentation state 1162 """ 1163 self._foreground = state["foreground"] 1164 self._center_distances = state["center_distances"] 1165 self._boundary_distances = state["boundary_distances"] 1166 self._is_initialized = True 1167 1168 def clear_state(self): 1169 """Clear the state of the instance segmenter. 1170 """ 1171 self._foreground = None 1172 self._center_distances = None 1173 self._boundary_distances = None 1174 self._is_initialized = False 1175 1176 1177class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder): 1178 """Same as `InstanceSegmentationWithDecoder` but for tiled image embeddings. 1179 """ 1180 1181 # Apply the decoder in a batched fashion, and then perform the resizing independently per output. 1182 # This is necessary, because the individual tiles may have different tile shapes due to border tiles. 1183 def _predict_decoder(self, batched_embeddings, input_shapes, original_shapes): 1184 batched_embeddings = torch.cat(batched_embeddings) 1185 output = self._decoder._forward_impl(batched_embeddings) 1186 1187 batched_output = [] 1188 for x, input_shape, original_shape in zip(output, input_shapes, original_shapes): 1189 x = self._decoder.postprocess_masks(x.unsqueeze(0), input_shape, original_shape).squeeze(0) 1190 batched_output.append(x.cpu().numpy()) 1191 return batched_output 1192 1193 @torch.no_grad() 1194 def initialize( 1195 self, 1196 image: np.ndarray, 1197 image_embeddings: Optional[util.ImageEmbeddings] = None, 1198 i: Optional[int] = None, 1199 tile_shape: Optional[Tuple[int, int]] = None, 1200 halo: Optional[Tuple[int, int]] = None, 1201 verbose: bool = False, 1202 pbar_init: Optional[callable] = None, 1203 pbar_update: Optional[callable] = None, 1204 batch_size: int = 1, 1205 ) -> None: 1206 """Initialize image embeddings and decoder predictions for an image. 1207 1208 Args: 1209 image: The input image, volume or timeseries. 1210 image_embeddings: Optional precomputed image embeddings. 1211 See `util.precompute_image_embeddings` for details. 1212 i: Index for the image data. Required if `image` has three spatial dimensions 1213 or a time dimension and two spatial dimensions. 1214 tile_shape: Shape of the tiles for precomputing image embeddings. 1215 halo: Overlap of the tiles for tiled precomputation of image embeddings. 1216 verbose: Dummy input to be compatible with other function signatures. By default, set to 'False'. 1217 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 1218 Can be used together with pbar_update to handle napari progress bar in other thread. 1219 To enables using this function within a threadworker. 1220 pbar_update: Callback to update an external progress bar. 1221 batch_size: The batch size for image embedding computation and segmentation decoder prediction. 1222 """ 1223 original_size = image.shape[:2] 1224 image_embeddings, tile_shape, halo = _process_tiled_embeddings( 1225 self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose, batch_size=batch_size 1226 ) 1227 tiling = blocking([0, 0], original_size, tile_shape) 1228 1229 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 1230 pbar_init(tiling.numberOfBlocks, "Initialize tiled instance segmentation with decoder") 1231 1232 foreground = np.zeros(original_size, dtype="float32") 1233 center_distances = np.zeros(original_size, dtype="float32") 1234 boundary_distances = np.zeros(original_size, dtype="float32") 1235 1236 n_tiles = tiling.numberOfBlocks 1237 n_batches = int(np.ceil(n_tiles / batch_size)) 1238 1239 for batch_id in range(n_batches): 1240 tile_start = batch_id * batch_size 1241 tile_stop = min(tile_start + batch_size, n_tiles) 1242 1243 batched_embeddings, input_shapes, original_shapes = [], [], [] 1244 for tile_id in range(tile_start, tile_stop): 1245 # Get the image embeddings from the predictor for this tile. 1246 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i, tile_id=tile_id) 1247 1248 batched_embeddings.append(self._predictor.features) 1249 input_shapes.append(tuple(self._predictor.input_size)) 1250 original_shapes.append(tuple(self._predictor.original_size)) 1251 1252 batched_output = self._predict_decoder(batched_embeddings, input_shapes, original_shapes) 1253 1254 for output_id, tile_id in enumerate(range(tile_start, tile_stop)): 1255 output = batched_output[output_id] 1256 assert output.shape[0] == 3 1257 1258 # Set the predictions in the output for this tile. 1259 block = tiling.getBlockWithHalo(tile_id, halo=list(halo)) 1260 local_bb = tuple( 1261 slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end) 1262 ) 1263 inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end)) 1264 1265 foreground[inner_bb] = output[0][local_bb] 1266 center_distances[inner_bb] = output[1][local_bb] 1267 boundary_distances[inner_bb] = output[2][local_bb] 1268 pbar_update(1) 1269 1270 pbar_close() 1271 1272 # Set the state. 1273 self._foreground = foreground 1274 self._center_distances = center_distances 1275 self._boundary_distances = boundary_distances 1276 self._is_initialized = True 1277 1278 1279def get_amg( 1280 predictor: SamPredictor, is_tiled: bool, decoder: Optional[torch.nn.Module] = None, **kwargs, 1281) -> Union[AMGBase, InstanceSegmentationWithDecoder]: 1282 """Get the automatic mask generator class. 1283 1284 Args: 1285 predictor: The segment anything predictor. 1286 is_tiled: Whether tiled embeddings are used. 1287 decoder: Decoder to predict instacne segmmentation. 1288 kwargs: The keyword arguments for the amg class. 1289 1290 Returns: 1291 The automatic mask generator. 1292 """ 1293 if decoder is None: 1294 segmenter_class = TiledAutomaticMaskGenerator if is_tiled else AutomaticMaskGenerator 1295 segmenter = segmenter_class(predictor, **kwargs) 1296 else: 1297 segmenter_class = TiledInstanceSegmentationWithDecoder if is_tiled else InstanceSegmentationWithDecoder 1298 segmenter = segmenter_class(predictor, decoder, **kwargs) 1299 1300 return segmenter
50def mask_data_to_segmentation( 51 masks: List[Dict[str, Any]], 52 with_background: bool, 53 min_object_size: int = 0, 54 max_object_size: Optional[int] = None, 55 label_masks: bool = True, 56) -> np.ndarray: 57 """Convert the output of the automatic mask generation to an instance segmentation. 58 59 Args: 60 masks: The outputs generated by AutomaticMaskGenerator or EmbeddingMaskGenerator. 61 Only supports output_mode=binary_mask. 62 with_background: Whether the segmentation has background. If yes this function assures that the largest 63 object in the output will be mapped to zero (the background value). 64 min_object_size: The minimal size of an object in pixels. By default, set to '0'. 65 max_object_size: The maximal size of an object in pixels. 66 label_masks: Whether to apply connected components to the result before removing small objects. 67 By default, set to 'True'. 68 69 Returns: 70 The instance segmentation. 71 """ 72 73 masks = sorted(masks, key=(lambda x: x["area"]), reverse=True) 74 # we could also get the shape from the crop box 75 shape = next(iter(masks))["segmentation"].shape 76 segmentation = np.zeros(shape, dtype="uint32") 77 78 def require_numpy(mask): 79 return mask.cpu().numpy() if torch.is_tensor(mask) else mask 80 81 seg_id = 1 82 for mask in masks: 83 if mask["area"] < min_object_size: 84 continue 85 if max_object_size is not None and mask["area"] > max_object_size: 86 continue 87 88 this_seg_id = mask.get("seg_id", seg_id) 89 segmentation[require_numpy(mask["segmentation"])] = this_seg_id 90 seg_id = this_seg_id + 1 91 92 if label_masks: 93 segmentation = label(segmentation).astype(segmentation.dtype) 94 95 seg_ids, sizes = np.unique(segmentation, return_counts=True) 96 97 # In some cases objects may be smaller than peviously calculated, 98 # since they are covered by other objects. We ensure these also get 99 # filtered out here. 100 filter_ids = seg_ids[sizes < min_object_size] 101 102 # If we run segmentation with background we also map the largest segment 103 # (the most likely background object) to zero. This is often zero already, 104 # but it does not hurt to reset that to zero either. 105 if with_background: 106 bg_id = seg_ids[np.argmax(sizes)] 107 filter_ids = np.concatenate([filter_ids, [bg_id]]) 108 109 segmentation[np.isin(segmentation, filter_ids)] = 0 110 segmentation = relabel_sequential(segmentation)[0] 111 112 return segmentation
Convert the output of the automatic mask generation to an instance segmentation.
Arguments:
- masks: The outputs generated by AutomaticMaskGenerator or EmbeddingMaskGenerator. Only supports output_mode=binary_mask.
- with_background: Whether the segmentation has background. If yes this function assures that the largest object in the output will be mapped to zero (the background value).
- min_object_size: The minimal size of an object in pixels. By default, set to '0'.
- max_object_size: The maximal size of an object in pixels.
- label_masks: Whether to apply connected components to the result before removing small objects. By default, set to 'True'.
Returns:
The instance segmentation.
120class AMGBase(ABC): 121 """Base class for the automatic mask generators. 122 """ 123 def __init__(self): 124 # the state that has to be computed by the 'initialize' method of the child classes 125 self._is_initialized = False 126 self._crop_list = None 127 self._crop_boxes = None 128 self._original_size = None 129 130 @property 131 def is_initialized(self): 132 """Whether the mask generator has already been initialized. 133 """ 134 return self._is_initialized 135 136 @property 137 def crop_list(self): 138 """The list of mask data after initialization. 139 """ 140 return self._crop_list 141 142 @property 143 def crop_boxes(self): 144 """The list of crop boxes. 145 """ 146 return self._crop_boxes 147 148 @property 149 def original_size(self): 150 """The original image size. 151 """ 152 return self._original_size 153 154 def _postprocess_batch( 155 self, 156 data, 157 crop_box, 158 original_size, 159 pred_iou_thresh, 160 stability_score_thresh, 161 box_nms_thresh, 162 ): 163 orig_h, orig_w = original_size 164 165 # filter by predicted IoU 166 if pred_iou_thresh > 0.0: 167 keep_mask = data["iou_preds"] > pred_iou_thresh 168 data.filter(keep_mask) 169 170 # filter by stability score 171 if stability_score_thresh > 0.0: 172 keep_mask = data["stability_score"] >= stability_score_thresh 173 data.filter(keep_mask) 174 175 # filter boxes that touch crop boundaries 176 keep_mask = ~amg_utils.is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) 177 if not torch.all(keep_mask): 178 data.filter(keep_mask) 179 180 # remove duplicates within this crop. 181 keep_by_nms = batched_nms( 182 data["boxes"].float(), 183 data["iou_preds"], 184 torch.zeros_like(data["boxes"][:, 0]), # categories 185 iou_threshold=box_nms_thresh, 186 ) 187 data.filter(keep_by_nms) 188 189 # return to the original image frame 190 data["boxes"] = amg_utils.uncrop_boxes_xyxy(data["boxes"], crop_box) 191 data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) 192 # the data from embedding based segmentation doesn't have the points 193 # so we skip if the corresponding key can't be found 194 try: 195 data["points"] = amg_utils.uncrop_points(data["points"], crop_box) 196 except KeyError: 197 pass 198 199 return data 200 201 def _postprocess_small_regions(self, mask_data, min_area, nms_thresh): 202 203 if len(mask_data["rles"]) == 0: 204 return mask_data 205 206 # filter small disconnected regions and holes 207 new_masks = [] 208 scores = [] 209 for rle in mask_data["rles"]: 210 mask = amg_utils.rle_to_mask(rle) 211 212 mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="holes") 213 unchanged = not changed 214 mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="islands") 215 unchanged = unchanged and not changed 216 217 new_masks.append(torch.as_tensor(mask, dtype=torch.int).unsqueeze(0)) 218 # give score=0 to changed masks and score=1 to unchanged masks 219 # so NMS will prefer ones that didn't need postprocessing 220 scores.append(float(unchanged)) 221 222 # recalculate boxes and remove any new duplicates 223 masks = torch.cat(new_masks, dim=0) 224 boxes = batched_mask_to_box(masks.to(torch.bool)) # Casting this to boolean as we work with one-hot labels. 225 keep_by_nms = batched_nms( 226 boxes.float(), 227 torch.as_tensor(scores, dtype=torch.float), 228 torch.zeros_like(boxes[:, 0]), # categories 229 iou_threshold=nms_thresh, 230 ) 231 232 # only recalculate RLEs for masks that have changed 233 for i_mask in keep_by_nms: 234 if scores[i_mask] == 0.0: 235 mask_torch = masks[i_mask].unsqueeze(0) 236 # mask_data["rles"][i_mask] = amg_utils.mask_to_rle_pytorch(mask_torch)[0] 237 mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] 238 mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly 239 mask_data.filter(keep_by_nms) 240 241 return mask_data 242 243 def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode): 244 # filter small disconnected regions and holes in masks 245 if min_mask_region_area > 0: 246 mask_data = self._postprocess_small_regions( 247 mask_data, 248 min_mask_region_area, 249 max(box_nms_thresh, crop_nms_thresh), 250 ) 251 252 # encode masks 253 if output_mode == "coco_rle": 254 mask_data["segmentations"] = [amg_utils.coco_encode_rle(rle) for rle in mask_data["rles"]] 255 elif output_mode == "binary_mask": 256 mask_data["segmentations"] = [amg_utils.rle_to_mask(rle) for rle in mask_data["rles"]] 257 else: 258 mask_data["segmentations"] = mask_data["rles"] 259 260 # write mask records 261 curr_anns = [] 262 for idx in range(len(mask_data["segmentations"])): 263 ann = { 264 "segmentation": mask_data["segmentations"][idx], 265 "area": amg_utils.area_from_rle(mask_data["rles"][idx]), 266 "bbox": amg_utils.box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), 267 "predicted_iou": mask_data["iou_preds"][idx].item(), 268 "stability_score": mask_data["stability_score"][idx].item(), 269 "crop_box": amg_utils.box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), 270 } 271 # the data from embedding based segmentation doesn't have the points 272 # so we skip if the corresponding key can't be found 273 try: 274 ann["point_coords"] = [mask_data["points"][idx].tolist()] 275 except KeyError: 276 pass 277 curr_anns.append(ann) 278 279 return curr_anns 280 281 def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None): 282 orig_h, orig_w = original_size 283 284 # serialize predictions and store in MaskData 285 data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1)) 286 if points is not None: 287 data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0), dtype=torch.float) 288 289 del masks 290 291 # calculate the stability scores 292 data["stability_score"] = amg_utils.calculate_stability_score( 293 data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset 294 ) 295 296 # threshold masks and calculate boxes 297 data["masks"] = data["masks"] > self._predictor.model.mask_threshold 298 data["masks"] = data["masks"].type(torch.bool) 299 data["boxes"] = batched_mask_to_box(data["masks"]) 300 301 # compress to RLE 302 data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w) 303 # data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"]) 304 data["rles"] = mask_to_rle_pytorch(data["masks"]) 305 del data["masks"] 306 307 return data 308 309 def get_state(self) -> Dict[str, Any]: 310 """Get the initialized state of the mask generator. 311 312 Returns: 313 State of the mask generator. 314 """ 315 if not self.is_initialized: 316 raise RuntimeError("The state has not been computed yet. Call initialize first.") 317 318 return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size} 319 320 def set_state(self, state: Dict[str, Any]) -> None: 321 """Set the state of the mask generator. 322 323 Args: 324 state: The state of the mask generator, e.g. from serialized state. 325 """ 326 self._crop_list = state["crop_list"] 327 self._crop_boxes = state["crop_boxes"] 328 self._original_size = state["original_size"] 329 self._is_initialized = True 330 331 def clear_state(self): 332 """Clear the state of the mask generator. 333 """ 334 self._crop_list = None 335 self._crop_boxes = None 336 self._original_size = None 337 self._is_initialized = False
Base class for the automatic mask generators.
130 @property 131 def is_initialized(self): 132 """Whether the mask generator has already been initialized. 133 """ 134 return self._is_initialized
Whether the mask generator has already been initialized.
136 @property 137 def crop_list(self): 138 """The list of mask data after initialization. 139 """ 140 return self._crop_list
The list of mask data after initialization.
142 @property 143 def crop_boxes(self): 144 """The list of crop boxes. 145 """ 146 return self._crop_boxes
The list of crop boxes.
148 @property 149 def original_size(self): 150 """The original image size. 151 """ 152 return self._original_size
The original image size.
309 def get_state(self) -> Dict[str, Any]: 310 """Get the initialized state of the mask generator. 311 312 Returns: 313 State of the mask generator. 314 """ 315 if not self.is_initialized: 316 raise RuntimeError("The state has not been computed yet. Call initialize first.") 317 318 return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size}
Get the initialized state of the mask generator.
Returns:
State of the mask generator.
320 def set_state(self, state: Dict[str, Any]) -> None: 321 """Set the state of the mask generator. 322 323 Args: 324 state: The state of the mask generator, e.g. from serialized state. 325 """ 326 self._crop_list = state["crop_list"] 327 self._crop_boxes = state["crop_boxes"] 328 self._original_size = state["original_size"] 329 self._is_initialized = True
Set the state of the mask generator.
Arguments:
- state: The state of the mask generator, e.g. from serialized state.
340class AutomaticMaskGenerator(AMGBase): 341 """Generates an instance segmentation without prompts, using a point grid. 342 343 This class implements the same logic as 344 https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py 345 It decouples the computationally expensive steps of generating masks from the cheap post-processing operation 346 to filter these masks to enable grid search and interactively changing the post-processing. 347 348 Use this class as follows: 349 ```python 350 amg = AutomaticMaskGenerator(predictor) 351 amg.initialize(image) # Initialize the masks, this takes care of all expensive computations. 352 masks = amg.generate(pred_iou_thresh=0.8) # Generate the masks. This is fast and enables testing parameters 353 ``` 354 355 Args: 356 predictor: The segment anything predictor. 357 points_per_side: The number of points to be sampled along one side of the image. 358 If None, `point_grids` must provide explicit point sampling. By default, set to '32'. 359 points_per_batch: The number of points run simultaneously by the model. 360 Higher numbers may be faster but use more GPU memory. 361 By default, selects '64' for all devices except 'mps' (selects '16' for performance reasons). 362 crop_n_layers: If >0, the mask prediction will be run again on crops of the image. 363 By default, set to '0'. 364 crop_overlap_ratio: Sets the degree to which crops overlap. By default, set to '512 / 1500'. 365 crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops. 366 By default, set to '1'. 367 point_grids: A list over explicit grids of points used for sampling masks. 368 Normalized to [0, 1] with respect to the image coordinate system. 369 stability_score_offset: The amount to shift the cutoff when calculating the stability score. 370 By default, set to '1.0'. 371 """ 372 def __init__( 373 self, 374 predictor: SamPredictor, 375 points_per_side: Optional[int] = 32, 376 points_per_batch: Optional[int] = None, 377 crop_n_layers: int = 0, 378 crop_overlap_ratio: float = 512 / 1500, 379 crop_n_points_downscale_factor: int = 1, 380 point_grids: Optional[List[np.ndarray]] = None, 381 stability_score_offset: float = 1.0, 382 ): 383 super().__init__() 384 385 if points_per_side is not None: 386 self.point_grids = amg_utils.build_all_layer_point_grids( 387 points_per_side, crop_n_layers, crop_n_points_downscale_factor, 388 ) 389 elif point_grids is not None: 390 self.point_grids = point_grids 391 else: 392 raise ValueError("Can't have both points_per_side and point_grid be None or not None.") 393 394 self._predictor = predictor 395 self._points_per_side = points_per_side 396 397 # we set the points per batch to 16 for mps for performance reasons 398 # and otherwise keep them at the default of 64 399 if points_per_batch is None: 400 points_per_batch = 16 if str(predictor.device) == "mps" else 64 401 self._points_per_batch = points_per_batch 402 403 self._crop_n_layers = crop_n_layers 404 self._crop_overlap_ratio = crop_overlap_ratio 405 self._crop_n_points_downscale_factor = crop_n_points_downscale_factor 406 self._stability_score_offset = stability_score_offset 407 408 def _process_batch(self, points, im_size, crop_box, original_size): 409 # run model on this batch 410 transformed_points = self._predictor.transform.apply_coords(points, im_size) 411 in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float) 412 in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) 413 masks, iou_preds, _ = self._predictor.predict_torch( 414 point_coords=in_points[:, None, :], 415 point_labels=in_labels[:, None], 416 multimask_output=True, 417 return_logits=True, 418 ) 419 data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points) 420 del masks 421 return data 422 423 def _process_crop(self, image, crop_box, crop_layer_idx, precomputed_embeddings, pbar_init=None, pbar_update=None): 424 # Crop the image and calculate embeddings. 425 x0, y0, x1, y1 = crop_box 426 cropped_im = image[y0:y1, x0:x1, :] 427 cropped_im_size = cropped_im.shape[:2] 428 429 if not precomputed_embeddings: 430 self._predictor.set_image(cropped_im) 431 432 # Get the points for this crop. 433 points_scale = np.array(cropped_im_size)[None, ::-1] 434 points_for_image = self.point_grids[crop_layer_idx] * points_scale 435 436 # Generate masks for this crop in batches. 437 data = amg_utils.MaskData() 438 n_batches = len(points_for_image) // self._points_per_batch +\ 439 int(len(points_for_image) % self._points_per_batch != 0) 440 if pbar_init is not None: 441 pbar_init(n_batches, "Predict masks for point grid prompts") 442 443 for (points,) in amg_utils.batch_iterator(self._points_per_batch, points_for_image): 444 batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size) 445 data.cat(batch_data) 446 del batch_data 447 if pbar_update is not None: 448 pbar_update(1) 449 450 if not precomputed_embeddings: 451 self._predictor.reset_image() 452 453 return data 454 455 @torch.no_grad() 456 def initialize( 457 self, 458 image: np.ndarray, 459 image_embeddings: Optional[util.ImageEmbeddings] = None, 460 i: Optional[int] = None, 461 verbose: bool = False, 462 pbar_init: Optional[callable] = None, 463 pbar_update: Optional[callable] = None, 464 ) -> None: 465 """Initialize image embeddings and masks for an image. 466 467 Args: 468 image: The input image, volume or timeseries. 469 image_embeddings: Optional precomputed image embeddings. 470 See `util.precompute_image_embeddings` for details. 471 i: Index for the image data. Required if `image` has three spatial dimensions 472 or a time dimension and two spatial dimensions. 473 verbose: Whether to print computation progress. By default, set to 'False'. 474 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 475 Can be used together with pbar_update to handle napari progress bar in other thread. 476 To enables using this function within a threadworker. 477 pbar_update: Callback to update an external progress bar. 478 """ 479 original_size = image.shape[:2] 480 self._original_size = original_size 481 482 crop_boxes, layer_idxs = amg_utils.generate_crop_boxes( 483 original_size, self._crop_n_layers, self._crop_overlap_ratio 484 ) 485 486 # We can set fixed image embeddings if we only have a single crop box (the default setting). 487 # Otherwise we have to recompute the embeddings for each crop and can't precompute. 488 if len(crop_boxes) == 1: 489 if image_embeddings is None: 490 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 491 util.set_precomputed(self._predictor, image_embeddings, i=i) 492 precomputed_embeddings = True 493 else: 494 precomputed_embeddings = False 495 496 # we need to cast to the image representation that is compatible with SAM 497 image = util._to_image(image) 498 499 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 500 501 crop_list = [] 502 for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 503 crop_data = self._process_crop( 504 image, crop_box, layer_idx, 505 precomputed_embeddings=precomputed_embeddings, 506 pbar_init=pbar_init, pbar_update=pbar_update, 507 ) 508 crop_list.append(crop_data) 509 pbar_close() 510 511 self._is_initialized = True 512 self._crop_list = crop_list 513 self._crop_boxes = crop_boxes 514 515 @torch.no_grad() 516 def generate( 517 self, 518 pred_iou_thresh: float = 0.88, 519 stability_score_thresh: float = 0.95, 520 box_nms_thresh: float = 0.7, 521 crop_nms_thresh: float = 0.7, 522 min_mask_region_area: int = 0, 523 output_mode: str = "binary_mask", 524 ) -> List[Dict[str, Any]]: 525 """Generate instance segmentation for the currently initialized image. 526 527 Args: 528 pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model. 529 By default, set to '0.88'. 530 stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask 531 under changes to the cutoff used to binarize the model prediction. By default, set to '0.95'. 532 box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks. 533 By default, set to '0.7'. 534 crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops. 535 By default, set to '0.7'. 536 min_mask_region_area: Minimal size for the predicted masks. By default, set to '0'. 537 output_mode: The form masks are returned in. By default, set to 'binary_mask'. 538 539 Returns: 540 The instance segmentation masks. 541 """ 542 if not self.is_initialized: 543 raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.") 544 545 data = amg_utils.MaskData() 546 for data_, crop_box in zip(self.crop_list, self.crop_boxes): 547 crop_data = self._postprocess_batch( 548 data=deepcopy(data_), 549 crop_box=crop_box, original_size=self.original_size, 550 pred_iou_thresh=pred_iou_thresh, 551 stability_score_thresh=stability_score_thresh, 552 box_nms_thresh=box_nms_thresh 553 ) 554 data.cat(crop_data) 555 556 if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0: 557 # Prefer masks from smaller crops 558 scores = 1 / box_area(data["crop_boxes"]) 559 scores = scores.to(data["boxes"].device) 560 keep_by_nms = batched_nms( 561 data["boxes"].float(), 562 scores, 563 torch.zeros_like(data["boxes"][:, 0]), # categories 564 iou_threshold=crop_nms_thresh, 565 ) 566 data.filter(keep_by_nms) 567 568 data.to_numpy() 569 masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode) 570 return masks
Generates an instance segmentation without prompts, using a point grid.
This class implements the same logic as https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py It decouples the computationally expensive steps of generating masks from the cheap post-processing operation to filter these masks to enable grid search and interactively changing the post-processing.
Use this class as follows:
amg = AutomaticMaskGenerator(predictor)
amg.initialize(image) # Initialize the masks, this takes care of all expensive computations.
masks = amg.generate(pred_iou_thresh=0.8) # Generate the masks. This is fast and enables testing parameters
Arguments:
- predictor: The segment anything predictor.
- points_per_side: The number of points to be sampled along one side of the image.
If None,
point_grids
must provide explicit point sampling. By default, set to '32'. - points_per_batch: The number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory. By default, selects '64' for all devices except 'mps' (selects '16' for performance reasons).
- crop_n_layers: If >0, the mask prediction will be run again on crops of the image. By default, set to '0'.
- crop_overlap_ratio: Sets the degree to which crops overlap. By default, set to '512 / 1500'.
- crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops. By default, set to '1'.
- point_grids: A list over explicit grids of points used for sampling masks. Normalized to [0, 1] with respect to the image coordinate system.
- stability_score_offset: The amount to shift the cutoff when calculating the stability score. By default, set to '1.0'.
372 def __init__( 373 self, 374 predictor: SamPredictor, 375 points_per_side: Optional[int] = 32, 376 points_per_batch: Optional[int] = None, 377 crop_n_layers: int = 0, 378 crop_overlap_ratio: float = 512 / 1500, 379 crop_n_points_downscale_factor: int = 1, 380 point_grids: Optional[List[np.ndarray]] = None, 381 stability_score_offset: float = 1.0, 382 ): 383 super().__init__() 384 385 if points_per_side is not None: 386 self.point_grids = amg_utils.build_all_layer_point_grids( 387 points_per_side, crop_n_layers, crop_n_points_downscale_factor, 388 ) 389 elif point_grids is not None: 390 self.point_grids = point_grids 391 else: 392 raise ValueError("Can't have both points_per_side and point_grid be None or not None.") 393 394 self._predictor = predictor 395 self._points_per_side = points_per_side 396 397 # we set the points per batch to 16 for mps for performance reasons 398 # and otherwise keep them at the default of 64 399 if points_per_batch is None: 400 points_per_batch = 16 if str(predictor.device) == "mps" else 64 401 self._points_per_batch = points_per_batch 402 403 self._crop_n_layers = crop_n_layers 404 self._crop_overlap_ratio = crop_overlap_ratio 405 self._crop_n_points_downscale_factor = crop_n_points_downscale_factor 406 self._stability_score_offset = stability_score_offset
455 @torch.no_grad() 456 def initialize( 457 self, 458 image: np.ndarray, 459 image_embeddings: Optional[util.ImageEmbeddings] = None, 460 i: Optional[int] = None, 461 verbose: bool = False, 462 pbar_init: Optional[callable] = None, 463 pbar_update: Optional[callable] = None, 464 ) -> None: 465 """Initialize image embeddings and masks for an image. 466 467 Args: 468 image: The input image, volume or timeseries. 469 image_embeddings: Optional precomputed image embeddings. 470 See `util.precompute_image_embeddings` for details. 471 i: Index for the image data. Required if `image` has three spatial dimensions 472 or a time dimension and two spatial dimensions. 473 verbose: Whether to print computation progress. By default, set to 'False'. 474 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 475 Can be used together with pbar_update to handle napari progress bar in other thread. 476 To enables using this function within a threadworker. 477 pbar_update: Callback to update an external progress bar. 478 """ 479 original_size = image.shape[:2] 480 self._original_size = original_size 481 482 crop_boxes, layer_idxs = amg_utils.generate_crop_boxes( 483 original_size, self._crop_n_layers, self._crop_overlap_ratio 484 ) 485 486 # We can set fixed image embeddings if we only have a single crop box (the default setting). 487 # Otherwise we have to recompute the embeddings for each crop and can't precompute. 488 if len(crop_boxes) == 1: 489 if image_embeddings is None: 490 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 491 util.set_precomputed(self._predictor, image_embeddings, i=i) 492 precomputed_embeddings = True 493 else: 494 precomputed_embeddings = False 495 496 # we need to cast to the image representation that is compatible with SAM 497 image = util._to_image(image) 498 499 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 500 501 crop_list = [] 502 for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 503 crop_data = self._process_crop( 504 image, crop_box, layer_idx, 505 precomputed_embeddings=precomputed_embeddings, 506 pbar_init=pbar_init, pbar_update=pbar_update, 507 ) 508 crop_list.append(crop_data) 509 pbar_close() 510 511 self._is_initialized = True 512 self._crop_list = crop_list 513 self._crop_boxes = crop_boxes
Initialize image embeddings and masks for an image.
Arguments:
- image: The input image, volume or timeseries.
- image_embeddings: Optional precomputed image embeddings.
See
util.precompute_image_embeddings
for details. - i: Index for the image data. Required if
image
has three spatial dimensions or a time dimension and two spatial dimensions. - verbose: Whether to print computation progress. By default, set to 'False'.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
515 @torch.no_grad() 516 def generate( 517 self, 518 pred_iou_thresh: float = 0.88, 519 stability_score_thresh: float = 0.95, 520 box_nms_thresh: float = 0.7, 521 crop_nms_thresh: float = 0.7, 522 min_mask_region_area: int = 0, 523 output_mode: str = "binary_mask", 524 ) -> List[Dict[str, Any]]: 525 """Generate instance segmentation for the currently initialized image. 526 527 Args: 528 pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model. 529 By default, set to '0.88'. 530 stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask 531 under changes to the cutoff used to binarize the model prediction. By default, set to '0.95'. 532 box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks. 533 By default, set to '0.7'. 534 crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops. 535 By default, set to '0.7'. 536 min_mask_region_area: Minimal size for the predicted masks. By default, set to '0'. 537 output_mode: The form masks are returned in. By default, set to 'binary_mask'. 538 539 Returns: 540 The instance segmentation masks. 541 """ 542 if not self.is_initialized: 543 raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.") 544 545 data = amg_utils.MaskData() 546 for data_, crop_box in zip(self.crop_list, self.crop_boxes): 547 crop_data = self._postprocess_batch( 548 data=deepcopy(data_), 549 crop_box=crop_box, original_size=self.original_size, 550 pred_iou_thresh=pred_iou_thresh, 551 stability_score_thresh=stability_score_thresh, 552 box_nms_thresh=box_nms_thresh 553 ) 554 data.cat(crop_data) 555 556 if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0: 557 # Prefer masks from smaller crops 558 scores = 1 / box_area(data["crop_boxes"]) 559 scores = scores.to(data["boxes"].device) 560 keep_by_nms = batched_nms( 561 data["boxes"].float(), 562 scores, 563 torch.zeros_like(data["boxes"][:, 0]), # categories 564 iou_threshold=crop_nms_thresh, 565 ) 566 data.filter(keep_by_nms) 567 568 data.to_numpy() 569 masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode) 570 return masks
Generate instance segmentation for the currently initialized image.
Arguments:
- pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model. By default, set to '0.88'.
- stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask under changes to the cutoff used to binarize the model prediction. By default, set to '0.95'.
- box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks. By default, set to '0.7'.
- crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops. By default, set to '0.7'.
- min_mask_region_area: Minimal size for the predicted masks. By default, set to '0'.
- output_mode: The form masks are returned in. By default, set to 'binary_mask'.
Returns:
The instance segmentation masks.
Inherited Members
600class TiledAutomaticMaskGenerator(AutomaticMaskGenerator): 601 """Generates an instance segmentation without prompts, using a point grid. 602 603 Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings. 604 605 Args: 606 predictor: The Segment Anything predictor. 607 points_per_side: The number of points to be sampled along one side of the image. 608 If None, `point_grids` must provide explicit point sampling. By default, set to '32'. 609 points_per_batch: The number of points run simultaneously by the model. 610 Higher numbers may be faster but use more GPU memory. By default, set to '64'. 611 point_grids: A list over explicit grids of points used for sampling masks. 612 Normalized to [0, 1] with respect to the image coordinate system. 613 stability_score_offset: The amount to shift the cutoff when calculating the stability score. 614 By default, set to '1.0'. 615 """ 616 617 # We only expose the arguments that make sense for the tiled mask generator. 618 # Anything related to crops doesn't make sense, because we re-use that functionality 619 # for tiling, so these parameters wouldn't have any effect. 620 def __init__( 621 self, 622 predictor: SamPredictor, 623 points_per_side: Optional[int] = 32, 624 points_per_batch: int = 64, 625 point_grids: Optional[List[np.ndarray]] = None, 626 stability_score_offset: float = 1.0, 627 ) -> None: 628 super().__init__( 629 predictor=predictor, 630 points_per_side=points_per_side, 631 points_per_batch=points_per_batch, 632 point_grids=point_grids, 633 stability_score_offset=stability_score_offset, 634 ) 635 636 @torch.no_grad() 637 def initialize( 638 self, 639 image: np.ndarray, 640 image_embeddings: Optional[util.ImageEmbeddings] = None, 641 i: Optional[int] = None, 642 tile_shape: Optional[Tuple[int, int]] = None, 643 halo: Optional[Tuple[int, int]] = None, 644 verbose: bool = False, 645 pbar_init: Optional[callable] = None, 646 pbar_update: Optional[callable] = None, 647 batch_size: int = 1, 648 ) -> None: 649 """Initialize image embeddings and masks for an image. 650 651 Args: 652 image: The input image, volume or timeseries. 653 image_embeddings: Optional precomputed image embeddings. 654 See `util.precompute_image_embeddings` for details. 655 i: Index for the image data. Required if `image` has three spatial dimensions 656 or a time dimension and two spatial dimensions. 657 tile_shape: The tile shape for embedding prediction. 658 halo: The overlap of between tiles. 659 verbose: Whether to print computation progress. By default, set to 'False'. 660 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 661 Can be used together with pbar_update to handle napari progress bar in other thread. 662 To enables using this function within a threadworker. 663 pbar_update: Callback to update an external progress bar. 664 batch_size: The batch size for image embedding prediction. By default, set to '1'. 665 """ 666 original_size = image.shape[:2] 667 self._original_size = original_size 668 669 image_embeddings, tile_shape, halo = _process_tiled_embeddings( 670 self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose, batch_size=batch_size 671 ) 672 673 tiling = blocking([0, 0], original_size, tile_shape) 674 n_tiles = tiling.numberOfBlocks 675 676 # The crop box is always the full local tile. 677 tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)] 678 crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles] 679 680 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 681 pbar_init(n_tiles, "Compute masks for tile") 682 683 # We need to cast to the image representation that is compatible with SAM. 684 image = util._to_image(image) 685 686 mask_data = [] 687 for tile_id in range(n_tiles): 688 # set the pre-computed embeddings for this tile 689 features = image_embeddings["features"][str(tile_id)] 690 tile_embeddings = { 691 "features": features, 692 "input_size": features.attrs["input_size"], 693 "original_size": features.attrs["original_size"], 694 } 695 util.set_precomputed(self._predictor, tile_embeddings, i) 696 697 # compute the mask data for this tile and append it 698 this_mask_data = self._process_crop( 699 image, crop_box=crop_boxes[tile_id], crop_layer_idx=0, precomputed_embeddings=True 700 ) 701 mask_data.append(this_mask_data) 702 pbar_update(1) 703 pbar_close() 704 705 # set the initialized data 706 self._is_initialized = True 707 self._crop_list = mask_data 708 self._crop_boxes = crop_boxes
Generates an instance segmentation without prompts, using a point grid.
Implements the same functionality as AutomaticMaskGenerator
but for tiled embeddings.
Arguments:
- predictor: The Segment Anything predictor.
- points_per_side: The number of points to be sampled along one side of the image.
If None,
point_grids
must provide explicit point sampling. By default, set to '32'. - points_per_batch: The number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory. By default, set to '64'.
- point_grids: A list over explicit grids of points used for sampling masks. Normalized to [0, 1] with respect to the image coordinate system.
- stability_score_offset: The amount to shift the cutoff when calculating the stability score. By default, set to '1.0'.
620 def __init__( 621 self, 622 predictor: SamPredictor, 623 points_per_side: Optional[int] = 32, 624 points_per_batch: int = 64, 625 point_grids: Optional[List[np.ndarray]] = None, 626 stability_score_offset: float = 1.0, 627 ) -> None: 628 super().__init__( 629 predictor=predictor, 630 points_per_side=points_per_side, 631 points_per_batch=points_per_batch, 632 point_grids=point_grids, 633 stability_score_offset=stability_score_offset, 634 )
636 @torch.no_grad() 637 def initialize( 638 self, 639 image: np.ndarray, 640 image_embeddings: Optional[util.ImageEmbeddings] = None, 641 i: Optional[int] = None, 642 tile_shape: Optional[Tuple[int, int]] = None, 643 halo: Optional[Tuple[int, int]] = None, 644 verbose: bool = False, 645 pbar_init: Optional[callable] = None, 646 pbar_update: Optional[callable] = None, 647 batch_size: int = 1, 648 ) -> None: 649 """Initialize image embeddings and masks for an image. 650 651 Args: 652 image: The input image, volume or timeseries. 653 image_embeddings: Optional precomputed image embeddings. 654 See `util.precompute_image_embeddings` for details. 655 i: Index for the image data. Required if `image` has three spatial dimensions 656 or a time dimension and two spatial dimensions. 657 tile_shape: The tile shape for embedding prediction. 658 halo: The overlap of between tiles. 659 verbose: Whether to print computation progress. By default, set to 'False'. 660 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 661 Can be used together with pbar_update to handle napari progress bar in other thread. 662 To enables using this function within a threadworker. 663 pbar_update: Callback to update an external progress bar. 664 batch_size: The batch size for image embedding prediction. By default, set to '1'. 665 """ 666 original_size = image.shape[:2] 667 self._original_size = original_size 668 669 image_embeddings, tile_shape, halo = _process_tiled_embeddings( 670 self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose, batch_size=batch_size 671 ) 672 673 tiling = blocking([0, 0], original_size, tile_shape) 674 n_tiles = tiling.numberOfBlocks 675 676 # The crop box is always the full local tile. 677 tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)] 678 crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles] 679 680 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 681 pbar_init(n_tiles, "Compute masks for tile") 682 683 # We need to cast to the image representation that is compatible with SAM. 684 image = util._to_image(image) 685 686 mask_data = [] 687 for tile_id in range(n_tiles): 688 # set the pre-computed embeddings for this tile 689 features = image_embeddings["features"][str(tile_id)] 690 tile_embeddings = { 691 "features": features, 692 "input_size": features.attrs["input_size"], 693 "original_size": features.attrs["original_size"], 694 } 695 util.set_precomputed(self._predictor, tile_embeddings, i) 696 697 # compute the mask data for this tile and append it 698 this_mask_data = self._process_crop( 699 image, crop_box=crop_boxes[tile_id], crop_layer_idx=0, precomputed_embeddings=True 700 ) 701 mask_data.append(this_mask_data) 702 pbar_update(1) 703 pbar_close() 704 705 # set the initialized data 706 self._is_initialized = True 707 self._crop_list = mask_data 708 self._crop_boxes = crop_boxes
Initialize image embeddings and masks for an image.
Arguments:
- image: The input image, volume or timeseries.
- image_embeddings: Optional precomputed image embeddings.
See
util.precompute_image_embeddings
for details. - i: Index for the image data. Required if
image
has three spatial dimensions or a time dimension and two spatial dimensions. - tile_shape: The tile shape for embedding prediction.
- halo: The overlap of between tiles.
- verbose: Whether to print computation progress. By default, set to 'False'.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
- batch_size: The batch size for image embedding prediction. By default, set to '1'.
716class DecoderAdapter(torch.nn.Module): 717 """Adapter to contain the UNETR decoder in a single module. 718 719 To apply the decoder on top of pre-computed embeddings for the segmentation functionality. 720 See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py 721 """ 722 def __init__(self, unetr: torch.nn.Module): 723 super().__init__() 724 725 self.base = unetr.base 726 self.out_conv = unetr.out_conv 727 self.deconv_out = unetr.deconv_out 728 self.decoder_head = unetr.decoder_head 729 self.final_activation = unetr.final_activation 730 self.postprocess_masks = unetr.postprocess_masks 731 732 self.decoder = unetr.decoder 733 self.deconv1 = unetr.deconv1 734 self.deconv2 = unetr.deconv2 735 self.deconv3 = unetr.deconv3 736 self.deconv4 = unetr.deconv4 737 738 def _forward_impl(self, input_): 739 z12 = input_ 740 741 z9 = self.deconv1(z12) 742 z6 = self.deconv2(z9) 743 z3 = self.deconv3(z6) 744 z0 = self.deconv4(z3) 745 746 updated_from_encoder = [z9, z6, z3] 747 748 x = self.base(z12) 749 x = self.decoder(x, encoder_inputs=updated_from_encoder) 750 x = self.deconv_out(x) 751 752 x = torch.cat([x, z0], dim=1) 753 x = self.decoder_head(x) 754 755 x = self.out_conv(x) 756 if self.final_activation is not None: 757 x = self.final_activation(x) 758 return x 759 760 def forward(self, input_, input_shape, original_shape): 761 x = self._forward_impl(input_) 762 x = self.postprocess_masks(x, input_shape, original_shape) 763 return x
Adapter to contain the UNETR decoder in a single module.
To apply the decoder on top of pre-computed embeddings for the segmentation functionality. See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py
722 def __init__(self, unetr: torch.nn.Module): 723 super().__init__() 724 725 self.base = unetr.base 726 self.out_conv = unetr.out_conv 727 self.deconv_out = unetr.deconv_out 728 self.decoder_head = unetr.decoder_head 729 self.final_activation = unetr.final_activation 730 self.postprocess_masks = unetr.postprocess_masks 731 732 self.decoder = unetr.decoder 733 self.deconv1 = unetr.deconv1 734 self.deconv2 = unetr.deconv2 735 self.deconv3 = unetr.deconv3 736 self.deconv4 = unetr.deconv4
Initialize internal Module state, shared by both nn.Module and ScriptModule.
760 def forward(self, input_, input_shape, original_shape): 761 x = self._forward_impl(input_) 762 x = self.postprocess_masks(x, input_shape, original_shape) 763 return x
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- set_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- mtia
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_post_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_pre_hook
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
766def get_unetr( 767 image_encoder: torch.nn.Module, 768 decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None, 769 device: Optional[Union[str, torch.device]] = None, 770 out_channels: int = 3, 771 flexible_load_checkpoint: bool = False, 772) -> torch.nn.Module: 773 """Get UNETR model for automatic instance segmentation. 774 775 Args: 776 image_encoder: The image encoder of the SAM model. 777 This is used as encoder by the UNETR too. 778 decoder_state: Optional decoder state to initialize the weights of the UNETR decoder. 779 device: The device. By default, automatically chooses the best available device. 780 out_channels: The number of output channels. By default, set to '3'. 781 flexible_load_checkpoint: Whether to allow reinitialization of parameters 782 which could not be found in the provided decoder state. By default, set to 'False'. 783 784 Returns: 785 The UNETR model. 786 """ 787 device = util.get_device(device) 788 789 if decoder_state is None: 790 use_conv_transpose = False # By default, we use interpolation for upsampling. 791 else: 792 # From the provided pretrained 'decoder_state', we check whether it uses transposed convolutions. 793 # NOTE: Explanation to the logic below - 794 # - We do this by looking for parameter names that contain '.block.' within the "decoder.samplers" 795 # submodules. This naming convention indicates that transposed convolutions are used, 796 # wrapped inside a custom block module. 797 # - Otherwise '.conv.' appears. It indicates a standard `Conv2d` applied after interpolation for upsampling. 798 use_conv_transpose = any(".block." in k for k in decoder_state.keys() if k.startswith("decoder.samplers")) 799 800 unetr = UNETR( 801 backbone="sam", 802 encoder=image_encoder, 803 out_channels=out_channels, 804 use_sam_stats=True, 805 final_activation="Sigmoid", 806 use_skip_connection=False, 807 resize_input=True, 808 use_conv_transpose=use_conv_transpose, 809 810 ) 811 812 if decoder_state is not None: 813 unetr_state_dict = unetr.state_dict() 814 for k, v in unetr_state_dict.items(): 815 if not k.startswith("encoder"): 816 if flexible_load_checkpoint: # Whether allow reinitalization of params, if not found. 817 if k in decoder_state: # First check whether the key is available in the provided decoder state. 818 unetr_state_dict[k] = decoder_state[k] 819 else: # Otherwise, allow it to initialize it. 820 warnings.warn(f"Could not find '{k}' in the pretrained state dict. Hence, we reinitialize it.") 821 unetr_state_dict[k] = v 822 823 else: # Whether be strict on finding the parameter in the decoder state. 824 if k not in decoder_state: 825 raise RuntimeError(f"The parameters for '{k}' could not be found.") 826 unetr_state_dict[k] = decoder_state[k] 827 828 unetr.load_state_dict(unetr_state_dict) 829 830 unetr.to(device) 831 return unetr
Get UNETR model for automatic instance segmentation.
Arguments:
- image_encoder: The image encoder of the SAM model. This is used as encoder by the UNETR too.
- decoder_state: Optional decoder state to initialize the weights of the UNETR decoder.
- device: The device. By default, automatically chooses the best available device.
- out_channels: The number of output channels. By default, set to '3'.
- flexible_load_checkpoint: Whether to allow reinitialization of parameters which could not be found in the provided decoder state. By default, set to 'False'.
Returns:
The UNETR model.
834def get_decoder( 835 image_encoder: torch.nn.Module, 836 decoder_state: OrderedDict[str, torch.Tensor], 837 device: Optional[Union[str, torch.device]] = None, 838) -> DecoderAdapter: 839 """Get decoder to predict outputs for automatic instance segmentation 840 841 Args: 842 image_encoder: The image encoder of the SAM model. 843 decoder_state: State to initialize the weights of the UNETR decoder. 844 device: The device. By default, automatically chooses the best available device. 845 846 Returns: 847 The decoder for instance segmentation. 848 """ 849 unetr = get_unetr(image_encoder, decoder_state, device) 850 return DecoderAdapter(unetr)
Get decoder to predict outputs for automatic instance segmentation
Arguments:
- image_encoder: The image encoder of the SAM model.
- decoder_state: State to initialize the weights of the UNETR decoder.
- device: The device. By default, automatically chooses the best available device.
Returns:
The decoder for instance segmentation.
853def get_predictor_and_decoder( 854 model_type: str, 855 checkpoint_path: Union[str, os.PathLike], 856 device: Optional[Union[str, torch.device]] = None, 857 peft_kwargs: Optional[Dict] = None, 858) -> Tuple[SamPredictor, DecoderAdapter]: 859 """Load the SAM model (predictor) and instance segmentation decoder. 860 861 This requires a checkpoint that contains the state for both predictor 862 and decoder. 863 864 Args: 865 model_type: The type of the image encoder used in the SAM model. 866 checkpoint_path: Path to the checkpoint from which to load the data. 867 device: The device. By default, automatically chooses the best available device. 868 peft_kwargs: Keyword arguments for the PEFT wrapper class. 869 870 Returns: 871 The SAM predictor. 872 The decoder for instance segmentation. 873 """ 874 device = util.get_device(device) 875 predictor, state = util.get_sam_model( 876 model_type=model_type, 877 checkpoint_path=checkpoint_path, 878 device=device, 879 return_state=True, 880 peft_kwargs=peft_kwargs, 881 ) 882 883 if "decoder_state" not in state: 884 raise ValueError( 885 f"The checkpoint at '{checkpoint_path}' or the chosen model '{model_type}' does not contain a decoder state" 886 ) 887 888 decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device) 889 return predictor, decoder
Load the SAM model (predictor) and instance segmentation decoder.
This requires a checkpoint that contains the state for both predictor and decoder.
Arguments:
- model_type: The type of the image encoder used in the SAM model.
- checkpoint_path: Path to the checkpoint from which to load the data.
- device: The device. By default, automatically chooses the best available device.
- peft_kwargs: Keyword arguments for the PEFT wrapper class.
Returns:
The SAM predictor. The decoder for instance segmentation.
941class InstanceSegmentationWithDecoder: 942 """Generates an instance segmentation without prompts, using a decoder. 943 944 Implements the same interface as `AutomaticMaskGenerator`. 945 946 Use this class as follows: 947 ```python 948 segmenter = InstanceSegmentationWithDecoder(predictor, decoder) 949 segmenter.initialize(image) # Predict the image embeddings and decoder outputs. 950 masks = segmenter.generate(center_distance_threshold=0.75) # Generate the instance segmentation. 951 ``` 952 953 Args: 954 predictor: The segment anything predictor. 955 decoder: The decoder to predict intermediate representations 956 for instance segmentation. 957 """ 958 def __init__(self, predictor: SamPredictor, decoder: torch.nn.Module) -> None: 959 self._predictor = predictor 960 self._decoder = decoder 961 962 # The decoder outputs. 963 self._foreground = None 964 self._center_distances = None 965 self._boundary_distances = None 966 967 self._is_initialized = False 968 969 @property 970 def is_initialized(self): 971 """Whether the mask generator has already been initialized. 972 """ 973 return self._is_initialized 974 975 @torch.no_grad() 976 def initialize( 977 self, 978 image: np.ndarray, 979 image_embeddings: Optional[util.ImageEmbeddings] = None, 980 i: Optional[int] = None, 981 verbose: bool = False, 982 pbar_init: Optional[callable] = None, 983 pbar_update: Optional[callable] = None, 984 ) -> None: 985 """Initialize image embeddings and decoder predictions for an image. 986 987 Args: 988 image: The input image, volume or timeseries. 989 image_embeddings: Optional precomputed image embeddings. 990 See `util.precompute_image_embeddings` for details. 991 i: Index for the image data. Required if `image` has three spatial dimensions 992 or a time dimension and two spatial dimensions. 993 verbose: Whether to be verbose. By default, set to 'False'. 994 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 995 Can be used together with pbar_update to handle napari progress bar in other thread. 996 To enables using this function within a threadworker. 997 pbar_update: Callback to update an external progress bar. 998 """ 999 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 1000 pbar_init(1, "Initialize instance segmentation with decoder") 1001 1002 if image_embeddings is None: 1003 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 1004 1005 # Get the image embeddings from the predictor. 1006 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i) 1007 embeddings = self._predictor.features 1008 input_shape = tuple(self._predictor.input_size) 1009 original_shape = tuple(self._predictor.original_size) 1010 1011 # Run prediction with the UNETR decoder. 1012 output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0) 1013 assert output.shape[0] == 3, f"{output.shape}" 1014 pbar_update(1) 1015 pbar_close() 1016 1017 # Set the state. 1018 self._foreground = output[0] 1019 self._center_distances = output[1] 1020 self._boundary_distances = output[2] 1021 self._is_initialized = True 1022 1023 def _to_masks(self, segmentation, output_mode): 1024 if output_mode != "binary_mask": 1025 raise NotImplementedError 1026 1027 props = regionprops(segmentation) 1028 ndim = segmentation.ndim 1029 assert ndim in (2, 3) 1030 1031 shape = segmentation.shape 1032 if ndim == 2: 1033 crop_box = [0, shape[1], 0, shape[0]] 1034 else: 1035 crop_box = [0, shape[2], 0, shape[1], 0, shape[0]] 1036 1037 # go from skimage bbox in format [y0, x0, y1, x1] to SAM format [x0, w, y0, h] 1038 def to_bbox_2d(bbox): 1039 y0, x0 = bbox[0], bbox[1] 1040 w = bbox[3] - x0 1041 h = bbox[2] - y0 1042 return [x0, w, y0, h] 1043 1044 def to_bbox_3d(bbox): 1045 z0, y0, x0 = bbox[0], bbox[1], bbox[2] 1046 w = bbox[5] - x0 1047 h = bbox[4] - y0 1048 d = bbox[3] - y0 1049 return [x0, w, y0, h, z0, d] 1050 1051 to_bbox = to_bbox_2d if ndim == 2 else to_bbox_3d 1052 masks = [ 1053 { 1054 "segmentation": segmentation == prop.label, 1055 "area": prop.area, 1056 "bbox": to_bbox(prop.bbox), 1057 "crop_box": crop_box, 1058 "seg_id": prop.label, 1059 } for prop in props 1060 ] 1061 return masks 1062 1063 def generate( 1064 self, 1065 center_distance_threshold: float = 0.5, 1066 boundary_distance_threshold: float = 0.5, 1067 foreground_threshold: float = 0.5, 1068 foreground_smoothing: float = 1.0, 1069 distance_smoothing: float = 1.6, 1070 min_size: int = 0, 1071 output_mode: Optional[str] = "binary_mask", 1072 tile_shape: Optional[Tuple[int, int]] = None, 1073 halo: Optional[Tuple[int, int]] = None, 1074 n_threads: Optional[int] = None, 1075 ) -> List[Dict[str, Any]]: 1076 """Generate instance segmentation for the currently initialized image. 1077 1078 Args: 1079 center_distance_threshold: Center distance predictions below this value will be 1080 used to find seeds (intersected with thresholded boundary distance predictions). 1081 By default, set to '0.5'. 1082 boundary_distance_threshold: Boundary distance predictions below this value will be 1083 used to find seeds (intersected with thresholded center distance predictions). 1084 By default, set to '0.5'. 1085 foreground_threshold: Foreground predictions above this value will be used as foreground mask. 1086 By default, set to '0.5'. 1087 foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid 1088 checkerboard artifacts in the prediction. By default, set to '1.0'. 1089 distance_smoothing: Sigma value for smoothing the distance predictions. 1090 min_size: Minimal object size in the segmentation result. By default, set to '0'. 1091 output_mode: The form masks are returned in. Pass None to directly return the instance segmentation. 1092 By default, set to 'binary_mask'. 1093 tile_shape: Tile shape for parallelizing the instance segmentation post-processing. 1094 This parameter is independent from the tile shape for computing the embeddings. 1095 If not given then post-processing will not be parallelized. 1096 halo: Halo for parallel post-processing. See also `tile_shape`. 1097 n_threads: Number of threads for parallel post-processing. See also `tile_shape`. 1098 1099 Returns: 1100 The instance segmentation masks. 1101 """ 1102 if not self.is_initialized: 1103 raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.") 1104 1105 if foreground_smoothing > 0: 1106 foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing) 1107 else: 1108 foreground = self._foreground 1109 1110 if tile_shape is None: 1111 segmentation = watershed_from_center_and_boundary_distances( 1112 center_distances=self._center_distances, 1113 boundary_distances=self._boundary_distances, 1114 foreground_map=foreground, 1115 center_distance_threshold=center_distance_threshold, 1116 boundary_distance_threshold=boundary_distance_threshold, 1117 foreground_threshold=foreground_threshold, 1118 distance_smoothing=distance_smoothing, 1119 min_size=min_size, 1120 ) 1121 else: 1122 if halo is None: 1123 raise ValueError("You must pass a value for halo if tile_shape is given.") 1124 segmentation = _watershed_from_center_and_boundary_distances_parallel( 1125 center_distances=self._center_distances, 1126 boundary_distances=self._boundary_distances, 1127 foreground_map=foreground, 1128 center_distance_threshold=center_distance_threshold, 1129 boundary_distance_threshold=boundary_distance_threshold, 1130 foreground_threshold=foreground_threshold, 1131 distance_smoothing=distance_smoothing, 1132 min_size=min_size, 1133 tile_shape=tile_shape, 1134 halo=halo, 1135 n_threads=n_threads, 1136 verbose=False, 1137 ) 1138 1139 if output_mode is not None: 1140 segmentation = self._to_masks(segmentation, output_mode) 1141 return segmentation 1142 1143 def get_state(self) -> Dict[str, Any]: 1144 """Get the initialized state of the instance segmenter. 1145 1146 Returns: 1147 Instance segmentation state. 1148 """ 1149 if not self.is_initialized: 1150 raise RuntimeError("The state has not been computed yet. Call initialize first.") 1151 1152 return { 1153 "foreground": self._foreground, 1154 "center_distances": self._center_distances, 1155 "boundary_distances": self._boundary_distances, 1156 } 1157 1158 def set_state(self, state: Dict[str, Any]) -> None: 1159 """Set the state of the instance segmenter. 1160 1161 Args: 1162 state: The instance segmentation state 1163 """ 1164 self._foreground = state["foreground"] 1165 self._center_distances = state["center_distances"] 1166 self._boundary_distances = state["boundary_distances"] 1167 self._is_initialized = True 1168 1169 def clear_state(self): 1170 """Clear the state of the instance segmenter. 1171 """ 1172 self._foreground = None 1173 self._center_distances = None 1174 self._boundary_distances = None 1175 self._is_initialized = False
Generates an instance segmentation without prompts, using a decoder.
Implements the same interface as AutomaticMaskGenerator
.
Use this class as follows:
segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
segmenter.initialize(image) # Predict the image embeddings and decoder outputs.
masks = segmenter.generate(center_distance_threshold=0.75) # Generate the instance segmentation.
Arguments:
- predictor: The segment anything predictor.
- decoder: The decoder to predict intermediate representations for instance segmentation.
958 def __init__(self, predictor: SamPredictor, decoder: torch.nn.Module) -> None: 959 self._predictor = predictor 960 self._decoder = decoder 961 962 # The decoder outputs. 963 self._foreground = None 964 self._center_distances = None 965 self._boundary_distances = None 966 967 self._is_initialized = False
969 @property 970 def is_initialized(self): 971 """Whether the mask generator has already been initialized. 972 """ 973 return self._is_initialized
Whether the mask generator has already been initialized.
975 @torch.no_grad() 976 def initialize( 977 self, 978 image: np.ndarray, 979 image_embeddings: Optional[util.ImageEmbeddings] = None, 980 i: Optional[int] = None, 981 verbose: bool = False, 982 pbar_init: Optional[callable] = None, 983 pbar_update: Optional[callable] = None, 984 ) -> None: 985 """Initialize image embeddings and decoder predictions for an image. 986 987 Args: 988 image: The input image, volume or timeseries. 989 image_embeddings: Optional precomputed image embeddings. 990 See `util.precompute_image_embeddings` for details. 991 i: Index for the image data. Required if `image` has three spatial dimensions 992 or a time dimension and two spatial dimensions. 993 verbose: Whether to be verbose. By default, set to 'False'. 994 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 995 Can be used together with pbar_update to handle napari progress bar in other thread. 996 To enables using this function within a threadworker. 997 pbar_update: Callback to update an external progress bar. 998 """ 999 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 1000 pbar_init(1, "Initialize instance segmentation with decoder") 1001 1002 if image_embeddings is None: 1003 image_embeddings = util.precompute_image_embeddings(self._predictor, image) 1004 1005 # Get the image embeddings from the predictor. 1006 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i) 1007 embeddings = self._predictor.features 1008 input_shape = tuple(self._predictor.input_size) 1009 original_shape = tuple(self._predictor.original_size) 1010 1011 # Run prediction with the UNETR decoder. 1012 output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0) 1013 assert output.shape[0] == 3, f"{output.shape}" 1014 pbar_update(1) 1015 pbar_close() 1016 1017 # Set the state. 1018 self._foreground = output[0] 1019 self._center_distances = output[1] 1020 self._boundary_distances = output[2] 1021 self._is_initialized = True
Initialize image embeddings and decoder predictions for an image.
Arguments:
- image: The input image, volume or timeseries.
- image_embeddings: Optional precomputed image embeddings.
See
util.precompute_image_embeddings
for details. - i: Index for the image data. Required if
image
has three spatial dimensions or a time dimension and two spatial dimensions. - verbose: Whether to be verbose. By default, set to 'False'.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
1063 def generate( 1064 self, 1065 center_distance_threshold: float = 0.5, 1066 boundary_distance_threshold: float = 0.5, 1067 foreground_threshold: float = 0.5, 1068 foreground_smoothing: float = 1.0, 1069 distance_smoothing: float = 1.6, 1070 min_size: int = 0, 1071 output_mode: Optional[str] = "binary_mask", 1072 tile_shape: Optional[Tuple[int, int]] = None, 1073 halo: Optional[Tuple[int, int]] = None, 1074 n_threads: Optional[int] = None, 1075 ) -> List[Dict[str, Any]]: 1076 """Generate instance segmentation for the currently initialized image. 1077 1078 Args: 1079 center_distance_threshold: Center distance predictions below this value will be 1080 used to find seeds (intersected with thresholded boundary distance predictions). 1081 By default, set to '0.5'. 1082 boundary_distance_threshold: Boundary distance predictions below this value will be 1083 used to find seeds (intersected with thresholded center distance predictions). 1084 By default, set to '0.5'. 1085 foreground_threshold: Foreground predictions above this value will be used as foreground mask. 1086 By default, set to '0.5'. 1087 foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid 1088 checkerboard artifacts in the prediction. By default, set to '1.0'. 1089 distance_smoothing: Sigma value for smoothing the distance predictions. 1090 min_size: Minimal object size in the segmentation result. By default, set to '0'. 1091 output_mode: The form masks are returned in. Pass None to directly return the instance segmentation. 1092 By default, set to 'binary_mask'. 1093 tile_shape: Tile shape for parallelizing the instance segmentation post-processing. 1094 This parameter is independent from the tile shape for computing the embeddings. 1095 If not given then post-processing will not be parallelized. 1096 halo: Halo for parallel post-processing. See also `tile_shape`. 1097 n_threads: Number of threads for parallel post-processing. See also `tile_shape`. 1098 1099 Returns: 1100 The instance segmentation masks. 1101 """ 1102 if not self.is_initialized: 1103 raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.") 1104 1105 if foreground_smoothing > 0: 1106 foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing) 1107 else: 1108 foreground = self._foreground 1109 1110 if tile_shape is None: 1111 segmentation = watershed_from_center_and_boundary_distances( 1112 center_distances=self._center_distances, 1113 boundary_distances=self._boundary_distances, 1114 foreground_map=foreground, 1115 center_distance_threshold=center_distance_threshold, 1116 boundary_distance_threshold=boundary_distance_threshold, 1117 foreground_threshold=foreground_threshold, 1118 distance_smoothing=distance_smoothing, 1119 min_size=min_size, 1120 ) 1121 else: 1122 if halo is None: 1123 raise ValueError("You must pass a value for halo if tile_shape is given.") 1124 segmentation = _watershed_from_center_and_boundary_distances_parallel( 1125 center_distances=self._center_distances, 1126 boundary_distances=self._boundary_distances, 1127 foreground_map=foreground, 1128 center_distance_threshold=center_distance_threshold, 1129 boundary_distance_threshold=boundary_distance_threshold, 1130 foreground_threshold=foreground_threshold, 1131 distance_smoothing=distance_smoothing, 1132 min_size=min_size, 1133 tile_shape=tile_shape, 1134 halo=halo, 1135 n_threads=n_threads, 1136 verbose=False, 1137 ) 1138 1139 if output_mode is not None: 1140 segmentation = self._to_masks(segmentation, output_mode) 1141 return segmentation
Generate instance segmentation for the currently initialized image.
Arguments:
- center_distance_threshold: Center distance predictions below this value will be used to find seeds (intersected with thresholded boundary distance predictions). By default, set to '0.5'.
- boundary_distance_threshold: Boundary distance predictions below this value will be used to find seeds (intersected with thresholded center distance predictions). By default, set to '0.5'.
- foreground_threshold: Foreground predictions above this value will be used as foreground mask. By default, set to '0.5'.
- foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid checkerboard artifacts in the prediction. By default, set to '1.0'.
- distance_smoothing: Sigma value for smoothing the distance predictions.
- min_size: Minimal object size in the segmentation result. By default, set to '0'.
- output_mode: The form masks are returned in. Pass None to directly return the instance segmentation. By default, set to 'binary_mask'.
- tile_shape: Tile shape for parallelizing the instance segmentation post-processing. This parameter is independent from the tile shape for computing the embeddings. If not given then post-processing will not be parallelized.
- halo: Halo for parallel post-processing. See also
tile_shape
. - n_threads: Number of threads for parallel post-processing. See also
tile_shape
.
Returns:
The instance segmentation masks.
1143 def get_state(self) -> Dict[str, Any]: 1144 """Get the initialized state of the instance segmenter. 1145 1146 Returns: 1147 Instance segmentation state. 1148 """ 1149 if not self.is_initialized: 1150 raise RuntimeError("The state has not been computed yet. Call initialize first.") 1151 1152 return { 1153 "foreground": self._foreground, 1154 "center_distances": self._center_distances, 1155 "boundary_distances": self._boundary_distances, 1156 }
Get the initialized state of the instance segmenter.
Returns:
Instance segmentation state.
1158 def set_state(self, state: Dict[str, Any]) -> None: 1159 """Set the state of the instance segmenter. 1160 1161 Args: 1162 state: The instance segmentation state 1163 """ 1164 self._foreground = state["foreground"] 1165 self._center_distances = state["center_distances"] 1166 self._boundary_distances = state["boundary_distances"] 1167 self._is_initialized = True
Set the state of the instance segmenter.
Arguments:
- state: The instance segmentation state
1178class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder): 1179 """Same as `InstanceSegmentationWithDecoder` but for tiled image embeddings. 1180 """ 1181 1182 # Apply the decoder in a batched fashion, and then perform the resizing independently per output. 1183 # This is necessary, because the individual tiles may have different tile shapes due to border tiles. 1184 def _predict_decoder(self, batched_embeddings, input_shapes, original_shapes): 1185 batched_embeddings = torch.cat(batched_embeddings) 1186 output = self._decoder._forward_impl(batched_embeddings) 1187 1188 batched_output = [] 1189 for x, input_shape, original_shape in zip(output, input_shapes, original_shapes): 1190 x = self._decoder.postprocess_masks(x.unsqueeze(0), input_shape, original_shape).squeeze(0) 1191 batched_output.append(x.cpu().numpy()) 1192 return batched_output 1193 1194 @torch.no_grad() 1195 def initialize( 1196 self, 1197 image: np.ndarray, 1198 image_embeddings: Optional[util.ImageEmbeddings] = None, 1199 i: Optional[int] = None, 1200 tile_shape: Optional[Tuple[int, int]] = None, 1201 halo: Optional[Tuple[int, int]] = None, 1202 verbose: bool = False, 1203 pbar_init: Optional[callable] = None, 1204 pbar_update: Optional[callable] = None, 1205 batch_size: int = 1, 1206 ) -> None: 1207 """Initialize image embeddings and decoder predictions for an image. 1208 1209 Args: 1210 image: The input image, volume or timeseries. 1211 image_embeddings: Optional precomputed image embeddings. 1212 See `util.precompute_image_embeddings` for details. 1213 i: Index for the image data. Required if `image` has three spatial dimensions 1214 or a time dimension and two spatial dimensions. 1215 tile_shape: Shape of the tiles for precomputing image embeddings. 1216 halo: Overlap of the tiles for tiled precomputation of image embeddings. 1217 verbose: Dummy input to be compatible with other function signatures. By default, set to 'False'. 1218 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 1219 Can be used together with pbar_update to handle napari progress bar in other thread. 1220 To enables using this function within a threadworker. 1221 pbar_update: Callback to update an external progress bar. 1222 batch_size: The batch size for image embedding computation and segmentation decoder prediction. 1223 """ 1224 original_size = image.shape[:2] 1225 image_embeddings, tile_shape, halo = _process_tiled_embeddings( 1226 self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose, batch_size=batch_size 1227 ) 1228 tiling = blocking([0, 0], original_size, tile_shape) 1229 1230 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 1231 pbar_init(tiling.numberOfBlocks, "Initialize tiled instance segmentation with decoder") 1232 1233 foreground = np.zeros(original_size, dtype="float32") 1234 center_distances = np.zeros(original_size, dtype="float32") 1235 boundary_distances = np.zeros(original_size, dtype="float32") 1236 1237 n_tiles = tiling.numberOfBlocks 1238 n_batches = int(np.ceil(n_tiles / batch_size)) 1239 1240 for batch_id in range(n_batches): 1241 tile_start = batch_id * batch_size 1242 tile_stop = min(tile_start + batch_size, n_tiles) 1243 1244 batched_embeddings, input_shapes, original_shapes = [], [], [] 1245 for tile_id in range(tile_start, tile_stop): 1246 # Get the image embeddings from the predictor for this tile. 1247 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i, tile_id=tile_id) 1248 1249 batched_embeddings.append(self._predictor.features) 1250 input_shapes.append(tuple(self._predictor.input_size)) 1251 original_shapes.append(tuple(self._predictor.original_size)) 1252 1253 batched_output = self._predict_decoder(batched_embeddings, input_shapes, original_shapes) 1254 1255 for output_id, tile_id in enumerate(range(tile_start, tile_stop)): 1256 output = batched_output[output_id] 1257 assert output.shape[0] == 3 1258 1259 # Set the predictions in the output for this tile. 1260 block = tiling.getBlockWithHalo(tile_id, halo=list(halo)) 1261 local_bb = tuple( 1262 slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end) 1263 ) 1264 inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end)) 1265 1266 foreground[inner_bb] = output[0][local_bb] 1267 center_distances[inner_bb] = output[1][local_bb] 1268 boundary_distances[inner_bb] = output[2][local_bb] 1269 pbar_update(1) 1270 1271 pbar_close() 1272 1273 # Set the state. 1274 self._foreground = foreground 1275 self._center_distances = center_distances 1276 self._boundary_distances = boundary_distances 1277 self._is_initialized = True
Same as InstanceSegmentationWithDecoder
but for tiled image embeddings.
1194 @torch.no_grad() 1195 def initialize( 1196 self, 1197 image: np.ndarray, 1198 image_embeddings: Optional[util.ImageEmbeddings] = None, 1199 i: Optional[int] = None, 1200 tile_shape: Optional[Tuple[int, int]] = None, 1201 halo: Optional[Tuple[int, int]] = None, 1202 verbose: bool = False, 1203 pbar_init: Optional[callable] = None, 1204 pbar_update: Optional[callable] = None, 1205 batch_size: int = 1, 1206 ) -> None: 1207 """Initialize image embeddings and decoder predictions for an image. 1208 1209 Args: 1210 image: The input image, volume or timeseries. 1211 image_embeddings: Optional precomputed image embeddings. 1212 See `util.precompute_image_embeddings` for details. 1213 i: Index for the image data. Required if `image` has three spatial dimensions 1214 or a time dimension and two spatial dimensions. 1215 tile_shape: Shape of the tiles for precomputing image embeddings. 1216 halo: Overlap of the tiles for tiled precomputation of image embeddings. 1217 verbose: Dummy input to be compatible with other function signatures. By default, set to 'False'. 1218 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 1219 Can be used together with pbar_update to handle napari progress bar in other thread. 1220 To enables using this function within a threadworker. 1221 pbar_update: Callback to update an external progress bar. 1222 batch_size: The batch size for image embedding computation and segmentation decoder prediction. 1223 """ 1224 original_size = image.shape[:2] 1225 image_embeddings, tile_shape, halo = _process_tiled_embeddings( 1226 self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose, batch_size=batch_size 1227 ) 1228 tiling = blocking([0, 0], original_size, tile_shape) 1229 1230 _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update) 1231 pbar_init(tiling.numberOfBlocks, "Initialize tiled instance segmentation with decoder") 1232 1233 foreground = np.zeros(original_size, dtype="float32") 1234 center_distances = np.zeros(original_size, dtype="float32") 1235 boundary_distances = np.zeros(original_size, dtype="float32") 1236 1237 n_tiles = tiling.numberOfBlocks 1238 n_batches = int(np.ceil(n_tiles / batch_size)) 1239 1240 for batch_id in range(n_batches): 1241 tile_start = batch_id * batch_size 1242 tile_stop = min(tile_start + batch_size, n_tiles) 1243 1244 batched_embeddings, input_shapes, original_shapes = [], [], [] 1245 for tile_id in range(tile_start, tile_stop): 1246 # Get the image embeddings from the predictor for this tile. 1247 self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i, tile_id=tile_id) 1248 1249 batched_embeddings.append(self._predictor.features) 1250 input_shapes.append(tuple(self._predictor.input_size)) 1251 original_shapes.append(tuple(self._predictor.original_size)) 1252 1253 batched_output = self._predict_decoder(batched_embeddings, input_shapes, original_shapes) 1254 1255 for output_id, tile_id in enumerate(range(tile_start, tile_stop)): 1256 output = batched_output[output_id] 1257 assert output.shape[0] == 3 1258 1259 # Set the predictions in the output for this tile. 1260 block = tiling.getBlockWithHalo(tile_id, halo=list(halo)) 1261 local_bb = tuple( 1262 slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end) 1263 ) 1264 inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end)) 1265 1266 foreground[inner_bb] = output[0][local_bb] 1267 center_distances[inner_bb] = output[1][local_bb] 1268 boundary_distances[inner_bb] = output[2][local_bb] 1269 pbar_update(1) 1270 1271 pbar_close() 1272 1273 # Set the state. 1274 self._foreground = foreground 1275 self._center_distances = center_distances 1276 self._boundary_distances = boundary_distances 1277 self._is_initialized = True
Initialize image embeddings and decoder predictions for an image.
Arguments:
- image: The input image, volume or timeseries.
- image_embeddings: Optional precomputed image embeddings.
See
util.precompute_image_embeddings
for details. - i: Index for the image data. Required if
image
has three spatial dimensions or a time dimension and two spatial dimensions. - tile_shape: Shape of the tiles for precomputing image embeddings.
- halo: Overlap of the tiles for tiled precomputation of image embeddings.
- verbose: Dummy input to be compatible with other function signatures. By default, set to 'False'.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
- batch_size: The batch size for image embedding computation and segmentation decoder prediction.
1280def get_amg( 1281 predictor: SamPredictor, is_tiled: bool, decoder: Optional[torch.nn.Module] = None, **kwargs, 1282) -> Union[AMGBase, InstanceSegmentationWithDecoder]: 1283 """Get the automatic mask generator class. 1284 1285 Args: 1286 predictor: The segment anything predictor. 1287 is_tiled: Whether tiled embeddings are used. 1288 decoder: Decoder to predict instacne segmmentation. 1289 kwargs: The keyword arguments for the amg class. 1290 1291 Returns: 1292 The automatic mask generator. 1293 """ 1294 if decoder is None: 1295 segmenter_class = TiledAutomaticMaskGenerator if is_tiled else AutomaticMaskGenerator 1296 segmenter = segmenter_class(predictor, **kwargs) 1297 else: 1298 segmenter_class = TiledInstanceSegmentationWithDecoder if is_tiled else InstanceSegmentationWithDecoder 1299 segmenter = segmenter_class(predictor, decoder, **kwargs) 1300 1301 return segmenter
Get the automatic mask generator class.
Arguments:
- predictor: The segment anything predictor.
- is_tiled: Whether tiled embeddings are used.
- decoder: Decoder to predict instacne segmmentation.
- kwargs: The keyword arguments for the amg class.
Returns:
The automatic mask generator.