micro_sam.visualization

Functionality for visualizing image embeddings.

  1"""Functionality for visualizing image embeddings.
  2"""
  3
  4from typing import Tuple
  5
  6import numpy as np
  7from skimage.transform import resize
  8
  9from nifty.tools import blocking
 10
 11from elf.segmentation.embeddings import embedding_pca
 12
 13from .util import ImageEmbeddings
 14
 15
 16#
 17# PCA visualization for the image embeddings
 18#
 19
 20def compute_pca(embeddings: np.ndarray, n_components: int = 3, as_rgb: bool = True) -> np.ndarray:
 21    """Compute the pca projection of the embeddings to visualize them as RGB image.
 22
 23    Args:
 24        embeddings: The embeddings. For example predicted by the SAM image encoder.
 25        n_components: The number of PCA components to use for dimensionality reduction.
 26        as_rgb: Whether to normalize the projected embeddings so that they can be displated as rgb.
 27
 28    Returns:
 29        PCA of the embeddings, mapped to the pixels.
 30    """
 31    if embeddings.ndim == 4:
 32        pca = embedding_pca(embeddings.squeeze(), n_components=n_components, as_rgb=as_rgb).transpose((1, 2, 0))
 33    elif embeddings.ndim == 5:
 34        pca = []
 35        for embed in embeddings:
 36            vis = embedding_pca(embed.squeeze(), n_components=n_components, as_rgb=as_rgb).transpose((1, 2, 0))
 37            pca.append(vis)
 38        pca = np.stack(pca)
 39    else:
 40        raise ValueError(f"Expect input of ndim 4 or 5, got {embeddings.ndim}")
 41
 42    return pca
 43
 44
 45def _get_crop(embed_shape, shape):
 46    if shape[0] == shape[1]:  # square image, we don't need to do anything
 47        crop = np.s_[:, :, :]
 48    elif shape[0] > shape[1]:
 49        aspect_ratio = float(shape[1] / shape[0])
 50        crop = np.s_[:, :int(aspect_ratio * embed_shape[1])]
 51    elif shape[1] > shape[0]:
 52        aspect_ratio = float(shape[0] / shape[1])
 53        crop = np.s_[:int(aspect_ratio * embed_shape[0]), :]
 54
 55    return crop
 56
 57
 58def _project_embeddings(embeddings, shape, apply_crop=True, n_components=3, as_rgb=True):
 59    assert embeddings.ndim == len(shape) + 2, f"{embeddings.shape}, {shape}"
 60
 61    embedding_vis = compute_pca(embeddings, n_components=n_components, as_rgb=as_rgb)
 62    if not apply_crop:
 63        pass
 64    elif len(shape) == 2:
 65        crop = _get_crop(embedding_vis.shape, shape)
 66        embedding_vis = embedding_vis[crop]
 67    elif len(shape) == 3:
 68        crop = _get_crop(embedding_vis.shape[1:], shape[1:])
 69        crop = (slice(None),) + crop
 70        embedding_vis = embedding_vis[crop]
 71    else:
 72        raise ValueError(f"Expect 2d or 3d data, got {len(shape)}")
 73
 74    scale = tuple(float(sh) / vsh for sh, vsh in zip(shape, embedding_vis.shape))
 75    return embedding_vis, scale
 76
 77
 78def _project_embeddings_to_tile(tile, tile_embeds):
 79    outer_tile = tile.outerBlock
 80    inner_tile_local = tile.innerBlockLocal
 81
 82    embed_shape = tile_embeds.shape[-2:]
 83    outer_tile_shape = tuple(end - beg for beg, end in zip(outer_tile.begin, outer_tile.end))
 84
 85    crop = _get_crop(embed_shape, outer_tile_shape)
 86    crop = (tile_embeds.ndim - len(crop)) * (slice(None),) + crop
 87    this_embeds = tile_embeds[crop]
 88
 89    tile_scale = tuple(esh / float(fsh) for esh, fsh in zip(this_embeds.shape[-2:], outer_tile_shape))
 90    tile_bb = tuple(
 91        slice(int(np.round(beg * scale)), int(np.round(end * scale)))
 92        for beg, end, scale in zip(inner_tile_local.begin, inner_tile_local.end, tile_scale)
 93    )
 94    tile_bb = (tile_embeds.ndim - len(outer_tile_shape)) * (slice(None),) + tile_bb
 95
 96    this_embeds = this_embeds[tile_bb]
 97    return this_embeds
 98
 99
100def _resize_and_cocatenate(arrays, axis):
101    assert axis in (-1, -2)
102    resize_axis = -1 if axis == -2 else -2
103    resize_len = max([arr.shape[resize_axis] for arr in arrays])
104
105    def resize_shape(shape):
106        axis_ = arrays[0].ndim + resize_axis
107        return tuple(resize_len if i == axis_ else sh for i, sh in enumerate(shape))
108
109    return np.concatenate([resize(arr, resize_shape(arr.shape)) for arr in arrays], axis=axis)
110
111
112def _project_tiled_embeddings(image_embeddings, n_components, as_rgb):
113    features = image_embeddings["features"]
114    tile_shape, halo, shape = features.attrs["tile_shape"], features.attrs["halo"], features.attrs["shape"]
115    tiling = blocking([0, 0], shape, tile_shape)
116
117    tile_grid = tiling.blocksPerAxis
118
119    embeds = {
120        i: {j: None for j in range(tile_grid[1])} for i in range(tile_grid[0])
121    }
122
123    for tile_id in range(tiling.numberOfBlocks):
124        tile_embeds = features[str(tile_id)][:]
125        assert tile_embeds.ndim in (4, 5)
126
127        # extract the embeddings corresponding to the inner tile
128        tile = tiling.getBlockWithHalo(tile_id, list(halo))
129        tile_coords = tiling.blockGridPosition(tile_id)
130        this_embeds = _project_embeddings_to_tile(tile, tile_embeds)
131
132        i, j = tile_coords
133        embeds[i][j] = this_embeds
134
135    embeds = _resize_and_cocatenate(
136        [
137            _resize_and_cocatenate(
138                [embeds[i][j] for j in range(tile_grid[1])], axis=-1
139            )
140            for i in range(tile_grid[0])
141        ], axis=-2
142    )
143
144    if features["0"].ndim == 5:
145        shape = (features["0"].shape[0],) + tuple(shape)
146    embedding_vis, scale = _project_embeddings(
147        embeds, shape, n_components=n_components, as_rgb=as_rgb, apply_crop=False
148    )
149    return embedding_vis, scale
150
151
152def project_embeddings_for_visualization(
153    image_embeddings: ImageEmbeddings, n_components: int = 3, as_rgb: bool = True,
154) -> Tuple[np.ndarray, Tuple[float, ...]]:
155    """Project image embeddings to pixel-wise PCA.
156
157    Args:
158        image_embeddings: The image embeddings.
159        n_components: The number of PCA components to use for dimensionality reduction.
160        as_rgb: Whether to normalize the projected embeddings so that they can be displated as rgb.
161
162    Returns:
163        The PCA of the embeddings.
164        The scale factor for resizing to the original image size.
165    """
166    is_tiled = image_embeddings["input_size"] is None
167    if is_tiled:
168        embedding_vis, scale = _project_tiled_embeddings(image_embeddings, n_components, as_rgb)
169    else:
170        embeddings = image_embeddings["features"]
171        shape = tuple(image_embeddings["original_size"])
172        if embeddings.ndim == 5:
173            shape = (embeddings.shape[0],) + shape
174        embedding_vis, scale = _project_embeddings(embeddings, shape, n_components=n_components, as_rgb=as_rgb)
175
176    return embedding_vis, scale
def compute_pca( embeddings: numpy.ndarray, n_components: int = 3, as_rgb: bool = True) -> numpy.ndarray:
21def compute_pca(embeddings: np.ndarray, n_components: int = 3, as_rgb: bool = True) -> np.ndarray:
22    """Compute the pca projection of the embeddings to visualize them as RGB image.
23
24    Args:
25        embeddings: The embeddings. For example predicted by the SAM image encoder.
26        n_components: The number of PCA components to use for dimensionality reduction.
27        as_rgb: Whether to normalize the projected embeddings so that they can be displated as rgb.
28
29    Returns:
30        PCA of the embeddings, mapped to the pixels.
31    """
32    if embeddings.ndim == 4:
33        pca = embedding_pca(embeddings.squeeze(), n_components=n_components, as_rgb=as_rgb).transpose((1, 2, 0))
34    elif embeddings.ndim == 5:
35        pca = []
36        for embed in embeddings:
37            vis = embedding_pca(embed.squeeze(), n_components=n_components, as_rgb=as_rgb).transpose((1, 2, 0))
38            pca.append(vis)
39        pca = np.stack(pca)
40    else:
41        raise ValueError(f"Expect input of ndim 4 or 5, got {embeddings.ndim}")
42
43    return pca

Compute the pca projection of the embeddings to visualize them as RGB image.

Arguments:
  • embeddings: The embeddings. For example predicted by the SAM image encoder.
  • n_components: The number of PCA components to use for dimensionality reduction.
  • as_rgb: Whether to normalize the projected embeddings so that they can be displated as rgb.
Returns:

PCA of the embeddings, mapped to the pixels.

def project_embeddings_for_visualization( image_embeddings: Dict[str, Any], n_components: int = 3, as_rgb: bool = True) -> Tuple[numpy.ndarray, Tuple[float, ...]]:
153def project_embeddings_for_visualization(
154    image_embeddings: ImageEmbeddings, n_components: int = 3, as_rgb: bool = True,
155) -> Tuple[np.ndarray, Tuple[float, ...]]:
156    """Project image embeddings to pixel-wise PCA.
157
158    Args:
159        image_embeddings: The image embeddings.
160        n_components: The number of PCA components to use for dimensionality reduction.
161        as_rgb: Whether to normalize the projected embeddings so that they can be displated as rgb.
162
163    Returns:
164        The PCA of the embeddings.
165        The scale factor for resizing to the original image size.
166    """
167    is_tiled = image_embeddings["input_size"] is None
168    if is_tiled:
169        embedding_vis, scale = _project_tiled_embeddings(image_embeddings, n_components, as_rgb)
170    else:
171        embeddings = image_embeddings["features"]
172        shape = tuple(image_embeddings["original_size"])
173        if embeddings.ndim == 5:
174            shape = (embeddings.shape[0],) + shape
175        embedding_vis, scale = _project_embeddings(embeddings, shape, n_components=n_components, as_rgb=as_rgb)
176
177    return embedding_vis, scale

Project image embeddings to pixel-wise PCA.

Arguments:
  • image_embeddings: The image embeddings.
  • n_components: The number of PCA components to use for dimensionality reduction.
  • as_rgb: Whether to normalize the projected embeddings so that they can be displated as rgb.
Returns:

The PCA of the embeddings. The scale factor for resizing to the original image size.