micro_sam.visualization
Functionality for visualizing image embeddings.
1"""Functionality for visualizing image embeddings. 2""" 3 4from typing import Tuple 5 6import numpy as np 7from skimage.transform import resize 8 9from nifty.tools import blocking 10 11from elf.segmentation.embeddings import embedding_pca 12 13from .util import ImageEmbeddings 14 15 16# 17# PCA visualization for the image embeddings 18# 19 20def compute_pca(embeddings: np.ndarray) -> np.ndarray: 21 """Compute the pca projection of the embeddings to visualize them as RGB image. 22 23 Args: 24 embeddings: The embeddings. For example predicted by the SAM image encoder. 25 26 Returns: 27 PCA of the embeddings, mapped to the pixels. 28 """ 29 if embeddings.ndim == 4: 30 pca = embedding_pca(embeddings.squeeze()).transpose((1, 2, 0)) 31 elif embeddings.ndim == 5: 32 pca = [] 33 for embed in embeddings: 34 vis = embedding_pca(embed.squeeze()).transpose((1, 2, 0)) 35 pca.append(vis) 36 pca = np.stack(pca) 37 else: 38 raise ValueError(f"Expect input of ndim 4 or 5, got {embeddings.ndim}") 39 40 return pca 41 42 43def _get_crop(embed_shape, shape): 44 if shape[0] == shape[1]: # square image, we don't need to do anything 45 crop = np.s_[:, :, :] 46 elif shape[0] > shape[1]: 47 aspect_ratio = float(shape[1] / shape[0]) 48 crop = np.s_[:, :int(aspect_ratio * embed_shape[1])] 49 elif shape[1] > shape[0]: 50 aspect_ratio = float(shape[0] / shape[1]) 51 crop = np.s_[:int(aspect_ratio * embed_shape[0]), :] 52 53 return crop 54 55 56def _project_embeddings(embeddings, shape, apply_crop=True): 57 assert embeddings.ndim == len(shape) + 2, f"{embeddings.shape}, {shape}" 58 59 embedding_vis = compute_pca(embeddings) 60 if not apply_crop: 61 pass 62 elif len(shape) == 2: 63 crop = _get_crop(embedding_vis.shape, shape) 64 embedding_vis = embedding_vis[crop] 65 elif len(shape) == 3: 66 crop = _get_crop(embedding_vis.shape[1:], shape[1:]) 67 crop = (slice(None),) + crop 68 embedding_vis = embedding_vis[crop] 69 else: 70 raise ValueError(f"Expect 2d or 3d data, got {len(shape)}") 71 72 scale = tuple(float(sh) / vsh for sh, vsh in zip(shape, embedding_vis.shape)) 73 return embedding_vis, scale 74 75 76def _project_embeddings_to_tile(tile, tile_embeds): 77 outer_tile = tile.outerBlock 78 inner_tile_local = tile.innerBlockLocal 79 80 embed_shape = tile_embeds.shape[-2:] 81 outer_tile_shape = tuple(end - beg for beg, end in zip(outer_tile.begin, outer_tile.end)) 82 83 crop = _get_crop(embed_shape, outer_tile_shape) 84 crop = (tile_embeds.ndim - len(crop)) * (slice(None),) + crop 85 this_embeds = tile_embeds[crop] 86 87 tile_scale = tuple(esh / float(fsh) for esh, fsh in zip(this_embeds.shape[-2:], outer_tile_shape)) 88 tile_bb = tuple( 89 slice(int(np.round(beg * scale)), int(np.round(end * scale))) 90 for beg, end, scale in zip(inner_tile_local.begin, inner_tile_local.end, tile_scale) 91 ) 92 tile_bb = (tile_embeds.ndim - len(outer_tile_shape)) * (slice(None),) + tile_bb 93 94 this_embeds = this_embeds[tile_bb] 95 return this_embeds 96 97 98def _resize_and_cocatenate(arrays, axis): 99 assert axis in (-1, -2) 100 resize_axis = -1 if axis == -2 else -2 101 resize_len = max([arr.shape[resize_axis] for arr in arrays]) 102 103 def resize_shape(shape): 104 axis_ = arrays[0].ndim + resize_axis 105 return tuple(resize_len if i == axis_ else sh for i, sh in enumerate(shape)) 106 107 return np.concatenate([resize(arr, resize_shape(arr.shape)) for arr in arrays], axis=axis) 108 109 110def _project_tiled_embeddings(image_embeddings): 111 features = image_embeddings["features"] 112 tile_shape, halo, shape = features.attrs["tile_shape"], features.attrs["halo"], features.attrs["shape"] 113 tiling = blocking([0, 0], shape, tile_shape) 114 115 tile_grid = tiling.blocksPerAxis 116 117 embeds = { 118 i: {j: None for j in range(tile_grid[1])} for i in range(tile_grid[0]) 119 } 120 121 for tile_id in range(tiling.numberOfBlocks): 122 tile_embeds = features[tile_id][:] 123 assert tile_embeds.ndim in (4, 5) 124 125 # extract the embeddings corresponding to the inner tile 126 tile = tiling.getBlockWithHalo(tile_id, list(halo)) 127 tile_coords = tiling.blockGridPosition(tile_id) 128 this_embeds = _project_embeddings_to_tile(tile, tile_embeds) 129 130 i, j = tile_coords 131 embeds[i][j] = this_embeds 132 133 embeds = _resize_and_cocatenate( 134 [ 135 _resize_and_cocatenate( 136 [embeds[i][j] for j in range(tile_grid[1])], axis=-1 137 ) 138 for i in range(tile_grid[0]) 139 ], axis=-2 140 ) 141 142 if features[0].ndim == 5: 143 shape = (features[0].shape[0],) + tuple(shape) 144 embedding_vis, scale = _project_embeddings(embeds, shape, apply_crop=False) 145 return embedding_vis, scale 146 147 148def project_embeddings_for_visualization( 149 image_embeddings: ImageEmbeddings 150) -> Tuple[np.ndarray, Tuple[float, ...]]: 151 """Project image embeddings to pixel-wise PCA. 152 153 Args: 154 image_embeddings: The image embeddings. 155 156 Returns: 157 The PCA of the embeddings. 158 The scale factor for resizing to the original image size. 159 """ 160 is_tiled = image_embeddings["input_size"] is None 161 if is_tiled: 162 embedding_vis, scale = _project_tiled_embeddings(image_embeddings) 163 else: 164 embeddings = image_embeddings["features"] 165 shape = tuple(image_embeddings["original_size"]) 166 if embeddings.ndim == 5: 167 shape = (embeddings.shape[0],) + shape 168 embedding_vis, scale = _project_embeddings(embeddings, shape) 169 170 return embedding_vis, scale
def
compute_pca(embeddings: numpy.ndarray) -> numpy.ndarray:
21def compute_pca(embeddings: np.ndarray) -> np.ndarray: 22 """Compute the pca projection of the embeddings to visualize them as RGB image. 23 24 Args: 25 embeddings: The embeddings. For example predicted by the SAM image encoder. 26 27 Returns: 28 PCA of the embeddings, mapped to the pixels. 29 """ 30 if embeddings.ndim == 4: 31 pca = embedding_pca(embeddings.squeeze()).transpose((1, 2, 0)) 32 elif embeddings.ndim == 5: 33 pca = [] 34 for embed in embeddings: 35 vis = embedding_pca(embed.squeeze()).transpose((1, 2, 0)) 36 pca.append(vis) 37 pca = np.stack(pca) 38 else: 39 raise ValueError(f"Expect input of ndim 4 or 5, got {embeddings.ndim}") 40 41 return pca
Compute the pca projection of the embeddings to visualize them as RGB image.
Arguments:
- embeddings: The embeddings. For example predicted by the SAM image encoder.
Returns:
PCA of the embeddings, mapped to the pixels.
def
project_embeddings_for_visualization( image_embeddings: Dict[str, Any]) -> Tuple[numpy.ndarray, Tuple[float, ...]]:
149def project_embeddings_for_visualization( 150 image_embeddings: ImageEmbeddings 151) -> Tuple[np.ndarray, Tuple[float, ...]]: 152 """Project image embeddings to pixel-wise PCA. 153 154 Args: 155 image_embeddings: The image embeddings. 156 157 Returns: 158 The PCA of the embeddings. 159 The scale factor for resizing to the original image size. 160 """ 161 is_tiled = image_embeddings["input_size"] is None 162 if is_tiled: 163 embedding_vis, scale = _project_tiled_embeddings(image_embeddings) 164 else: 165 embeddings = image_embeddings["features"] 166 shape = tuple(image_embeddings["original_size"]) 167 if embeddings.ndim == 5: 168 shape = (embeddings.shape[0],) + shape 169 embedding_vis, scale = _project_embeddings(embeddings, shape) 170 171 return embedding_vis, scale
Project image embeddings to pixel-wise PCA.
Arguments:
- image_embeddings: The image embeddings.
Returns:
The PCA of the embeddings. The scale factor for resizing to the original image size.