micro_sam.precompute_state

Precompute image embeddings and automatic mask generator state for image data.

  1"""Precompute image embeddings and automatic mask generator state for image data.
  2"""
  3
  4import os
  5import pickle
  6from glob import glob
  7from pathlib import Path
  8from functools import partial
  9from typing import Optional, Tuple, Union, List
 10
 11import h5py
 12import numpy as np
 13
 14import torch
 15import torch.nn as nn
 16
 17from segment_anything.predictor import SamPredictor
 18
 19try:
 20    from napari.utils import progress as tqdm
 21except ImportError:
 22    from tqdm import tqdm
 23
 24from . import instance_segmentation, util
 25
 26
 27def cache_amg_state(
 28    predictor: SamPredictor,
 29    raw: np.ndarray,
 30    image_embeddings: util.ImageEmbeddings,
 31    save_path: Union[str, os.PathLike],
 32    verbose: bool = True,
 33    i: Optional[int] = None,
 34    **kwargs,
 35) -> instance_segmentation.AMGBase:
 36    """Compute and cache or load the state for the automatic mask generator.
 37
 38    Args:
 39        predictor: The Segment Anything predictor.
 40        raw: The image data.
 41        image_embeddings: The image embeddings.
 42        save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'.
 43        verbose: Whether to run the computation verbose. By default, set to 'True'.
 44        i: The index for which to cache the state.
 45        kwargs: The keyword arguments for the amg class.
 46
 47    Returns:
 48        The automatic mask generator class with the cached state.
 49    """
 50    is_tiled = image_embeddings["input_size"] is None
 51    amg = instance_segmentation.get_instance_segmentation_generator(predictor, is_tiled=is_tiled, **kwargs)
 52
 53    # If i is given we compute the state for a given slice/frame.
 54    # And we have to save the state for slices/frames separately.
 55    if i is None:
 56        save_path_amg = os.path.join(save_path, "amg_state.pickle")
 57    else:
 58        os.makedirs(os.path.join(save_path, "amg_state"), exist_ok=True)
 59        save_path_amg = os.path.join(save_path, "amg_state", f"state-{i}.pkl")
 60
 61    if os.path.exists(save_path_amg):
 62        if verbose:
 63            print("Load the AMG state from", save_path_amg)
 64        with open(save_path_amg, "rb") as f:
 65            amg_state = pickle.load(f)
 66        amg.set_state(amg_state)
 67        return amg
 68
 69    if verbose:
 70        print("Precomputing the state for instance segmentation.")
 71
 72    amg.initialize(raw if i is None else raw[i], image_embeddings=image_embeddings, verbose=verbose, i=i)
 73    amg_state = amg.get_state()
 74
 75    # put all state onto the cpu so that the state can be deserialized without a gpu
 76    new_crop_list = []
 77    for mask_data in amg_state["crop_list"]:
 78        for k, v in mask_data.items():
 79            if torch.is_tensor(v):
 80                mask_data[k] = v.cpu()
 81        new_crop_list.append(mask_data)
 82    amg_state["crop_list"] = new_crop_list
 83
 84    with open(save_path_amg, "wb") as f:
 85        pickle.dump(amg_state, f)
 86
 87    return amg
 88
 89
 90def cache_is_state(
 91    predictor: SamPredictor,
 92    decoder: torch.nn.Module,
 93    raw: np.ndarray,
 94    image_embeddings: util.ImageEmbeddings,
 95    save_path: Union[str, os.PathLike],
 96    verbose: bool = True,
 97    i: Optional[int] = None,
 98    skip_load: bool = False,
 99    **kwargs,
100) -> Optional[instance_segmentation.AMGBase]:
101    """Compute and cache or load the state for the automatic mask generator.
102
103    Args:
104        predictor: The Segment Anything predictor.
105        decoder: The instance segmentation decoder.
106        raw: The image data.
107        image_embeddings: The image embeddings.
108        save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'.
109        verbose: Whether to run the computation verbose. By default, set to 'True'.
110        i: The index for which to cache the state.
111        skip_load: Skip loading the state if it is precomputed. By default, set to 'False'.
112        kwargs: The keyword arguments for the amg class.
113
114    Returns:
115        The instance segmentation class with the cached state.
116    """
117    is_tiled = image_embeddings["input_size"] is None
118    amg = instance_segmentation.get_instance_segmentation_generator(
119        predictor, is_tiled=is_tiled, decoder=decoder, **kwargs
120    )
121
122    # If i is given we compute the state for a given slice/frame.
123    # And we have to save the state for slices/frames separately.
124    save_path = os.path.join(save_path, "is_state.h5")
125    save_key = "state" if i is None else f"state-{i}"
126
127    with h5py.File(save_path, "a") as f:
128        if save_key in f:
129            if skip_load:  # Skip loading to speed this up for cases where we don't need the return val.
130                return
131
132            if verbose:
133                print("Load instance segmentation state from", save_path, ":", save_key)
134            g = f[save_key]
135            state = {
136                "foreground": g["foreground"][:],
137                "boundary_distances": g["boundary_distances"][:],
138                "center_distances": g["center_distances"][:],
139            }
140            amg.set_state(state)
141            return amg
142
143    if verbose:
144        print("Precomputing the state for instance segmentation.")
145
146    amg.initialize(raw, image_embeddings=image_embeddings, verbose=verbose, i=i)
147    state = amg.get_state()
148
149    with h5py.File(save_path, "a") as f:
150        g = f.create_group(save_key)
151        g.create_dataset("foreground", data=state["foreground"], compression="gzip")
152        g.create_dataset("boundary_distances", data=state["boundary_distances"], compression="gzip")
153        g.create_dataset("center_distances", data=state["center_distances"], compression="gzip")
154
155    return amg
156
157
158def _precompute_state_for_file(
159    predictor, input_path, output_path, key, ndim, tile_shape, halo, precompute_amg_state, decoder, verbose
160):
161    if isinstance(input_path, np.ndarray):
162        image_data = input_path
163    else:
164        image_data = util.load_image_data(input_path, key)
165
166    # Precompute the image embeddings.
167    output_path = Path(output_path).with_suffix(".zarr")
168    embeddings = util.precompute_image_embeddings(
169        predictor, image_data, output_path, ndim=ndim, tile_shape=tile_shape, halo=halo, verbose=verbose
170    )
171
172    # Precompute the state for automatic instance segmnetaiton (AMG or AIS).
173    if precompute_amg_state:
174        if decoder is None:
175            cache_function = partial(
176                cache_amg_state, predictor=predictor, image_embeddings=embeddings, save_path=output_path
177            )
178        else:
179            cache_function = partial(
180                cache_is_state, predictor=predictor, decoder=decoder,
181                image_embeddings=embeddings, save_path=output_path
182            )
183
184        if ndim is None:
185            ndim = image_data.ndim
186
187        if ndim == 2:
188            cache_function(raw=image_data, verbose=verbose)
189        else:
190            n = image_data.shape[0]
191            for i in tqdm(range(n), total=n, desc="Precompute instance segmentation state", disable=not verbose):
192                cache_function(raw=image_data, i=i, verbose=False)
193
194
195def _precompute_state_for_files(
196    predictor: SamPredictor,
197    input_files: Union[List[Union[os.PathLike, str]], List[np.ndarray]],
198    output_path: Union[os.PathLike, str],
199    key: Optional[str] = None,
200    ndim: Optional[int] = None,
201    tile_shape: Optional[Tuple[int, int]] = None,
202    halo: Optional[Tuple[int, int]] = None,
203    precompute_amg_state: bool = False,
204    decoder: Optional["nn.Module"] = None,
205):
206    os.makedirs(output_path, exist_ok=True)
207    idx = 0
208    for file_path in tqdm(input_files, total=len(input_files), desc="Precompute state for files"):
209
210        if isinstance(file_path, np.ndarray):
211            out_path = os.path.join(output_path, f"embedding_{idx:05}.tif")
212        else:
213            out_path = os.path.join(output_path, os.path.basename(file_path))
214
215        _precompute_state_for_file(
216            predictor, file_path, out_path,
217            key=key, ndim=ndim, tile_shape=tile_shape, halo=halo,
218            precompute_amg_state=precompute_amg_state, decoder=decoder,
219            verbose=False,
220        )
221        idx += 1
222
223
224def precompute_state(
225    input_path: Union[os.PathLike, str],
226    output_path: Union[os.PathLike, str],
227    pattern: Optional[str] = None,
228    model_type: str = util._DEFAULT_MODEL,
229    checkpoint_path: Optional[Union[os.PathLike, str]] = None,
230    key: Optional[str] = None,
231    ndim: Optional[int] = None,
232    tile_shape: Optional[Tuple[int, int]] = None,
233    halo: Optional[Tuple[int, int]] = None,
234    precompute_amg_state: bool = False,
235) -> None:
236    """Precompute the image embeddings and other optional state for the input image(s).
237
238    Args:
239        input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
240            a container file (e.g. hdf5 or zarr) or a folder with images files.
241            In case of a container file the argument `key` must be given. In case of a folder
242            it can be given to provide a glob pattern to subselect files from the folder.
243        output_path: The output path where the embeddings and other state will be saved.
244        pattern: Glob pattern to select files in a folder. The embeddings will be computed
245            for each of these files. To select all files in a folder pass "*".
246        model_type: The Segment Anything model to use. Will use the `vit_b_lm` model by default.
247        checkpoint_path: Path to a checkpoint for a custom model.
248        key: The key to the input file. This is needed for contaner files (e.g. hdf5 or zarr)
249            or to load several images as 3d volume. Provide a glob pattern, e.g. "*.tif", for this case.
250        ndim: The dimensionality of the data. By default, computes it from the input data.
251        tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
252        halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
253        precompute_amg_state: Whether to precompute the state for automatic instance segmentation
254            in addition to the image embeddings.
255    """
256    predictor, state = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, return_state=True)
257
258    if "decoder_state" in state:
259        decoder = instance_segmentation.get_decoder(predictor.model.image_encoder, state["decoder_state"])
260    else:
261        decoder = None
262
263    # Check if we precompute the state for a single file or for a folder with image files.
264    if pattern is None:
265        _precompute_state_for_file(
266            predictor, input_path, output_path, key,
267            ndim=ndim, tile_shape=tile_shape, halo=halo,
268            precompute_amg_state=precompute_amg_state,
269            decoder=decoder, verbose=True,
270        )
271    else:
272        input_files = glob(os.path.join(input_path, pattern))
273        _precompute_state_for_files(
274            predictor, input_files, output_path, key=key,
275            ndim=ndim, tile_shape=tile_shape, halo=halo,
276            precompute_amg_state=precompute_amg_state,
277            decoder=decoder,
278        )
279
280
281def main():
282    """@private"""
283    import argparse
284
285    available_models = list(util.get_model_names())
286    available_models = ", ".join(available_models)
287
288    parser = argparse.ArgumentParser(description="Compute the embeddings for an image.")
289    parser.add_argument(
290        "-i", "--input_path", required=True,
291        help="The filepath to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) "
292        "or elf.io.open_file (e.g. hdf5, zarr, mrc). For the latter you also need to pass the 'key' parameter."
293    )
294    parser.add_argument(
295        "-e", "--embedding_path", required=True, help="The path where the embeddings will be saved."
296    )
297    parser.add_argument(
298        "--pattern", help="Pattern / wildcard for selecting files in a folder. To select all files use '*'."
299    )
300    parser.add_argument(
301        "-k", "--key",
302        help="The key for opening data with elf.io.open_file. This is the internal path for a hdf5 or zarr container, "
303        "for an image stack it is a wild-card, e.g. '*.png' and for mrc it is 'data'."
304    )
305    parser.add_argument(
306        "-m", "--model_type", default=util._DEFAULT_MODEL,
307        help=f"The segment anything model that will be used, one of {available_models}."
308    )
309    parser.add_argument(
310        "-c", "--checkpoint", default=None,
311        help="Checkpoint from which the SAM model will be loaded loaded."
312    )
313    parser.add_argument(
314        "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction.", default=None
315    )
316    parser.add_argument(
317        "--halo", nargs="+", type=int, help="The halo for using tiled prediction.", default=None
318    )
319    parser.add_argument(
320        "-n", "--ndim", type=int, default=None,
321        help="The number of spatial dimensions in the data. "
322        "Please specify this if your data has a channel dimension."
323    )
324    parser.add_argument(
325        "-p", "--precompute_amg_state", action="store_true",
326        help="Whether to precompute the state for automatic instance segmentation."
327    )
328
329    args = parser.parse_args()
330    precompute_state(
331        args.input_path, args.embedding_path,
332        model_type=args.model_type, checkpoint_path=args.checkpoint,
333        pattern=args.pattern, key=args.key,
334        tile_shape=args.tile_shape, halo=args.halo, ndim=args.ndim,
335        precompute_amg_state=args.precompute_amg_state,
336    )
337
338
339if __name__ == "__main__":
340    main()
def cache_amg_state( predictor: segment_anything.predictor.SamPredictor, raw: numpy.ndarray, image_embeddings: Dict[str, Any], save_path: Union[str, os.PathLike], verbose: bool = True, i: Optional[int] = None, **kwargs) -> micro_sam.instance_segmentation.AMGBase:
28def cache_amg_state(
29    predictor: SamPredictor,
30    raw: np.ndarray,
31    image_embeddings: util.ImageEmbeddings,
32    save_path: Union[str, os.PathLike],
33    verbose: bool = True,
34    i: Optional[int] = None,
35    **kwargs,
36) -> instance_segmentation.AMGBase:
37    """Compute and cache or load the state for the automatic mask generator.
38
39    Args:
40        predictor: The Segment Anything predictor.
41        raw: The image data.
42        image_embeddings: The image embeddings.
43        save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'.
44        verbose: Whether to run the computation verbose. By default, set to 'True'.
45        i: The index for which to cache the state.
46        kwargs: The keyword arguments for the amg class.
47
48    Returns:
49        The automatic mask generator class with the cached state.
50    """
51    is_tiled = image_embeddings["input_size"] is None
52    amg = instance_segmentation.get_instance_segmentation_generator(predictor, is_tiled=is_tiled, **kwargs)
53
54    # If i is given we compute the state for a given slice/frame.
55    # And we have to save the state for slices/frames separately.
56    if i is None:
57        save_path_amg = os.path.join(save_path, "amg_state.pickle")
58    else:
59        os.makedirs(os.path.join(save_path, "amg_state"), exist_ok=True)
60        save_path_amg = os.path.join(save_path, "amg_state", f"state-{i}.pkl")
61
62    if os.path.exists(save_path_amg):
63        if verbose:
64            print("Load the AMG state from", save_path_amg)
65        with open(save_path_amg, "rb") as f:
66            amg_state = pickle.load(f)
67        amg.set_state(amg_state)
68        return amg
69
70    if verbose:
71        print("Precomputing the state for instance segmentation.")
72
73    amg.initialize(raw if i is None else raw[i], image_embeddings=image_embeddings, verbose=verbose, i=i)
74    amg_state = amg.get_state()
75
76    # put all state onto the cpu so that the state can be deserialized without a gpu
77    new_crop_list = []
78    for mask_data in amg_state["crop_list"]:
79        for k, v in mask_data.items():
80            if torch.is_tensor(v):
81                mask_data[k] = v.cpu()
82        new_crop_list.append(mask_data)
83    amg_state["crop_list"] = new_crop_list
84
85    with open(save_path_amg, "wb") as f:
86        pickle.dump(amg_state, f)
87
88    return amg

Compute and cache or load the state for the automatic mask generator.

Arguments:
  • predictor: The Segment Anything predictor.
  • raw: The image data.
  • image_embeddings: The image embeddings.
  • save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'.
  • verbose: Whether to run the computation verbose. By default, set to 'True'.
  • i: The index for which to cache the state.
  • kwargs: The keyword arguments for the amg class.
Returns:

The automatic mask generator class with the cached state.

def cache_is_state( predictor: segment_anything.predictor.SamPredictor, decoder: torch.nn.modules.module.Module, raw: numpy.ndarray, image_embeddings: Dict[str, Any], save_path: Union[str, os.PathLike], verbose: bool = True, i: Optional[int] = None, skip_load: bool = False, **kwargs) -> Optional[micro_sam.instance_segmentation.AMGBase]:
 91def cache_is_state(
 92    predictor: SamPredictor,
 93    decoder: torch.nn.Module,
 94    raw: np.ndarray,
 95    image_embeddings: util.ImageEmbeddings,
 96    save_path: Union[str, os.PathLike],
 97    verbose: bool = True,
 98    i: Optional[int] = None,
 99    skip_load: bool = False,
100    **kwargs,
101) -> Optional[instance_segmentation.AMGBase]:
102    """Compute and cache or load the state for the automatic mask generator.
103
104    Args:
105        predictor: The Segment Anything predictor.
106        decoder: The instance segmentation decoder.
107        raw: The image data.
108        image_embeddings: The image embeddings.
109        save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'.
110        verbose: Whether to run the computation verbose. By default, set to 'True'.
111        i: The index for which to cache the state.
112        skip_load: Skip loading the state if it is precomputed. By default, set to 'False'.
113        kwargs: The keyword arguments for the amg class.
114
115    Returns:
116        The instance segmentation class with the cached state.
117    """
118    is_tiled = image_embeddings["input_size"] is None
119    amg = instance_segmentation.get_instance_segmentation_generator(
120        predictor, is_tiled=is_tiled, decoder=decoder, **kwargs
121    )
122
123    # If i is given we compute the state for a given slice/frame.
124    # And we have to save the state for slices/frames separately.
125    save_path = os.path.join(save_path, "is_state.h5")
126    save_key = "state" if i is None else f"state-{i}"
127
128    with h5py.File(save_path, "a") as f:
129        if save_key in f:
130            if skip_load:  # Skip loading to speed this up for cases where we don't need the return val.
131                return
132
133            if verbose:
134                print("Load instance segmentation state from", save_path, ":", save_key)
135            g = f[save_key]
136            state = {
137                "foreground": g["foreground"][:],
138                "boundary_distances": g["boundary_distances"][:],
139                "center_distances": g["center_distances"][:],
140            }
141            amg.set_state(state)
142            return amg
143
144    if verbose:
145        print("Precomputing the state for instance segmentation.")
146
147    amg.initialize(raw, image_embeddings=image_embeddings, verbose=verbose, i=i)
148    state = amg.get_state()
149
150    with h5py.File(save_path, "a") as f:
151        g = f.create_group(save_key)
152        g.create_dataset("foreground", data=state["foreground"], compression="gzip")
153        g.create_dataset("boundary_distances", data=state["boundary_distances"], compression="gzip")
154        g.create_dataset("center_distances", data=state["center_distances"], compression="gzip")
155
156    return amg

Compute and cache or load the state for the automatic mask generator.

Arguments:
  • predictor: The Segment Anything predictor.
  • decoder: The instance segmentation decoder.
  • raw: The image data.
  • image_embeddings: The image embeddings.
  • save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'.
  • verbose: Whether to run the computation verbose. By default, set to 'True'.
  • i: The index for which to cache the state.
  • skip_load: Skip loading the state if it is precomputed. By default, set to 'False'.
  • kwargs: The keyword arguments for the amg class.
Returns:

The instance segmentation class with the cached state.

def precompute_state( input_path: Union[os.PathLike, str], output_path: Union[os.PathLike, str], pattern: Optional[str] = None, model_type: str = 'vit_b_lm', checkpoint_path: Union[os.PathLike, str, NoneType] = None, key: Optional[str] = None, ndim: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, precompute_amg_state: bool = False) -> None:
225def precompute_state(
226    input_path: Union[os.PathLike, str],
227    output_path: Union[os.PathLike, str],
228    pattern: Optional[str] = None,
229    model_type: str = util._DEFAULT_MODEL,
230    checkpoint_path: Optional[Union[os.PathLike, str]] = None,
231    key: Optional[str] = None,
232    ndim: Optional[int] = None,
233    tile_shape: Optional[Tuple[int, int]] = None,
234    halo: Optional[Tuple[int, int]] = None,
235    precompute_amg_state: bool = False,
236) -> None:
237    """Precompute the image embeddings and other optional state for the input image(s).
238
239    Args:
240        input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
241            a container file (e.g. hdf5 or zarr) or a folder with images files.
242            In case of a container file the argument `key` must be given. In case of a folder
243            it can be given to provide a glob pattern to subselect files from the folder.
244        output_path: The output path where the embeddings and other state will be saved.
245        pattern: Glob pattern to select files in a folder. The embeddings will be computed
246            for each of these files. To select all files in a folder pass "*".
247        model_type: The Segment Anything model to use. Will use the `vit_b_lm` model by default.
248        checkpoint_path: Path to a checkpoint for a custom model.
249        key: The key to the input file. This is needed for contaner files (e.g. hdf5 or zarr)
250            or to load several images as 3d volume. Provide a glob pattern, e.g. "*.tif", for this case.
251        ndim: The dimensionality of the data. By default, computes it from the input data.
252        tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
253        halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
254        precompute_amg_state: Whether to precompute the state for automatic instance segmentation
255            in addition to the image embeddings.
256    """
257    predictor, state = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, return_state=True)
258
259    if "decoder_state" in state:
260        decoder = instance_segmentation.get_decoder(predictor.model.image_encoder, state["decoder_state"])
261    else:
262        decoder = None
263
264    # Check if we precompute the state for a single file or for a folder with image files.
265    if pattern is None:
266        _precompute_state_for_file(
267            predictor, input_path, output_path, key,
268            ndim=ndim, tile_shape=tile_shape, halo=halo,
269            precompute_amg_state=precompute_amg_state,
270            decoder=decoder, verbose=True,
271        )
272    else:
273        input_files = glob(os.path.join(input_path, pattern))
274        _precompute_state_for_files(
275            predictor, input_files, output_path, key=key,
276            ndim=ndim, tile_shape=tile_shape, halo=halo,
277            precompute_amg_state=precompute_amg_state,
278            decoder=decoder,
279        )

Precompute the image embeddings and other optional state for the input image(s).

Arguments:
  • input_path: The input image file(s). Can either be a single image file (e.g. tif or png), a container file (e.g. hdf5 or zarr) or a folder with images files. In case of a container file the argument key must be given. In case of a folder it can be given to provide a glob pattern to subselect files from the folder.
  • output_path: The output path where the embeddings and other state will be saved.
  • pattern: Glob pattern to select files in a folder. The embeddings will be computed for each of these files. To select all files in a folder pass "*".
  • model_type: The Segment Anything model to use. Will use the vit_b_lm model by default.
  • checkpoint_path: Path to a checkpoint for a custom model.
  • key: The key to the input file. This is needed for contaner files (e.g. hdf5 or zarr) or to load several images as 3d volume. Provide a glob pattern, e.g. "*.tif", for this case.
  • ndim: The dimensionality of the data. By default, computes it from the input data.
  • tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
  • halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
  • precompute_amg_state: Whether to precompute the state for automatic instance segmentation in addition to the image embeddings.