micro_sam.precompute_state

Precompute image embeddings and automatic mask generator state for image data.

  1"""Precompute image embeddings and automatic mask generator state for image data.
  2"""
  3
  4import os
  5import pickle
  6from glob import glob
  7from pathlib import Path
  8from functools import partial
  9from typing import Optional, Tuple, Union, List
 10
 11import h5py
 12import numpy as np
 13
 14import torch
 15import torch.nn as nn
 16
 17from segment_anything.predictor import SamPredictor
 18
 19try:
 20    from napari.utils import progress as tqdm
 21except ImportError:
 22    from tqdm import tqdm
 23
 24from . import instance_segmentation, util
 25
 26
 27def cache_amg_state(
 28    predictor: SamPredictor,
 29    raw: np.ndarray,
 30    image_embeddings: util.ImageEmbeddings,
 31    save_path: Union[str, os.PathLike],
 32    verbose: bool = True,
 33    i: Optional[int] = None,
 34    **kwargs,
 35) -> instance_segmentation.AMGBase:
 36    """Compute and cache or load the state for the automatic mask generator.
 37
 38    Args:
 39        predictor: The segment anything predictor.
 40        raw: The image data.
 41        image_embeddings: The image embeddings.
 42        save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'.
 43        verbose: Whether to run the computation verbose.
 44        i: The index for which to cache the state.
 45        kwargs: The keyword arguments for the amg class.
 46
 47    Returns:
 48        The automatic mask generator class with the cached state.
 49    """
 50    is_tiled = image_embeddings["input_size"] is None
 51    amg = instance_segmentation.get_amg(predictor, is_tiled, **kwargs)
 52
 53    # If i is given we compute the state for a given slice/frame.
 54    # And we have to save the state for slices/frames separately.
 55    if i is None:
 56        save_path_amg = os.path.join(save_path, "amg_state.pickle")
 57    else:
 58        os.makedirs(os.path.join(save_path, "amg_state"), exist_ok=True)
 59        save_path_amg = os.path.join(save_path, "amg_state", f"state-{i}.pkl")
 60
 61    if os.path.exists(save_path_amg):
 62        if verbose:
 63            print("Load the AMG state from", save_path_amg)
 64        with open(save_path_amg, "rb") as f:
 65            amg_state = pickle.load(f)
 66        amg.set_state(amg_state)
 67        return amg
 68
 69    if verbose:
 70        print("Precomputing the state for instance segmentation.")
 71
 72    amg.initialize(raw if i is None else raw[i], image_embeddings=image_embeddings, verbose=verbose, i=i)
 73    amg_state = amg.get_state()
 74
 75    # put all state onto the cpu so that the state can be deserialized without a gpu
 76    new_crop_list = []
 77    for mask_data in amg_state["crop_list"]:
 78        for k, v in mask_data.items():
 79            if torch.is_tensor(v):
 80                mask_data[k] = v.cpu()
 81        new_crop_list.append(mask_data)
 82    amg_state["crop_list"] = new_crop_list
 83
 84    with open(save_path_amg, "wb") as f:
 85        pickle.dump(amg_state, f)
 86
 87    return amg
 88
 89
 90def cache_is_state(
 91    predictor: SamPredictor,
 92    decoder: torch.nn.Module,
 93    raw: np.ndarray,
 94    image_embeddings: util.ImageEmbeddings,
 95    save_path: Union[str, os.PathLike],
 96    verbose: bool = True,
 97    i: Optional[int] = None,
 98    skip_load: bool = False,
 99    **kwargs,
100) -> Optional[instance_segmentation.AMGBase]:
101    """Compute and cache or load the state for the automatic mask generator.
102
103    Args:
104        predictor: The segment anything predictor.
105        decoder: The instance segmentation decoder.
106        raw: The image data.
107        image_embeddings: The image embeddings.
108        save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'.
109        verbose: Whether to run the computation verbose.
110        i: The index for which to cache the state.
111        skip_load: Skip loading the state if it is precomputed.
112        kwargs: The keyword arguments for the amg class.
113
114    Returns:
115        The instance segmentation class with the cached state.
116    """
117    is_tiled = image_embeddings["input_size"] is None
118    amg = instance_segmentation.get_amg(predictor, is_tiled, decoder=decoder, **kwargs)
119
120    # If i is given we compute the state for a given slice/frame.
121    # And we have to save the state for slices/frames separately.
122    save_path = os.path.join(save_path, "is_state.h5")
123    save_key = "state" if i is None else f"state-{i}"
124
125    with h5py.File(save_path, "a") as f:
126        if save_key in f:
127            if skip_load:  # Skip loading to speed this up for cases where we don't need the return val.
128                return
129
130            if verbose:
131                print("Load instance segmentation state from", save_path, ":", save_key)
132            g = f[save_key]
133            state = {
134                "foreground": g["foreground"][:],
135                "boundary_distances": g["boundary_distances"][:],
136                "center_distances": g["center_distances"][:],
137            }
138            amg.set_state(state)
139            return amg
140
141    if verbose:
142        print("Precomputing the state for instance segmentation.")
143
144    amg.initialize(raw, image_embeddings=image_embeddings, verbose=verbose, i=i)
145    state = amg.get_state()
146
147    with h5py.File(save_path, "a") as f:
148        g = f.create_group(save_key)
149        g.create_dataset("foreground", data=state["foreground"], compression="gzip")
150        g.create_dataset("boundary_distances", data=state["boundary_distances"], compression="gzip")
151        g.create_dataset("center_distances", data=state["center_distances"], compression="gzip")
152
153    return amg
154
155
156def _precompute_state_for_file(
157    predictor, input_path, output_path, key, ndim, tile_shape, halo, precompute_amg_state, decoder, verbose
158):
159    if isinstance(input_path, np.ndarray):
160        image_data = input_path
161    else:
162        image_data = util.load_image_data(input_path, key)
163
164    # Precompute the image embeddings.
165    output_path = Path(output_path).with_suffix(".zarr")
166    embeddings = util.precompute_image_embeddings(
167        predictor, image_data, output_path, ndim=ndim, tile_shape=tile_shape, halo=halo, verbose=verbose
168    )
169
170    # Precompute the state for automatic instance segmnetaiton (AMG or AIS).
171    if precompute_amg_state:
172        if decoder is None:
173            cache_function = partial(
174                cache_amg_state, predictor=predictor, image_embeddings=embeddings, save_path=output_path
175            )
176        else:
177            cache_function = partial(
178                cache_is_state, predictor=predictor, decoder=decoder,
179                image_embeddings=embeddings, save_path=output_path
180            )
181
182        if ndim is None:
183            ndim = image_data.ndim
184
185        if ndim == 2:
186            cache_function(raw=image_data, verbose=verbose)
187        else:
188            n = image_data.shape[0]
189            for i in tqdm(range(n), total=n, desc="Precompute instance segmentation state", disable=not verbose):
190                cache_function(raw=image_data, i=i, verbose=False)
191
192
193def _precompute_state_for_files(
194    predictor: SamPredictor,
195    input_files: Union[List[Union[os.PathLike, str]], List[np.ndarray]],
196    output_path: Union[os.PathLike, str],
197    key: Optional[str] = None,
198    ndim: Optional[int] = None,
199    tile_shape: Optional[Tuple[int, int]] = None,
200    halo: Optional[Tuple[int, int]] = None,
201    precompute_amg_state: bool = False,
202    decoder: Optional["nn.Module"] = None,
203):
204    os.makedirs(output_path, exist_ok=True)
205    idx = 0
206    for file_path in tqdm(input_files, total=len(input_files), desc="Precompute state for files"):
207
208        if isinstance(file_path, np.ndarray):
209            out_path = os.path.join(output_path, f"embedding_{idx:05}.tif")
210        else:
211            out_path = os.path.join(output_path, os.path.basename(file_path))
212
213        _precompute_state_for_file(
214            predictor, file_path, out_path,
215            key=key, ndim=ndim, tile_shape=tile_shape, halo=halo,
216            precompute_amg_state=precompute_amg_state, decoder=decoder,
217            verbose=False,
218        )
219        idx += 1
220
221
222def precompute_state(
223    input_path: Union[os.PathLike, str],
224    output_path: Union[os.PathLike, str],
225    pattern: Optional[str] = None,
226    model_type: str = util._DEFAULT_MODEL,
227    checkpoint_path: Optional[Union[os.PathLike, str]] = None,
228    key: Optional[str] = None,
229    ndim: Optional[int] = None,
230    tile_shape: Optional[Tuple[int, int]] = None,
231    halo: Optional[Tuple[int, int]] = None,
232    precompute_amg_state: bool = False,
233) -> None:
234    """Precompute the image embeddings and other optional state for the input image(s).
235
236    Args:
237        input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
238            a container file (e.g. hdf5 or zarr) or a folder with images files.
239            In case of a container file the argument `key` must be given. In case of a folder
240            it can be given to provide a glob pattern to subselect files from the folder.
241        output_path: The output path where the embeddings and other state will be saved.
242        pattern: Glob pattern to select files in a folder. The embeddings will be computed
243            for each of these files. To select all files in a folder pass "*".
244        model_type: The Segment Anything model to use. Will use the `vit_b_lm` model by default.
245        checkpoint_path: Path to a checkpoint for a custom model.
246        key: The key to the input file. This is needed for contaner files (e.g. hdf5 or zarr)
247            or to load several images as 3d volume. Provide a glob pattern, e.g. "*.tif", for this case.
248        ndim: The dimensionality of the data.
249        tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
250        halo: Overlap of the tiles for tiled prediction.
251        precompute_amg_state: Whether to precompute the state for automatic instance segmentation
252            in addition to the image embeddings.
253    """
254    predictor, state = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, return_state=True)
255
256    if "decoder_state" in state:
257        decoder = instance_segmentation.get_decoder(predictor.model.image_encoder, state["decoder_state"])
258    else:
259        decoder = None
260
261    # Check if we precompute the state for a single file or for a folder with image files.
262    if pattern is None:
263        _precompute_state_for_file(
264            predictor, input_path, output_path, key,
265            ndim=ndim, tile_shape=tile_shape, halo=halo,
266            precompute_amg_state=precompute_amg_state,
267            decoder=decoder, verbose=True,
268        )
269    else:
270        input_files = glob(os.path.join(input_path, pattern))
271        _precompute_state_for_files(
272            predictor, input_files, output_path, key=key,
273            ndim=ndim, tile_shape=tile_shape, halo=halo,
274            precompute_amg_state=precompute_amg_state,
275            decoder=decoder,
276        )
277
278
279def main():
280    """@private"""
281    import argparse
282
283    available_models = list(util.get_model_names())
284    available_models = ", ".join(available_models)
285
286    parser = argparse.ArgumentParser(description="Compute the embeddings for an image.")
287    parser.add_argument(
288        "-i", "--input_path", required=True,
289        help="The filepath to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) "
290        "or elf.io.open_file (e.g. hdf5, zarr, mrc). For the latter you also need to pass the 'key' parameter."
291    )
292    parser.add_argument(
293        "-e", "--embedding_path", required=True, help="The path where the embeddings will be saved."
294    )
295    parser.add_argument(
296        "--pattern", help="Pattern / wildcard for selecting files in a folder. To select all files use '*'."
297    )
298    parser.add_argument(
299        "-k", "--key",
300        help="The key for opening data with elf.io.open_file. This is the internal path for a hdf5 or zarr container, "
301        "for an image stack it is a wild-card, e.g. '*.png' and for mrc it is 'data'."
302    )
303    parser.add_argument(
304        "-m", "--model_type", default=util._DEFAULT_MODEL,
305        help=f"The segment anything model that will be used, one of {available_models}."
306    )
307    parser.add_argument(
308        "-c", "--checkpoint", default=None,
309        help="Checkpoint from which the SAM model will be loaded loaded."
310    )
311    parser.add_argument(
312        "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction.", default=None
313    )
314    parser.add_argument(
315        "--halo", nargs="+", type=int, help="The halo for using tiled prediction.", default=None
316    )
317    parser.add_argument(
318        "-n", "--ndim", type=int, default=None,
319        help="The number of spatial dimensions in the data. "
320        "Please specify this if your data has a channel dimension."
321    )
322    parser.add_argument(
323        "-p", "--precompute_amg_state", action="store_true",
324        help="Whether to precompute the state for automatic instance segmentation."
325    )
326
327    args = parser.parse_args()
328    precompute_state(
329        args.input_path, args.embedding_path,
330        model_type=args.model_type, checkpoint_path=args.checkpoint,
331        pattern=args.pattern, key=args.key,
332        tile_shape=args.tile_shape, halo=args.halo, ndim=args.ndim,
333        precompute_amg_state=args.precompute_amg_state,
334    )
335
336
337if __name__ == "__main__":
338    main()
def cache_amg_state( predictor: segment_anything.predictor.SamPredictor, raw: numpy.ndarray, image_embeddings: Dict[str, Any], save_path: Union[str, os.PathLike], verbose: bool = True, i: Optional[int] = None, **kwargs) -> micro_sam.instance_segmentation.AMGBase:
28def cache_amg_state(
29    predictor: SamPredictor,
30    raw: np.ndarray,
31    image_embeddings: util.ImageEmbeddings,
32    save_path: Union[str, os.PathLike],
33    verbose: bool = True,
34    i: Optional[int] = None,
35    **kwargs,
36) -> instance_segmentation.AMGBase:
37    """Compute and cache or load the state for the automatic mask generator.
38
39    Args:
40        predictor: The segment anything predictor.
41        raw: The image data.
42        image_embeddings: The image embeddings.
43        save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'.
44        verbose: Whether to run the computation verbose.
45        i: The index for which to cache the state.
46        kwargs: The keyword arguments for the amg class.
47
48    Returns:
49        The automatic mask generator class with the cached state.
50    """
51    is_tiled = image_embeddings["input_size"] is None
52    amg = instance_segmentation.get_amg(predictor, is_tiled, **kwargs)
53
54    # If i is given we compute the state for a given slice/frame.
55    # And we have to save the state for slices/frames separately.
56    if i is None:
57        save_path_amg = os.path.join(save_path, "amg_state.pickle")
58    else:
59        os.makedirs(os.path.join(save_path, "amg_state"), exist_ok=True)
60        save_path_amg = os.path.join(save_path, "amg_state", f"state-{i}.pkl")
61
62    if os.path.exists(save_path_amg):
63        if verbose:
64            print("Load the AMG state from", save_path_amg)
65        with open(save_path_amg, "rb") as f:
66            amg_state = pickle.load(f)
67        amg.set_state(amg_state)
68        return amg
69
70    if verbose:
71        print("Precomputing the state for instance segmentation.")
72
73    amg.initialize(raw if i is None else raw[i], image_embeddings=image_embeddings, verbose=verbose, i=i)
74    amg_state = amg.get_state()
75
76    # put all state onto the cpu so that the state can be deserialized without a gpu
77    new_crop_list = []
78    for mask_data in amg_state["crop_list"]:
79        for k, v in mask_data.items():
80            if torch.is_tensor(v):
81                mask_data[k] = v.cpu()
82        new_crop_list.append(mask_data)
83    amg_state["crop_list"] = new_crop_list
84
85    with open(save_path_amg, "wb") as f:
86        pickle.dump(amg_state, f)
87
88    return amg

Compute and cache or load the state for the automatic mask generator.

Arguments:
  • predictor: The segment anything predictor.
  • raw: The image data.
  • image_embeddings: The image embeddings.
  • save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'.
  • verbose: Whether to run the computation verbose.
  • i: The index for which to cache the state.
  • kwargs: The keyword arguments for the amg class.
Returns:

The automatic mask generator class with the cached state.

def cache_is_state( predictor: segment_anything.predictor.SamPredictor, decoder: torch.nn.modules.module.Module, raw: numpy.ndarray, image_embeddings: Dict[str, Any], save_path: Union[str, os.PathLike], verbose: bool = True, i: Optional[int] = None, skip_load: bool = False, **kwargs) -> Optional[micro_sam.instance_segmentation.AMGBase]:
 91def cache_is_state(
 92    predictor: SamPredictor,
 93    decoder: torch.nn.Module,
 94    raw: np.ndarray,
 95    image_embeddings: util.ImageEmbeddings,
 96    save_path: Union[str, os.PathLike],
 97    verbose: bool = True,
 98    i: Optional[int] = None,
 99    skip_load: bool = False,
100    **kwargs,
101) -> Optional[instance_segmentation.AMGBase]:
102    """Compute and cache or load the state for the automatic mask generator.
103
104    Args:
105        predictor: The segment anything predictor.
106        decoder: The instance segmentation decoder.
107        raw: The image data.
108        image_embeddings: The image embeddings.
109        save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'.
110        verbose: Whether to run the computation verbose.
111        i: The index for which to cache the state.
112        skip_load: Skip loading the state if it is precomputed.
113        kwargs: The keyword arguments for the amg class.
114
115    Returns:
116        The instance segmentation class with the cached state.
117    """
118    is_tiled = image_embeddings["input_size"] is None
119    amg = instance_segmentation.get_amg(predictor, is_tiled, decoder=decoder, **kwargs)
120
121    # If i is given we compute the state for a given slice/frame.
122    # And we have to save the state for slices/frames separately.
123    save_path = os.path.join(save_path, "is_state.h5")
124    save_key = "state" if i is None else f"state-{i}"
125
126    with h5py.File(save_path, "a") as f:
127        if save_key in f:
128            if skip_load:  # Skip loading to speed this up for cases where we don't need the return val.
129                return
130
131            if verbose:
132                print("Load instance segmentation state from", save_path, ":", save_key)
133            g = f[save_key]
134            state = {
135                "foreground": g["foreground"][:],
136                "boundary_distances": g["boundary_distances"][:],
137                "center_distances": g["center_distances"][:],
138            }
139            amg.set_state(state)
140            return amg
141
142    if verbose:
143        print("Precomputing the state for instance segmentation.")
144
145    amg.initialize(raw, image_embeddings=image_embeddings, verbose=verbose, i=i)
146    state = amg.get_state()
147
148    with h5py.File(save_path, "a") as f:
149        g = f.create_group(save_key)
150        g.create_dataset("foreground", data=state["foreground"], compression="gzip")
151        g.create_dataset("boundary_distances", data=state["boundary_distances"], compression="gzip")
152        g.create_dataset("center_distances", data=state["center_distances"], compression="gzip")
153
154    return amg

Compute and cache or load the state for the automatic mask generator.

Arguments:
  • predictor: The segment anything predictor.
  • decoder: The instance segmentation decoder.
  • raw: The image data.
  • image_embeddings: The image embeddings.
  • save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'.
  • verbose: Whether to run the computation verbose.
  • i: The index for which to cache the state.
  • skip_load: Skip loading the state if it is precomputed.
  • kwargs: The keyword arguments for the amg class.
Returns:

The instance segmentation class with the cached state.

def precompute_state( input_path: Union[os.PathLike, str], output_path: Union[os.PathLike, str], pattern: Optional[str] = None, model_type: str = 'vit_b_lm', checkpoint_path: Union[str, os.PathLike, NoneType] = None, key: Optional[str] = None, ndim: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, precompute_amg_state: bool = False) -> None:
223def precompute_state(
224    input_path: Union[os.PathLike, str],
225    output_path: Union[os.PathLike, str],
226    pattern: Optional[str] = None,
227    model_type: str = util._DEFAULT_MODEL,
228    checkpoint_path: Optional[Union[os.PathLike, str]] = None,
229    key: Optional[str] = None,
230    ndim: Optional[int] = None,
231    tile_shape: Optional[Tuple[int, int]] = None,
232    halo: Optional[Tuple[int, int]] = None,
233    precompute_amg_state: bool = False,
234) -> None:
235    """Precompute the image embeddings and other optional state for the input image(s).
236
237    Args:
238        input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
239            a container file (e.g. hdf5 or zarr) or a folder with images files.
240            In case of a container file the argument `key` must be given. In case of a folder
241            it can be given to provide a glob pattern to subselect files from the folder.
242        output_path: The output path where the embeddings and other state will be saved.
243        pattern: Glob pattern to select files in a folder. The embeddings will be computed
244            for each of these files. To select all files in a folder pass "*".
245        model_type: The Segment Anything model to use. Will use the `vit_b_lm` model by default.
246        checkpoint_path: Path to a checkpoint for a custom model.
247        key: The key to the input file. This is needed for contaner files (e.g. hdf5 or zarr)
248            or to load several images as 3d volume. Provide a glob pattern, e.g. "*.tif", for this case.
249        ndim: The dimensionality of the data.
250        tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
251        halo: Overlap of the tiles for tiled prediction.
252        precompute_amg_state: Whether to precompute the state for automatic instance segmentation
253            in addition to the image embeddings.
254    """
255    predictor, state = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, return_state=True)
256
257    if "decoder_state" in state:
258        decoder = instance_segmentation.get_decoder(predictor.model.image_encoder, state["decoder_state"])
259    else:
260        decoder = None
261
262    # Check if we precompute the state for a single file or for a folder with image files.
263    if pattern is None:
264        _precompute_state_for_file(
265            predictor, input_path, output_path, key,
266            ndim=ndim, tile_shape=tile_shape, halo=halo,
267            precompute_amg_state=precompute_amg_state,
268            decoder=decoder, verbose=True,
269        )
270    else:
271        input_files = glob(os.path.join(input_path, pattern))
272        _precompute_state_for_files(
273            predictor, input_files, output_path, key=key,
274            ndim=ndim, tile_shape=tile_shape, halo=halo,
275            precompute_amg_state=precompute_amg_state,
276            decoder=decoder,
277        )

Precompute the image embeddings and other optional state for the input image(s).

Arguments:
  • input_path: The input image file(s). Can either be a single image file (e.g. tif or png), a container file (e.g. hdf5 or zarr) or a folder with images files. In case of a container file the argument key must be given. In case of a folder it can be given to provide a glob pattern to subselect files from the folder.
  • output_path: The output path where the embeddings and other state will be saved.
  • pattern: Glob pattern to select files in a folder. The embeddings will be computed for each of these files. To select all files in a folder pass "*".
  • model_type: The Segment Anything model to use. Will use the vit_b_lm model by default.
  • checkpoint_path: Path to a checkpoint for a custom model.
  • key: The key to the input file. This is needed for contaner files (e.g. hdf5 or zarr) or to load several images as 3d volume. Provide a glob pattern, e.g. "*.tif", for this case.
  • ndim: The dimensionality of the data.
  • tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
  • halo: Overlap of the tiles for tiled prediction.
  • precompute_amg_state: Whether to precompute the state for automatic instance segmentation in addition to the image embeddings.