micro_sam.visualization

Functionality for visualizing image embeddings.

  1"""
  2Functionality for visualizing image embeddings.
  3"""
  4from typing import Tuple
  5
  6import numpy as np
  7from skimage.transform import resize
  8
  9from nifty.tools import blocking
 10
 11from elf.segmentation.embeddings import embedding_pca
 12
 13from .util import ImageEmbeddings
 14
 15
 16#
 17# PCA visualization for the image embeddings
 18#
 19
 20def compute_pca(embeddings: np.ndarray) -> np.ndarray:
 21    """Compute the pca projection of the embeddings to visualize them as RGB image.
 22
 23    Args:
 24        embeddings: The embeddings. For example predicted by the SAM image encoder.
 25
 26    Returns:
 27        PCA of the embeddings, mapped to the pixels.
 28    """
 29    if embeddings.ndim == 4:
 30        pca = embedding_pca(embeddings.squeeze()).transpose((1, 2, 0))
 31    elif embeddings.ndim == 5:
 32        pca = []
 33        for embed in embeddings:
 34            vis = embedding_pca(embed.squeeze()).transpose((1, 2, 0))
 35            pca.append(vis)
 36        pca = np.stack(pca)
 37    else:
 38        raise ValueError(f"Expect input of ndim 4 or 5, got {embeddings.ndim}")
 39    return pca
 40
 41
 42def _get_crop(embed_shape, shape):
 43    if shape[0] == shape[1]:  # square image, we don't need to do anything
 44        crop = np.s_[:, :, :]
 45    elif shape[0] > shape[1]:
 46        aspect_ratio = float(shape[1] / shape[0])
 47        crop = np.s_[:, :int(aspect_ratio * embed_shape[1])]
 48    elif shape[1] > shape[0]:
 49        aspect_ratio = float(shape[0] / shape[1])
 50        crop = np.s_[:int(aspect_ratio * embed_shape[0]), :]
 51    return crop
 52
 53
 54def _project_embeddings(embeddings, shape, apply_crop=True):
 55    assert embeddings.ndim == len(shape) + 2, f"{embeddings.shape}, {shape}"
 56
 57    embedding_vis = compute_pca(embeddings)
 58    if not apply_crop:
 59        pass
 60    elif len(shape) == 2:
 61        crop = _get_crop(embedding_vis.shape, shape)
 62        embedding_vis = embedding_vis[crop]
 63    elif len(shape) == 3:
 64        crop = _get_crop(embedding_vis.shape[1:], shape[1:])
 65        crop = (slice(None),) + crop
 66        embedding_vis = embedding_vis[crop]
 67    else:
 68        raise ValueError(f"Expect 2d or 3d data, got {len(shape)}")
 69
 70    scale = tuple(float(sh) / vsh for sh, vsh in zip(shape, embedding_vis.shape))
 71    return embedding_vis, scale
 72
 73
 74def _project_embeddings_to_tile(tile, tile_embeds):
 75    outer_tile = tile.outerBlock
 76    inner_tile_local = tile.innerBlockLocal
 77
 78    embed_shape = tile_embeds.shape[-2:]
 79    outer_tile_shape = tuple(end - beg for beg, end in zip(outer_tile.begin, outer_tile.end))
 80
 81    crop = _get_crop(embed_shape, outer_tile_shape)
 82    crop = (tile_embeds.ndim - len(crop)) * (slice(None),) + crop
 83    this_embeds = tile_embeds[crop]
 84
 85    tile_scale = tuple(esh / float(fsh) for esh, fsh in zip(this_embeds.shape[-2:], outer_tile_shape))
 86    tile_bb = tuple(
 87        slice(int(np.round(beg * scale)), int(np.round(end * scale)))
 88        for beg, end, scale in zip(inner_tile_local.begin, inner_tile_local.end, tile_scale)
 89    )
 90    tile_bb = (tile_embeds.ndim - len(outer_tile_shape)) * (slice(None),) + tile_bb
 91
 92    this_embeds = this_embeds[tile_bb]
 93    return this_embeds
 94
 95
 96def _resize_and_cocatenate(arrays, axis):
 97    assert axis in (-1, -2)
 98    resize_axis = -1 if axis == -2 else -2
 99    resize_len = max([arr.shape[resize_axis] for arr in arrays])
100
101    def resize_shape(shape):
102        axis_ = arrays[0].ndim + resize_axis
103        return tuple(resize_len if i == axis_ else sh for i, sh in enumerate(shape))
104
105    return np.concatenate(
106        [resize(arr, resize_shape(arr.shape)) for arr in arrays],
107        axis=axis
108    )
109
110
111def _project_tiled_embeddings(image_embeddings):
112    features = image_embeddings["features"]
113    tile_shape, halo, shape = features.attrs["tile_shape"], features.attrs["halo"], features.attrs["shape"]
114    tiling = blocking([0, 0], shape, tile_shape)
115
116    tile_grid = tiling.blocksPerAxis
117
118    embeds = {
119        i: {j: None for j in range(tile_grid[1])} for i in range(tile_grid[0])
120    }
121
122    for tile_id in range(tiling.numberOfBlocks):
123        tile_embeds = features[tile_id][:]
124        assert tile_embeds.ndim in (4, 5)
125
126        # extract the embeddings corresponding to the inner tile
127        tile = tiling.getBlockWithHalo(tile_id, list(halo))
128        tile_coords = tiling.blockGridPosition(tile_id)
129        this_embeds = _project_embeddings_to_tile(tile, tile_embeds)
130
131        i, j = tile_coords
132        embeds[i][j] = this_embeds
133
134    embeds = _resize_and_cocatenate(
135        [
136            _resize_and_cocatenate(
137                [embeds[i][j] for j in range(tile_grid[1])], axis=-1
138            )
139            for i in range(tile_grid[0])
140        ], axis=-2
141    )
142
143    if features[0].ndim == 5:
144        shape = (features[0].shape[0],) + tuple(shape)
145    embedding_vis, scale = _project_embeddings(embeds, shape, apply_crop=False)
146    return embedding_vis, scale
147
148
149def project_embeddings_for_visualization(
150    image_embeddings: ImageEmbeddings
151) -> Tuple[np.ndarray, Tuple[float, ...]]:
152    """Project image embeddings to pixel-wise PCA.
153
154    Args:
155        image_embeddings: The image embeddings.
156
157    Returns:
158        The PCA of the embeddings.
159        The scale factor for resizing to the original image size.
160    """
161    is_tiled = image_embeddings["input_size"] is None
162    if is_tiled:
163        embedding_vis, scale = _project_tiled_embeddings(image_embeddings)
164    else:
165        embeddings = image_embeddings["features"]
166        shape = tuple(image_embeddings["original_size"])
167        if embeddings.ndim == 5:
168            shape = (embeddings.shape[0],) + shape
169        embedding_vis, scale = _project_embeddings(embeddings, shape)
170    return embedding_vis, scale
def compute_pca(embeddings: numpy.ndarray) -> numpy.ndarray:
21def compute_pca(embeddings: np.ndarray) -> np.ndarray:
22    """Compute the pca projection of the embeddings to visualize them as RGB image.
23
24    Args:
25        embeddings: The embeddings. For example predicted by the SAM image encoder.
26
27    Returns:
28        PCA of the embeddings, mapped to the pixels.
29    """
30    if embeddings.ndim == 4:
31        pca = embedding_pca(embeddings.squeeze()).transpose((1, 2, 0))
32    elif embeddings.ndim == 5:
33        pca = []
34        for embed in embeddings:
35            vis = embedding_pca(embed.squeeze()).transpose((1, 2, 0))
36            pca.append(vis)
37        pca = np.stack(pca)
38    else:
39        raise ValueError(f"Expect input of ndim 4 or 5, got {embeddings.ndim}")
40    return pca

Compute the pca projection of the embeddings to visualize them as RGB image.

Arguments:
  • embeddings: The embeddings. For example predicted by the SAM image encoder.
Returns:

PCA of the embeddings, mapped to the pixels.

def project_embeddings_for_visualization( image_embeddings: Dict[str, Any]) -> Tuple[numpy.ndarray, Tuple[float, ...]]:
150def project_embeddings_for_visualization(
151    image_embeddings: ImageEmbeddings
152) -> Tuple[np.ndarray, Tuple[float, ...]]:
153    """Project image embeddings to pixel-wise PCA.
154
155    Args:
156        image_embeddings: The image embeddings.
157
158    Returns:
159        The PCA of the embeddings.
160        The scale factor for resizing to the original image size.
161    """
162    is_tiled = image_embeddings["input_size"] is None
163    if is_tiled:
164        embedding_vis, scale = _project_tiled_embeddings(image_embeddings)
165    else:
166        embeddings = image_embeddings["features"]
167        shape = tuple(image_embeddings["original_size"])
168        if embeddings.ndim == 5:
169            shape = (embeddings.shape[0],) + shape
170        embedding_vis, scale = _project_embeddings(embeddings, shape)
171    return embedding_vis, scale

Project image embeddings to pixel-wise PCA.

Arguments:
  • image_embeddings: The image embeddings.
Returns:

The PCA of the embeddings. The scale factor for resizing to the original image size.