micro_sam.visualization
Functionality for visualizing image embeddings.
1""" 2Functionality for visualizing image embeddings. 3""" 4from typing import Tuple 5 6import numpy as np 7from skimage.transform import resize 8 9from nifty.tools import blocking 10 11from elf.segmentation.embeddings import embedding_pca 12 13from .util import ImageEmbeddings 14 15 16# 17# PCA visualization for the image embeddings 18# 19 20def compute_pca(embeddings: np.ndarray) -> np.ndarray: 21 """Compute the pca projection of the embeddings to visualize them as RGB image. 22 23 Args: 24 embeddings: The embeddings. For example predicted by the SAM image encoder. 25 26 Returns: 27 PCA of the embeddings, mapped to the pixels. 28 """ 29 if embeddings.ndim == 4: 30 pca = embedding_pca(embeddings.squeeze()).transpose((1, 2, 0)) 31 elif embeddings.ndim == 5: 32 pca = [] 33 for embed in embeddings: 34 vis = embedding_pca(embed.squeeze()).transpose((1, 2, 0)) 35 pca.append(vis) 36 pca = np.stack(pca) 37 else: 38 raise ValueError(f"Expect input of ndim 4 or 5, got {embeddings.ndim}") 39 return pca 40 41 42def _get_crop(embed_shape, shape): 43 if shape[0] == shape[1]: # square image, we don't need to do anything 44 crop = np.s_[:, :, :] 45 elif shape[0] > shape[1]: 46 aspect_ratio = float(shape[1] / shape[0]) 47 crop = np.s_[:, :int(aspect_ratio * embed_shape[1])] 48 elif shape[1] > shape[0]: 49 aspect_ratio = float(shape[0] / shape[1]) 50 crop = np.s_[:int(aspect_ratio * embed_shape[0]), :] 51 return crop 52 53 54def _project_embeddings(embeddings, shape, apply_crop=True): 55 assert embeddings.ndim == len(shape) + 2, f"{embeddings.shape}, {shape}" 56 57 embedding_vis = compute_pca(embeddings) 58 if not apply_crop: 59 pass 60 elif len(shape) == 2: 61 crop = _get_crop(embedding_vis.shape, shape) 62 embedding_vis = embedding_vis[crop] 63 elif len(shape) == 3: 64 crop = _get_crop(embedding_vis.shape[1:], shape[1:]) 65 crop = (slice(None),) + crop 66 embedding_vis = embedding_vis[crop] 67 else: 68 raise ValueError(f"Expect 2d or 3d data, got {len(shape)}") 69 70 scale = tuple(float(sh) / vsh for sh, vsh in zip(shape, embedding_vis.shape)) 71 return embedding_vis, scale 72 73 74def _project_embeddings_to_tile(tile, tile_embeds): 75 outer_tile = tile.outerBlock 76 inner_tile_local = tile.innerBlockLocal 77 78 embed_shape = tile_embeds.shape[-2:] 79 outer_tile_shape = tuple(end - beg for beg, end in zip(outer_tile.begin, outer_tile.end)) 80 81 crop = _get_crop(embed_shape, outer_tile_shape) 82 crop = (tile_embeds.ndim - len(crop)) * (slice(None),) + crop 83 this_embeds = tile_embeds[crop] 84 85 tile_scale = tuple(esh / float(fsh) for esh, fsh in zip(this_embeds.shape[-2:], outer_tile_shape)) 86 tile_bb = tuple( 87 slice(int(np.round(beg * scale)), int(np.round(end * scale))) 88 for beg, end, scale in zip(inner_tile_local.begin, inner_tile_local.end, tile_scale) 89 ) 90 tile_bb = (tile_embeds.ndim - len(outer_tile_shape)) * (slice(None),) + tile_bb 91 92 this_embeds = this_embeds[tile_bb] 93 return this_embeds 94 95 96def _resize_and_cocatenate(arrays, axis): 97 assert axis in (-1, -2) 98 resize_axis = -1 if axis == -2 else -2 99 resize_len = max([arr.shape[resize_axis] for arr in arrays]) 100 101 def resize_shape(shape): 102 axis_ = arrays[0].ndim + resize_axis 103 return tuple(resize_len if i == axis_ else sh for i, sh in enumerate(shape)) 104 105 return np.concatenate( 106 [resize(arr, resize_shape(arr.shape)) for arr in arrays], 107 axis=axis 108 ) 109 110 111def _project_tiled_embeddings(image_embeddings): 112 features = image_embeddings["features"] 113 tile_shape, halo, shape = features.attrs["tile_shape"], features.attrs["halo"], features.attrs["shape"] 114 tiling = blocking([0, 0], shape, tile_shape) 115 116 tile_grid = tiling.blocksPerAxis 117 118 embeds = { 119 i: {j: None for j in range(tile_grid[1])} for i in range(tile_grid[0]) 120 } 121 122 for tile_id in range(tiling.numberOfBlocks): 123 tile_embeds = features[tile_id][:] 124 assert tile_embeds.ndim in (4, 5) 125 126 # extract the embeddings corresponding to the inner tile 127 tile = tiling.getBlockWithHalo(tile_id, list(halo)) 128 tile_coords = tiling.blockGridPosition(tile_id) 129 this_embeds = _project_embeddings_to_tile(tile, tile_embeds) 130 131 i, j = tile_coords 132 embeds[i][j] = this_embeds 133 134 embeds = _resize_and_cocatenate( 135 [ 136 _resize_and_cocatenate( 137 [embeds[i][j] for j in range(tile_grid[1])], axis=-1 138 ) 139 for i in range(tile_grid[0]) 140 ], axis=-2 141 ) 142 143 if features[0].ndim == 5: 144 shape = (features[0].shape[0],) + tuple(shape) 145 embedding_vis, scale = _project_embeddings(embeds, shape, apply_crop=False) 146 return embedding_vis, scale 147 148 149def project_embeddings_for_visualization( 150 image_embeddings: ImageEmbeddings 151) -> Tuple[np.ndarray, Tuple[float, ...]]: 152 """Project image embeddings to pixel-wise PCA. 153 154 Args: 155 image_embeddings: The image embeddings. 156 157 Returns: 158 The PCA of the embeddings. 159 The scale factor for resizing to the original image size. 160 """ 161 is_tiled = image_embeddings["input_size"] is None 162 if is_tiled: 163 embedding_vis, scale = _project_tiled_embeddings(image_embeddings) 164 else: 165 embeddings = image_embeddings["features"] 166 shape = tuple(image_embeddings["original_size"]) 167 if embeddings.ndim == 5: 168 shape = (embeddings.shape[0],) + shape 169 embedding_vis, scale = _project_embeddings(embeddings, shape) 170 return embedding_vis, scale
def
compute_pca(embeddings: numpy.ndarray) -> numpy.ndarray:
21def compute_pca(embeddings: np.ndarray) -> np.ndarray: 22 """Compute the pca projection of the embeddings to visualize them as RGB image. 23 24 Args: 25 embeddings: The embeddings. For example predicted by the SAM image encoder. 26 27 Returns: 28 PCA of the embeddings, mapped to the pixels. 29 """ 30 if embeddings.ndim == 4: 31 pca = embedding_pca(embeddings.squeeze()).transpose((1, 2, 0)) 32 elif embeddings.ndim == 5: 33 pca = [] 34 for embed in embeddings: 35 vis = embedding_pca(embed.squeeze()).transpose((1, 2, 0)) 36 pca.append(vis) 37 pca = np.stack(pca) 38 else: 39 raise ValueError(f"Expect input of ndim 4 or 5, got {embeddings.ndim}") 40 return pca
Compute the pca projection of the embeddings to visualize them as RGB image.
Arguments:
- embeddings: The embeddings. For example predicted by the SAM image encoder.
Returns:
PCA of the embeddings, mapped to the pixels.
def
project_embeddings_for_visualization( image_embeddings: Dict[str, Any]) -> Tuple[numpy.ndarray, Tuple[float, ...]]:
150def project_embeddings_for_visualization( 151 image_embeddings: ImageEmbeddings 152) -> Tuple[np.ndarray, Tuple[float, ...]]: 153 """Project image embeddings to pixel-wise PCA. 154 155 Args: 156 image_embeddings: The image embeddings. 157 158 Returns: 159 The PCA of the embeddings. 160 The scale factor for resizing to the original image size. 161 """ 162 is_tiled = image_embeddings["input_size"] is None 163 if is_tiled: 164 embedding_vis, scale = _project_tiled_embeddings(image_embeddings) 165 else: 166 embeddings = image_embeddings["features"] 167 shape = tuple(image_embeddings["original_size"]) 168 if embeddings.ndim == 5: 169 shape = (embeddings.shape[0],) + shape 170 embedding_vis, scale = _project_embeddings(embeddings, shape) 171 return embedding_vis, scale
Project image embeddings to pixel-wise PCA.
Arguments:
- image_embeddings: The image embeddings.
Returns:
The PCA of the embeddings. The scale factor for resizing to the original image size.