micro_sam.precompute_state

Precompute image embeddings and automatic mask generator state for image data.

  1"""Precompute image embeddings and automatic mask generator state for image data.
  2"""
  3
  4import os
  5import pickle
  6from glob import glob
  7from pathlib import Path
  8from functools import partial
  9from typing import Optional, Tuple, Union, List
 10
 11import h5py
 12import numpy as np
 13
 14import torch
 15import torch.nn as nn
 16
 17from segment_anything.predictor import SamPredictor
 18
 19try:
 20    from napari.utils import progress as tqdm
 21except ImportError:
 22    from tqdm import tqdm
 23
 24from . import instance_segmentation, util
 25
 26
 27def cache_amg_state(
 28    predictor: SamPredictor,
 29    raw: np.ndarray,
 30    image_embeddings: util.ImageEmbeddings,
 31    save_path: Union[str, os.PathLike],
 32    verbose: bool = True,
 33    i: Optional[int] = None,
 34    **kwargs,
 35) -> instance_segmentation.AMGBase:
 36    """Compute and cache or load the state for the automatic mask generator.
 37
 38    Args:
 39        predictor: The segment anything predictor.
 40        raw: The image data.
 41        image_embeddings: The image embeddings.
 42        save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'.
 43        verbose: Whether to run the computation verbose.
 44        i: The index for which to cache the state.
 45        kwargs: The keyword arguments for the amg class.
 46
 47    Returns:
 48        The automatic mask generator class with the cached state.
 49    """
 50    is_tiled = image_embeddings["input_size"] is None
 51    amg = instance_segmentation.get_amg(predictor, is_tiled, **kwargs)
 52
 53    # If i is given we compute the state for a given slice/frame.
 54    # And we have to save the state for slices/frames separately.
 55    if i is None:
 56        save_path_amg = os.path.join(save_path, "amg_state.pickle")
 57    else:
 58        os.makedirs(os.path.join(save_path, "amg_state"), exist_ok=True)
 59        save_path_amg = os.path.join(save_path, "amg_state", f"state-{i}.pkl")
 60
 61    if os.path.exists(save_path_amg):
 62        if verbose:
 63            print("Load the AMG state from", save_path_amg)
 64        with open(save_path_amg, "rb") as f:
 65            amg_state = pickle.load(f)
 66        amg.set_state(amg_state)
 67        return amg
 68
 69    if verbose:
 70        print("Precomputing the state for instance segmentation.")
 71
 72    amg.initialize(raw if i is None else raw[i], image_embeddings=image_embeddings, verbose=verbose, i=i)
 73    amg_state = amg.get_state()
 74
 75    # put all state onto the cpu so that the state can be deserialized without a gpu
 76    new_crop_list = []
 77    for mask_data in amg_state["crop_list"]:
 78        for k, v in mask_data.items():
 79            if torch.is_tensor(v):
 80                mask_data[k] = v.cpu()
 81        new_crop_list.append(mask_data)
 82    amg_state["crop_list"] = new_crop_list
 83
 84    with open(save_path_amg, "wb") as f:
 85        pickle.dump(amg_state, f)
 86
 87    return amg
 88
 89
 90def cache_is_state(
 91    predictor: SamPredictor,
 92    decoder: torch.nn.Module,
 93    raw: np.ndarray,
 94    image_embeddings: util.ImageEmbeddings,
 95    save_path: Union[str, os.PathLike],
 96    verbose: bool = True,
 97    i: Optional[int] = None,
 98    skip_load: bool = False,
 99    **kwargs,
100) -> Optional[instance_segmentation.AMGBase]:
101    """Compute and cache or load the state for the automatic mask generator.
102
103    Args:
104        predictor: The segment anything predictor.
105        decoder: The instance segmentation decoder.
106        raw: The image data.
107        image_embeddings: The image embeddings.
108        save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'.
109        verbose: Whether to run the computation verbose.
110        i: The index for which to cache the state.
111        skip_load: Skip loading the state if it is precomputed.
112        kwargs: The keyword arguments for the amg class.
113
114    Returns:
115        The instance segmentation class with the cached state.
116    """
117    is_tiled = image_embeddings["input_size"] is None
118    amg = instance_segmentation.get_amg(predictor, is_tiled, decoder=decoder, **kwargs)
119
120    # If i is given we compute the state for a given slice/frame.
121    # And we have to save the state for slices/frames separately.
122    save_path = os.path.join(save_path, "is_state.h5")
123    save_key = "state" if i is None else f"state-{i}"
124
125    with h5py.File(save_path, "a") as f:
126        if save_key in f:
127            if skip_load:  # Skip loading to speed this up for cases where we don't need the return val.
128                return
129
130            if verbose:
131                print("Load instance segmentation state from", save_path, ":", save_key)
132            g = f[save_key]
133            state = {
134                "foreground": g["foreground"][:],
135                "boundary_distances": g["boundary_distances"][:],
136                "center_distances": g["center_distances"][:],
137            }
138            amg.set_state(state)
139            return amg
140
141    if verbose:
142        print("Precomputing the state for instance segmentation.")
143
144    amg.initialize(raw, image_embeddings=image_embeddings, verbose=verbose, i=i)
145    state = amg.get_state()
146
147    with h5py.File(save_path, "a") as f:
148        g = f.create_group(save_key)
149        g.create_dataset("foreground", data=state["foreground"], compression="gzip")
150        g.create_dataset("boundary_distances", data=state["boundary_distances"], compression="gzip")
151        g.create_dataset("center_distances", data=state["center_distances"], compression="gzip")
152
153    return amg
154
155
156def _precompute_state_for_file(
157    predictor, input_path, output_path, key, ndim, tile_shape, halo, precompute_amg_state, decoder, verbose
158):
159    if isinstance(input_path, np.ndarray):
160        image_data = input_path
161    else:
162        image_data = util.load_image_data(input_path, key)
163
164    # Precompute the image embeddings.
165    output_path = Path(output_path).with_suffix(".zarr")
166    embeddings = util.precompute_image_embeddings(
167        predictor, image_data, output_path, ndim=ndim, tile_shape=tile_shape, halo=halo, verbose=verbose
168    )
169
170    # Precompute the state for automatic instance segmnetaiton (AMG or AIS).
171    if precompute_amg_state:
172        if decoder is None:
173            cache_function = partial(
174                cache_amg_state, predictor=predictor, image_embeddings=embeddings, save_path=output_path
175            )
176        else:
177            cache_function = partial(
178                cache_is_state, predictor=predictor, decoder=decoder,
179                image_embeddings=embeddings, save_path=output_path
180            )
181
182        if ndim is None:
183            ndim = image_data.ndim
184
185        if ndim == 2:
186            cache_function(raw=image_data, verbose=verbose)
187        else:
188            n = image_data.shape[0]
189            for i in tqdm(range(n), total=n, desc="Precompute instance segmentation state", disable=not verbose):
190                cache_function(raw=image_data, i=i, verbose=False)
191
192
193def _precompute_state_for_files(
194    predictor: SamPredictor,
195    input_files: Union[List[Union[os.PathLike, str]], List[np.ndarray]],
196    output_path: Union[os.PathLike, str],
197    key: Optional[str] = None,
198    ndim: Optional[int] = None,
199    tile_shape: Optional[Tuple[int, int]] = None,
200    halo: Optional[Tuple[int, int]] = None,
201    precompute_amg_state: bool = False,
202    decoder: Optional["nn.Module"] = None,
203):
204    os.makedirs(output_path, exist_ok=True)
205    for i, file_path in enumerate(tqdm(input_files, total=len(input_files), desc="Precompute state for files")):
206
207        if isinstance(file_path, np.ndarray):
208            out_path = os.path.join(output_path, f"embedding_{i:05}.tif")
209        else:
210            out_path = os.path.join(output_path, os.path.basename(file_path))
211
212        _precompute_state_for_file(
213            predictor, file_path, out_path,
214            key=key, ndim=ndim, tile_shape=tile_shape, halo=halo,
215            precompute_amg_state=precompute_amg_state, decoder=decoder,
216            verbose=False,
217        )
218
219
220def precompute_state(
221    input_path: Union[os.PathLike, str],
222    output_path: Union[os.PathLike, str],
223    pattern: Optional[str] = None,
224    model_type: str = util._DEFAULT_MODEL,
225    checkpoint_path: Optional[Union[os.PathLike, str]] = None,
226    key: Optional[str] = None,
227    ndim: Optional[int] = None,
228    tile_shape: Optional[Tuple[int, int]] = None,
229    halo: Optional[Tuple[int, int]] = None,
230    precompute_amg_state: bool = False,
231) -> None:
232    """Precompute the image embeddings and other optional state for the input image(s).
233
234    Args:
235        input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
236            a container file (e.g. hdf5 or zarr) or a folder with images files.
237            In case of a container file the argument `key` must be given. In case of a folder
238            it can be given to provide a glob pattern to subselect files from the folder.
239        output_path: The output path where the embeddings and other state will be saved.
240        pattern: Glob pattern to select files in a folder. The embeddings will be computed
241            for each of these files. To select all files in a folder pass "*".
242        model_type: The SegmentAnything model to use. Will use the standard vit_l model by default.
243        checkpoint_path: Path to a checkpoint for a custom model.
244        key: The key to the input file. This is needed for contaner files (e.g. hdf5 or zarr)
245            or to load several images as 3d volume. Provide a glob pattern, e.g. "*.tif", for this case.
246        ndim: The dimensionality of the data.
247        tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
248        halo: Overlap of the tiles for tiled prediction.
249        precompute_amg_state: Whether to precompute the state for automatic instance segmentation
250            in addition to the image embeddings.
251    """
252    predictor, state = util.get_sam_model(
253        model_type=model_type, checkpoint_path=checkpoint_path, return_state=True,
254    )
255    if "decoder_state" in state:
256        decoder = instance_segmentation.get_decoder(predictor.model.image_encoder, state["decoder_state"])
257    else:
258        decoder = None
259
260    # Check if we precompute the state for a single file or for a folder with image files.
261    if pattern is None:
262        _precompute_state_for_file(
263            predictor, input_path, output_path, key,
264            ndim=ndim, tile_shape=tile_shape, halo=halo,
265            precompute_amg_state=precompute_amg_state,
266            decoder=decoder, verbose=True,
267        )
268    else:
269        input_files = glob(os.path.join(input_path, pattern))
270        _precompute_state_for_files(
271            predictor, input_files, output_path, key=key,
272            ndim=ndim, tile_shape=tile_shape, halo=halo,
273            precompute_amg_state=precompute_amg_state,
274            decoder=decoder,
275        )
276
277
278def main():
279    """@private"""
280    import argparse
281
282    available_models = list(util.get_model_names())
283    available_models = ", ".join(available_models)
284
285    parser = argparse.ArgumentParser(description="Compute the embeddings for an image.")
286    parser.add_argument(
287        "-i", "--input_path", required=True,
288        help="The filepath to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) "
289        "or elf.io.open_file (e.g. hdf5, zarr, mrc). For the latter you also need to pass the 'key' parameter."
290    )
291    parser.add_argument(
292        "-e", "--embedding_path", required=True, help="The path where the embeddings will be saved."
293    )
294
295    parser.add_argument(
296        "--pattern", help="Pattern / wildcard for selecting files in a folder. To select all files use '*'."
297    )
298    parser.add_argument(
299        "-k", "--key",
300        help="The key for opening data with elf.io.open_file. This is the internal path for a hdf5 or zarr container, "
301        "for an image stack it is a wild-card, e.g. '*.png' and for mrc it is 'data'."
302    )
303
304    parser.add_argument(
305        "-m", "--model_type", default=util._DEFAULT_MODEL,
306        help=f"The segment anything model that will be used, one of {available_models}."
307    )
308    parser.add_argument(
309        "-c", "--checkpoint", default=None,
310        help="Checkpoint from which the SAM model will be loaded loaded."
311    )
312    parser.add_argument(
313        "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction.", default=None
314    )
315    parser.add_argument(
316        "--halo", nargs="+", type=int, help="The halo for using tiled prediction.", default=None
317    )
318    parser.add_argument(
319        "-n", "--ndim", type=int, default=None,
320        help="The number of spatial dimensions in the data. "
321        "Please specify this if your data has a channel dimension."
322    )
323    parser.add_argument(
324        "-p", "--precompute_amg_state", action="store_true",
325        help="Whether to precompute the state for automatic instance segmentation."
326    )
327
328    args = parser.parse_args()
329    precompute_state(
330        args.input_path, args.embedding_path,
331        model_type=args.model_type, checkpoint_path=args.checkpoint,
332        pattern=args.pattern, key=args.key,
333        tile_shape=args.tile_shape, halo=args.halo, ndim=args.ndim,
334        precompute_amg_state=args.precompute_amg_state,
335    )
336
337
338if __name__ == "__main__":
339    main()
def cache_amg_state( predictor: segment_anything.predictor.SamPredictor, raw: numpy.ndarray, image_embeddings: Dict[str, Any], save_path: Union[str, os.PathLike], verbose: bool = True, i: Optional[int] = None, **kwargs) -> micro_sam.instance_segmentation.AMGBase:
28def cache_amg_state(
29    predictor: SamPredictor,
30    raw: np.ndarray,
31    image_embeddings: util.ImageEmbeddings,
32    save_path: Union[str, os.PathLike],
33    verbose: bool = True,
34    i: Optional[int] = None,
35    **kwargs,
36) -> instance_segmentation.AMGBase:
37    """Compute and cache or load the state for the automatic mask generator.
38
39    Args:
40        predictor: The segment anything predictor.
41        raw: The image data.
42        image_embeddings: The image embeddings.
43        save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'.
44        verbose: Whether to run the computation verbose.
45        i: The index for which to cache the state.
46        kwargs: The keyword arguments for the amg class.
47
48    Returns:
49        The automatic mask generator class with the cached state.
50    """
51    is_tiled = image_embeddings["input_size"] is None
52    amg = instance_segmentation.get_amg(predictor, is_tiled, **kwargs)
53
54    # If i is given we compute the state for a given slice/frame.
55    # And we have to save the state for slices/frames separately.
56    if i is None:
57        save_path_amg = os.path.join(save_path, "amg_state.pickle")
58    else:
59        os.makedirs(os.path.join(save_path, "amg_state"), exist_ok=True)
60        save_path_amg = os.path.join(save_path, "amg_state", f"state-{i}.pkl")
61
62    if os.path.exists(save_path_amg):
63        if verbose:
64            print("Load the AMG state from", save_path_amg)
65        with open(save_path_amg, "rb") as f:
66            amg_state = pickle.load(f)
67        amg.set_state(amg_state)
68        return amg
69
70    if verbose:
71        print("Precomputing the state for instance segmentation.")
72
73    amg.initialize(raw if i is None else raw[i], image_embeddings=image_embeddings, verbose=verbose, i=i)
74    amg_state = amg.get_state()
75
76    # put all state onto the cpu so that the state can be deserialized without a gpu
77    new_crop_list = []
78    for mask_data in amg_state["crop_list"]:
79        for k, v in mask_data.items():
80            if torch.is_tensor(v):
81                mask_data[k] = v.cpu()
82        new_crop_list.append(mask_data)
83    amg_state["crop_list"] = new_crop_list
84
85    with open(save_path_amg, "wb") as f:
86        pickle.dump(amg_state, f)
87
88    return amg

Compute and cache or load the state for the automatic mask generator.

Arguments:
  • predictor: The segment anything predictor.
  • raw: The image data.
  • image_embeddings: The image embeddings.
  • save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'.
  • verbose: Whether to run the computation verbose.
  • i: The index for which to cache the state.
  • kwargs: The keyword arguments for the amg class.
Returns:

The automatic mask generator class with the cached state.

def cache_is_state( predictor: segment_anything.predictor.SamPredictor, decoder: torch.nn.modules.module.Module, raw: numpy.ndarray, image_embeddings: Dict[str, Any], save_path: Union[str, os.PathLike], verbose: bool = True, i: Optional[int] = None, skip_load: bool = False, **kwargs) -> Optional[micro_sam.instance_segmentation.AMGBase]:
 91def cache_is_state(
 92    predictor: SamPredictor,
 93    decoder: torch.nn.Module,
 94    raw: np.ndarray,
 95    image_embeddings: util.ImageEmbeddings,
 96    save_path: Union[str, os.PathLike],
 97    verbose: bool = True,
 98    i: Optional[int] = None,
 99    skip_load: bool = False,
100    **kwargs,
101) -> Optional[instance_segmentation.AMGBase]:
102    """Compute and cache or load the state for the automatic mask generator.
103
104    Args:
105        predictor: The segment anything predictor.
106        decoder: The instance segmentation decoder.
107        raw: The image data.
108        image_embeddings: The image embeddings.
109        save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'.
110        verbose: Whether to run the computation verbose.
111        i: The index for which to cache the state.
112        skip_load: Skip loading the state if it is precomputed.
113        kwargs: The keyword arguments for the amg class.
114
115    Returns:
116        The instance segmentation class with the cached state.
117    """
118    is_tiled = image_embeddings["input_size"] is None
119    amg = instance_segmentation.get_amg(predictor, is_tiled, decoder=decoder, **kwargs)
120
121    # If i is given we compute the state for a given slice/frame.
122    # And we have to save the state for slices/frames separately.
123    save_path = os.path.join(save_path, "is_state.h5")
124    save_key = "state" if i is None else f"state-{i}"
125
126    with h5py.File(save_path, "a") as f:
127        if save_key in f:
128            if skip_load:  # Skip loading to speed this up for cases where we don't need the return val.
129                return
130
131            if verbose:
132                print("Load instance segmentation state from", save_path, ":", save_key)
133            g = f[save_key]
134            state = {
135                "foreground": g["foreground"][:],
136                "boundary_distances": g["boundary_distances"][:],
137                "center_distances": g["center_distances"][:],
138            }
139            amg.set_state(state)
140            return amg
141
142    if verbose:
143        print("Precomputing the state for instance segmentation.")
144
145    amg.initialize(raw, image_embeddings=image_embeddings, verbose=verbose, i=i)
146    state = amg.get_state()
147
148    with h5py.File(save_path, "a") as f:
149        g = f.create_group(save_key)
150        g.create_dataset("foreground", data=state["foreground"], compression="gzip")
151        g.create_dataset("boundary_distances", data=state["boundary_distances"], compression="gzip")
152        g.create_dataset("center_distances", data=state["center_distances"], compression="gzip")
153
154    return amg

Compute and cache or load the state for the automatic mask generator.

Arguments:
  • predictor: The segment anything predictor.
  • decoder: The instance segmentation decoder.
  • raw: The image data.
  • image_embeddings: The image embeddings.
  • save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'.
  • verbose: Whether to run the computation verbose.
  • i: The index for which to cache the state.
  • skip_load: Skip loading the state if it is precomputed.
  • kwargs: The keyword arguments for the amg class.
Returns:

The instance segmentation class with the cached state.

def precompute_state( input_path: Union[os.PathLike, str], output_path: Union[os.PathLike, str], pattern: Optional[str] = None, model_type: str = 'vit_l', checkpoint_path: Union[str, os.PathLike, NoneType] = None, key: Optional[str] = None, ndim: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, precompute_amg_state: bool = False) -> None:
221def precompute_state(
222    input_path: Union[os.PathLike, str],
223    output_path: Union[os.PathLike, str],
224    pattern: Optional[str] = None,
225    model_type: str = util._DEFAULT_MODEL,
226    checkpoint_path: Optional[Union[os.PathLike, str]] = None,
227    key: Optional[str] = None,
228    ndim: Optional[int] = None,
229    tile_shape: Optional[Tuple[int, int]] = None,
230    halo: Optional[Tuple[int, int]] = None,
231    precompute_amg_state: bool = False,
232) -> None:
233    """Precompute the image embeddings and other optional state for the input image(s).
234
235    Args:
236        input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
237            a container file (e.g. hdf5 or zarr) or a folder with images files.
238            In case of a container file the argument `key` must be given. In case of a folder
239            it can be given to provide a glob pattern to subselect files from the folder.
240        output_path: The output path where the embeddings and other state will be saved.
241        pattern: Glob pattern to select files in a folder. The embeddings will be computed
242            for each of these files. To select all files in a folder pass "*".
243        model_type: The SegmentAnything model to use. Will use the standard vit_l model by default.
244        checkpoint_path: Path to a checkpoint for a custom model.
245        key: The key to the input file. This is needed for contaner files (e.g. hdf5 or zarr)
246            or to load several images as 3d volume. Provide a glob pattern, e.g. "*.tif", for this case.
247        ndim: The dimensionality of the data.
248        tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
249        halo: Overlap of the tiles for tiled prediction.
250        precompute_amg_state: Whether to precompute the state for automatic instance segmentation
251            in addition to the image embeddings.
252    """
253    predictor, state = util.get_sam_model(
254        model_type=model_type, checkpoint_path=checkpoint_path, return_state=True,
255    )
256    if "decoder_state" in state:
257        decoder = instance_segmentation.get_decoder(predictor.model.image_encoder, state["decoder_state"])
258    else:
259        decoder = None
260
261    # Check if we precompute the state for a single file or for a folder with image files.
262    if pattern is None:
263        _precompute_state_for_file(
264            predictor, input_path, output_path, key,
265            ndim=ndim, tile_shape=tile_shape, halo=halo,
266            precompute_amg_state=precompute_amg_state,
267            decoder=decoder, verbose=True,
268        )
269    else:
270        input_files = glob(os.path.join(input_path, pattern))
271        _precompute_state_for_files(
272            predictor, input_files, output_path, key=key,
273            ndim=ndim, tile_shape=tile_shape, halo=halo,
274            precompute_amg_state=precompute_amg_state,
275            decoder=decoder,
276        )

Precompute the image embeddings and other optional state for the input image(s).

Arguments:
  • input_path: The input image file(s). Can either be a single image file (e.g. tif or png), a container file (e.g. hdf5 or zarr) or a folder with images files. In case of a container file the argument key must be given. In case of a folder it can be given to provide a glob pattern to subselect files from the folder.
  • output_path: The output path where the embeddings and other state will be saved.
  • pattern: Glob pattern to select files in a folder. The embeddings will be computed for each of these files. To select all files in a folder pass "*".
  • model_type: The SegmentAnything model to use. Will use the standard vit_l model by default.
  • checkpoint_path: Path to a checkpoint for a custom model.
  • key: The key to the input file. This is needed for contaner files (e.g. hdf5 or zarr) or to load several images as 3d volume. Provide a glob pattern, e.g. "*.tif", for this case.
  • ndim: The dimensionality of the data.
  • tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
  • halo: Overlap of the tiles for tiled prediction.
  • precompute_amg_state: Whether to precompute the state for automatic instance segmentation in addition to the image embeddings.