micro_sam.visualization

Functionality for visualizing image embeddings.

  1"""Functionality for visualizing image embeddings.
  2"""
  3
  4from typing import Tuple
  5
  6import numpy as np
  7from skimage.transform import resize
  8
  9from nifty.tools import blocking
 10
 11from elf.segmentation.embeddings import embedding_pca
 12
 13from .util import ImageEmbeddings
 14
 15
 16#
 17# PCA visualization for the image embeddings
 18#
 19
 20def compute_pca(embeddings: np.ndarray) -> np.ndarray:
 21    """Compute the pca projection of the embeddings to visualize them as RGB image.
 22
 23    Args:
 24        embeddings: The embeddings. For example predicted by the SAM image encoder.
 25
 26    Returns:
 27        PCA of the embeddings, mapped to the pixels.
 28    """
 29    if embeddings.ndim == 4:
 30        pca = embedding_pca(embeddings.squeeze()).transpose((1, 2, 0))
 31    elif embeddings.ndim == 5:
 32        pca = []
 33        for embed in embeddings:
 34            vis = embedding_pca(embed.squeeze()).transpose((1, 2, 0))
 35            pca.append(vis)
 36        pca = np.stack(pca)
 37    else:
 38        raise ValueError(f"Expect input of ndim 4 or 5, got {embeddings.ndim}")
 39
 40    return pca
 41
 42
 43def _get_crop(embed_shape, shape):
 44    if shape[0] == shape[1]:  # square image, we don't need to do anything
 45        crop = np.s_[:, :, :]
 46    elif shape[0] > shape[1]:
 47        aspect_ratio = float(shape[1] / shape[0])
 48        crop = np.s_[:, :int(aspect_ratio * embed_shape[1])]
 49    elif shape[1] > shape[0]:
 50        aspect_ratio = float(shape[0] / shape[1])
 51        crop = np.s_[:int(aspect_ratio * embed_shape[0]), :]
 52
 53    return crop
 54
 55
 56def _project_embeddings(embeddings, shape, apply_crop=True):
 57    assert embeddings.ndim == len(shape) + 2, f"{embeddings.shape}, {shape}"
 58
 59    embedding_vis = compute_pca(embeddings)
 60    if not apply_crop:
 61        pass
 62    elif len(shape) == 2:
 63        crop = _get_crop(embedding_vis.shape, shape)
 64        embedding_vis = embedding_vis[crop]
 65    elif len(shape) == 3:
 66        crop = _get_crop(embedding_vis.shape[1:], shape[1:])
 67        crop = (slice(None),) + crop
 68        embedding_vis = embedding_vis[crop]
 69    else:
 70        raise ValueError(f"Expect 2d or 3d data, got {len(shape)}")
 71
 72    scale = tuple(float(sh) / vsh for sh, vsh in zip(shape, embedding_vis.shape))
 73    return embedding_vis, scale
 74
 75
 76def _project_embeddings_to_tile(tile, tile_embeds):
 77    outer_tile = tile.outerBlock
 78    inner_tile_local = tile.innerBlockLocal
 79
 80    embed_shape = tile_embeds.shape[-2:]
 81    outer_tile_shape = tuple(end - beg for beg, end in zip(outer_tile.begin, outer_tile.end))
 82
 83    crop = _get_crop(embed_shape, outer_tile_shape)
 84    crop = (tile_embeds.ndim - len(crop)) * (slice(None),) + crop
 85    this_embeds = tile_embeds[crop]
 86
 87    tile_scale = tuple(esh / float(fsh) for esh, fsh in zip(this_embeds.shape[-2:], outer_tile_shape))
 88    tile_bb = tuple(
 89        slice(int(np.round(beg * scale)), int(np.round(end * scale)))
 90        for beg, end, scale in zip(inner_tile_local.begin, inner_tile_local.end, tile_scale)
 91    )
 92    tile_bb = (tile_embeds.ndim - len(outer_tile_shape)) * (slice(None),) + tile_bb
 93
 94    this_embeds = this_embeds[tile_bb]
 95    return this_embeds
 96
 97
 98def _resize_and_cocatenate(arrays, axis):
 99    assert axis in (-1, -2)
100    resize_axis = -1 if axis == -2 else -2
101    resize_len = max([arr.shape[resize_axis] for arr in arrays])
102
103    def resize_shape(shape):
104        axis_ = arrays[0].ndim + resize_axis
105        return tuple(resize_len if i == axis_ else sh for i, sh in enumerate(shape))
106
107    return np.concatenate([resize(arr, resize_shape(arr.shape)) for arr in arrays], axis=axis)
108
109
110def _project_tiled_embeddings(image_embeddings):
111    features = image_embeddings["features"]
112    tile_shape, halo, shape = features.attrs["tile_shape"], features.attrs["halo"], features.attrs["shape"]
113    tiling = blocking([0, 0], shape, tile_shape)
114
115    tile_grid = tiling.blocksPerAxis
116
117    embeds = {
118        i: {j: None for j in range(tile_grid[1])} for i in range(tile_grid[0])
119    }
120
121    for tile_id in range(tiling.numberOfBlocks):
122        tile_embeds = features[tile_id][:]
123        assert tile_embeds.ndim in (4, 5)
124
125        # extract the embeddings corresponding to the inner tile
126        tile = tiling.getBlockWithHalo(tile_id, list(halo))
127        tile_coords = tiling.blockGridPosition(tile_id)
128        this_embeds = _project_embeddings_to_tile(tile, tile_embeds)
129
130        i, j = tile_coords
131        embeds[i][j] = this_embeds
132
133    embeds = _resize_and_cocatenate(
134        [
135            _resize_and_cocatenate(
136                [embeds[i][j] for j in range(tile_grid[1])], axis=-1
137            )
138            for i in range(tile_grid[0])
139        ], axis=-2
140    )
141
142    if features[0].ndim == 5:
143        shape = (features[0].shape[0],) + tuple(shape)
144    embedding_vis, scale = _project_embeddings(embeds, shape, apply_crop=False)
145    return embedding_vis, scale
146
147
148def project_embeddings_for_visualization(
149    image_embeddings: ImageEmbeddings
150) -> Tuple[np.ndarray, Tuple[float, ...]]:
151    """Project image embeddings to pixel-wise PCA.
152
153    Args:
154        image_embeddings: The image embeddings.
155
156    Returns:
157        The PCA of the embeddings.
158        The scale factor for resizing to the original image size.
159    """
160    is_tiled = image_embeddings["input_size"] is None
161    if is_tiled:
162        embedding_vis, scale = _project_tiled_embeddings(image_embeddings)
163    else:
164        embeddings = image_embeddings["features"]
165        shape = tuple(image_embeddings["original_size"])
166        if embeddings.ndim == 5:
167            shape = (embeddings.shape[0],) + shape
168        embedding_vis, scale = _project_embeddings(embeddings, shape)
169
170    return embedding_vis, scale
def compute_pca(embeddings: numpy.ndarray) -> numpy.ndarray:
21def compute_pca(embeddings: np.ndarray) -> np.ndarray:
22    """Compute the pca projection of the embeddings to visualize them as RGB image.
23
24    Args:
25        embeddings: The embeddings. For example predicted by the SAM image encoder.
26
27    Returns:
28        PCA of the embeddings, mapped to the pixels.
29    """
30    if embeddings.ndim == 4:
31        pca = embedding_pca(embeddings.squeeze()).transpose((1, 2, 0))
32    elif embeddings.ndim == 5:
33        pca = []
34        for embed in embeddings:
35            vis = embedding_pca(embed.squeeze()).transpose((1, 2, 0))
36            pca.append(vis)
37        pca = np.stack(pca)
38    else:
39        raise ValueError(f"Expect input of ndim 4 or 5, got {embeddings.ndim}")
40
41    return pca

Compute the pca projection of the embeddings to visualize them as RGB image.

Arguments:
  • embeddings: The embeddings. For example predicted by the SAM image encoder.
Returns:

PCA of the embeddings, mapped to the pixels.

def project_embeddings_for_visualization( image_embeddings: Dict[str, Any]) -> Tuple[numpy.ndarray, Tuple[float, ...]]:
149def project_embeddings_for_visualization(
150    image_embeddings: ImageEmbeddings
151) -> Tuple[np.ndarray, Tuple[float, ...]]:
152    """Project image embeddings to pixel-wise PCA.
153
154    Args:
155        image_embeddings: The image embeddings.
156
157    Returns:
158        The PCA of the embeddings.
159        The scale factor for resizing to the original image size.
160    """
161    is_tiled = image_embeddings["input_size"] is None
162    if is_tiled:
163        embedding_vis, scale = _project_tiled_embeddings(image_embeddings)
164    else:
165        embeddings = image_embeddings["features"]
166        shape = tuple(image_embeddings["original_size"])
167        if embeddings.ndim == 5:
168            shape = (embeddings.shape[0],) + shape
169        embedding_vis, scale = _project_embeddings(embeddings, shape)
170
171    return embedding_vis, scale

Project image embeddings to pixel-wise PCA.

Arguments:
  • image_embeddings: The image embeddings.
Returns:

The PCA of the embeddings. The scale factor for resizing to the original image size.