micro_sam.visualization
Functionality for visualizing image embeddings.
1"""Functionality for visualizing image embeddings. 2""" 3 4from typing import Tuple 5 6import numpy as np 7from skimage.transform import resize 8 9from nifty.tools import blocking 10 11from elf.segmentation.embeddings import embedding_pca 12 13from .util import ImageEmbeddings 14 15 16# 17# PCA visualization for the image embeddings 18# 19 20def compute_pca(embeddings: np.ndarray, n_components: int = 3, as_rgb: bool = True) -> np.ndarray: 21 """Compute the pca projection of the embeddings to visualize them as RGB image. 22 23 Args: 24 embeddings: The embeddings. For example predicted by the SAM image encoder. 25 n_components: The number of PCA components to use for dimensionality reduction. 26 as_rgb: Whether to normalize the projected embeddings so that they can be displated as rgb. 27 28 Returns: 29 PCA of the embeddings, mapped to the pixels. 30 """ 31 if embeddings.ndim == 4: 32 pca = embedding_pca(embeddings.squeeze(), n_components=n_components, as_rgb=as_rgb).transpose((1, 2, 0)) 33 elif embeddings.ndim == 5: 34 pca = [] 35 for embed in embeddings: 36 vis = embedding_pca(embed.squeeze(), n_components=n_components, as_rgb=as_rgb).transpose((1, 2, 0)) 37 pca.append(vis) 38 pca = np.stack(pca) 39 else: 40 raise ValueError(f"Expect input of ndim 4 or 5, got {embeddings.ndim}") 41 42 return pca 43 44 45def _get_crop(embed_shape, shape): 46 if shape[0] == shape[1]: # square image, we don't need to do anything 47 crop = np.s_[:, :, :] 48 elif shape[0] > shape[1]: 49 aspect_ratio = float(shape[1] / shape[0]) 50 crop = np.s_[:, :int(aspect_ratio * embed_shape[1])] 51 elif shape[1] > shape[0]: 52 aspect_ratio = float(shape[0] / shape[1]) 53 crop = np.s_[:int(aspect_ratio * embed_shape[0]), :] 54 55 return crop 56 57 58def _project_embeddings(embeddings, shape, apply_crop=True, n_components=3, as_rgb=True): 59 assert embeddings.ndim == len(shape) + 2, f"{embeddings.shape}, {shape}" 60 61 embedding_vis = compute_pca(embeddings, n_components=n_components, as_rgb=as_rgb) 62 if not apply_crop: 63 pass 64 elif len(shape) == 2: 65 crop = _get_crop(embedding_vis.shape, shape) 66 embedding_vis = embedding_vis[crop] 67 elif len(shape) == 3: 68 crop = _get_crop(embedding_vis.shape[1:], shape[1:]) 69 crop = (slice(None),) + crop 70 embedding_vis = embedding_vis[crop] 71 else: 72 raise ValueError(f"Expect 2d or 3d data, got {len(shape)}") 73 74 scale = tuple(float(sh) / vsh for sh, vsh in zip(shape, embedding_vis.shape)) 75 return embedding_vis, scale 76 77 78def _project_embeddings_to_tile(tile, tile_embeds): 79 outer_tile = tile.outerBlock 80 inner_tile_local = tile.innerBlockLocal 81 82 embed_shape = tile_embeds.shape[-2:] 83 outer_tile_shape = tuple(end - beg for beg, end in zip(outer_tile.begin, outer_tile.end)) 84 85 crop = _get_crop(embed_shape, outer_tile_shape) 86 crop = (tile_embeds.ndim - len(crop)) * (slice(None),) + crop 87 this_embeds = tile_embeds[crop] 88 89 tile_scale = tuple(esh / float(fsh) for esh, fsh in zip(this_embeds.shape[-2:], outer_tile_shape)) 90 tile_bb = tuple( 91 slice(int(np.round(beg * scale)), int(np.round(end * scale))) 92 for beg, end, scale in zip(inner_tile_local.begin, inner_tile_local.end, tile_scale) 93 ) 94 tile_bb = (tile_embeds.ndim - len(outer_tile_shape)) * (slice(None),) + tile_bb 95 96 this_embeds = this_embeds[tile_bb] 97 return this_embeds 98 99 100def _resize_and_cocatenate(arrays, axis): 101 assert axis in (-1, -2) 102 resize_axis = -1 if axis == -2 else -2 103 resize_len = max([arr.shape[resize_axis] for arr in arrays]) 104 105 def resize_shape(shape): 106 axis_ = arrays[0].ndim + resize_axis 107 return tuple(resize_len if i == axis_ else sh for i, sh in enumerate(shape)) 108 109 return np.concatenate([resize(arr, resize_shape(arr.shape)) for arr in arrays], axis=axis) 110 111 112def _project_tiled_embeddings(image_embeddings, n_components, as_rgb): 113 features = image_embeddings["features"] 114 tile_shape, halo, shape = features.attrs["tile_shape"], features.attrs["halo"], features.attrs["shape"] 115 tiling = blocking([0, 0], shape, tile_shape) 116 117 tile_grid = tiling.blocksPerAxis 118 119 embeds = { 120 i: {j: None for j in range(tile_grid[1])} for i in range(tile_grid[0]) 121 } 122 123 for tile_id in range(tiling.numberOfBlocks): 124 tile_embeds = features[str(tile_id)][:] 125 assert tile_embeds.ndim in (4, 5) 126 127 # extract the embeddings corresponding to the inner tile 128 tile = tiling.getBlockWithHalo(tile_id, list(halo)) 129 tile_coords = tiling.blockGridPosition(tile_id) 130 this_embeds = _project_embeddings_to_tile(tile, tile_embeds) 131 132 i, j = tile_coords 133 embeds[i][j] = this_embeds 134 135 embeds = _resize_and_cocatenate( 136 [ 137 _resize_and_cocatenate( 138 [embeds[i][j] for j in range(tile_grid[1])], axis=-1 139 ) 140 for i in range(tile_grid[0]) 141 ], axis=-2 142 ) 143 144 if features["0"].ndim == 5: 145 shape = (features["0"].shape[0],) + tuple(shape) 146 embedding_vis, scale = _project_embeddings( 147 embeds, shape, n_components=n_components, as_rgb=as_rgb, apply_crop=False 148 ) 149 return embedding_vis, scale 150 151 152def project_embeddings_for_visualization( 153 image_embeddings: ImageEmbeddings, n_components: int = 3, as_rgb: bool = True, 154) -> Tuple[np.ndarray, Tuple[float, ...]]: 155 """Project image embeddings to pixel-wise PCA. 156 157 Args: 158 image_embeddings: The image embeddings. 159 n_components: The number of PCA components to use for dimensionality reduction. 160 as_rgb: Whether to normalize the projected embeddings so that they can be displated as rgb. 161 162 Returns: 163 The PCA of the embeddings. 164 The scale factor for resizing to the original image size. 165 """ 166 is_tiled = image_embeddings["input_size"] is None 167 if is_tiled: 168 embedding_vis, scale = _project_tiled_embeddings(image_embeddings, n_components, as_rgb) 169 else: 170 embeddings = image_embeddings["features"] 171 shape = tuple(image_embeddings["original_size"]) 172 if embeddings.ndim == 5: 173 shape = (embeddings.shape[0],) + shape 174 embedding_vis, scale = _project_embeddings(embeddings, shape, n_components=n_components, as_rgb=as_rgb) 175 176 return embedding_vis, scale
def
compute_pca( embeddings: numpy.ndarray, n_components: int = 3, as_rgb: bool = True) -> numpy.ndarray:
21def compute_pca(embeddings: np.ndarray, n_components: int = 3, as_rgb: bool = True) -> np.ndarray: 22 """Compute the pca projection of the embeddings to visualize them as RGB image. 23 24 Args: 25 embeddings: The embeddings. For example predicted by the SAM image encoder. 26 n_components: The number of PCA components to use for dimensionality reduction. 27 as_rgb: Whether to normalize the projected embeddings so that they can be displated as rgb. 28 29 Returns: 30 PCA of the embeddings, mapped to the pixels. 31 """ 32 if embeddings.ndim == 4: 33 pca = embedding_pca(embeddings.squeeze(), n_components=n_components, as_rgb=as_rgb).transpose((1, 2, 0)) 34 elif embeddings.ndim == 5: 35 pca = [] 36 for embed in embeddings: 37 vis = embedding_pca(embed.squeeze(), n_components=n_components, as_rgb=as_rgb).transpose((1, 2, 0)) 38 pca.append(vis) 39 pca = np.stack(pca) 40 else: 41 raise ValueError(f"Expect input of ndim 4 or 5, got {embeddings.ndim}") 42 43 return pca
Compute the pca projection of the embeddings to visualize them as RGB image.
Arguments:
- embeddings: The embeddings. For example predicted by the SAM image encoder.
- n_components: The number of PCA components to use for dimensionality reduction.
- as_rgb: Whether to normalize the projected embeddings so that they can be displated as rgb.
Returns:
PCA of the embeddings, mapped to the pixels.
def
project_embeddings_for_visualization( image_embeddings: Dict[str, Any], n_components: int = 3, as_rgb: bool = True) -> Tuple[numpy.ndarray, Tuple[float, ...]]:
153def project_embeddings_for_visualization( 154 image_embeddings: ImageEmbeddings, n_components: int = 3, as_rgb: bool = True, 155) -> Tuple[np.ndarray, Tuple[float, ...]]: 156 """Project image embeddings to pixel-wise PCA. 157 158 Args: 159 image_embeddings: The image embeddings. 160 n_components: The number of PCA components to use for dimensionality reduction. 161 as_rgb: Whether to normalize the projected embeddings so that they can be displated as rgb. 162 163 Returns: 164 The PCA of the embeddings. 165 The scale factor for resizing to the original image size. 166 """ 167 is_tiled = image_embeddings["input_size"] is None 168 if is_tiled: 169 embedding_vis, scale = _project_tiled_embeddings(image_embeddings, n_components, as_rgb) 170 else: 171 embeddings = image_embeddings["features"] 172 shape = tuple(image_embeddings["original_size"]) 173 if embeddings.ndim == 5: 174 shape = (embeddings.shape[0],) + shape 175 embedding_vis, scale = _project_embeddings(embeddings, shape, n_components=n_components, as_rgb=as_rgb) 176 177 return embedding_vis, scale
Project image embeddings to pixel-wise PCA.
Arguments:
- image_embeddings: The image embeddings.
- n_components: The number of PCA components to use for dimensionality reduction.
- as_rgb: Whether to normalize the projected embeddings so that they can be displated as rgb.
Returns:
The PCA of the embeddings. The scale factor for resizing to the original image size.