micro_sam.precompute_state
Precompute image embeddings and automatic mask generator state for image data.
1"""Precompute image embeddings and automatic mask generator state for image data. 2""" 3 4import os 5import pickle 6from glob import glob 7from pathlib import Path 8from functools import partial 9from typing import Optional, Tuple, Union, List 10 11import h5py 12import numpy as np 13 14import torch 15import torch.nn as nn 16 17from segment_anything.predictor import SamPredictor 18 19try: 20 from napari.utils import progress as tqdm 21except ImportError: 22 from tqdm import tqdm 23 24from . import instance_segmentation, util 25 26 27def cache_amg_state( 28 predictor: SamPredictor, 29 raw: np.ndarray, 30 image_embeddings: util.ImageEmbeddings, 31 save_path: Union[str, os.PathLike], 32 verbose: bool = True, 33 i: Optional[int] = None, 34 **kwargs, 35) -> instance_segmentation.AMGBase: 36 """Compute and cache or load the state for the automatic mask generator. 37 38 Args: 39 predictor: The segment anything predictor. 40 raw: The image data. 41 image_embeddings: The image embeddings. 42 save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'. 43 verbose: Whether to run the computation verbose. 44 i: The index for which to cache the state. 45 kwargs: The keyword arguments for the amg class. 46 47 Returns: 48 The automatic mask generator class with the cached state. 49 """ 50 is_tiled = image_embeddings["input_size"] is None 51 amg = instance_segmentation.get_amg(predictor, is_tiled, **kwargs) 52 53 # If i is given we compute the state for a given slice/frame. 54 # And we have to save the state for slices/frames separately. 55 if i is None: 56 save_path_amg = os.path.join(save_path, "amg_state.pickle") 57 else: 58 os.makedirs(os.path.join(save_path, "amg_state"), exist_ok=True) 59 save_path_amg = os.path.join(save_path, "amg_state", f"state-{i}.pkl") 60 61 if os.path.exists(save_path_amg): 62 if verbose: 63 print("Load the AMG state from", save_path_amg) 64 with open(save_path_amg, "rb") as f: 65 amg_state = pickle.load(f) 66 amg.set_state(amg_state) 67 return amg 68 69 if verbose: 70 print("Precomputing the state for instance segmentation.") 71 72 amg.initialize(raw if i is None else raw[i], image_embeddings=image_embeddings, verbose=verbose, i=i) 73 amg_state = amg.get_state() 74 75 # put all state onto the cpu so that the state can be deserialized without a gpu 76 new_crop_list = [] 77 for mask_data in amg_state["crop_list"]: 78 for k, v in mask_data.items(): 79 if torch.is_tensor(v): 80 mask_data[k] = v.cpu() 81 new_crop_list.append(mask_data) 82 amg_state["crop_list"] = new_crop_list 83 84 with open(save_path_amg, "wb") as f: 85 pickle.dump(amg_state, f) 86 87 return amg 88 89 90def cache_is_state( 91 predictor: SamPredictor, 92 decoder: torch.nn.Module, 93 raw: np.ndarray, 94 image_embeddings: util.ImageEmbeddings, 95 save_path: Union[str, os.PathLike], 96 verbose: bool = True, 97 i: Optional[int] = None, 98 skip_load: bool = False, 99 **kwargs, 100) -> Optional[instance_segmentation.AMGBase]: 101 """Compute and cache or load the state for the automatic mask generator. 102 103 Args: 104 predictor: The segment anything predictor. 105 decoder: The instance segmentation decoder. 106 raw: The image data. 107 image_embeddings: The image embeddings. 108 save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'. 109 verbose: Whether to run the computation verbose. 110 i: The index for which to cache the state. 111 skip_load: Skip loading the state if it is precomputed. 112 kwargs: The keyword arguments for the amg class. 113 114 Returns: 115 The instance segmentation class with the cached state. 116 """ 117 is_tiled = image_embeddings["input_size"] is None 118 amg = instance_segmentation.get_amg(predictor, is_tiled, decoder=decoder, **kwargs) 119 120 # If i is given we compute the state for a given slice/frame. 121 # And we have to save the state for slices/frames separately. 122 save_path = os.path.join(save_path, "is_state.h5") 123 save_key = "state" if i is None else f"state-{i}" 124 125 with h5py.File(save_path, "a") as f: 126 if save_key in f: 127 if skip_load: # Skip loading to speed this up for cases where we don't need the return val. 128 return 129 130 if verbose: 131 print("Load instance segmentation state from", save_path, ":", save_key) 132 g = f[save_key] 133 state = { 134 "foreground": g["foreground"][:], 135 "boundary_distances": g["boundary_distances"][:], 136 "center_distances": g["center_distances"][:], 137 } 138 amg.set_state(state) 139 return amg 140 141 if verbose: 142 print("Precomputing the state for instance segmentation.") 143 144 amg.initialize(raw, image_embeddings=image_embeddings, verbose=verbose, i=i) 145 state = amg.get_state() 146 147 with h5py.File(save_path, "a") as f: 148 g = f.create_group(save_key) 149 g.create_dataset("foreground", data=state["foreground"], compression="gzip") 150 g.create_dataset("boundary_distances", data=state["boundary_distances"], compression="gzip") 151 g.create_dataset("center_distances", data=state["center_distances"], compression="gzip") 152 153 return amg 154 155 156def _precompute_state_for_file( 157 predictor, input_path, output_path, key, ndim, tile_shape, halo, precompute_amg_state, decoder, verbose 158): 159 if isinstance(input_path, np.ndarray): 160 image_data = input_path 161 else: 162 image_data = util.load_image_data(input_path, key) 163 164 # Precompute the image embeddings. 165 output_path = Path(output_path).with_suffix(".zarr") 166 embeddings = util.precompute_image_embeddings( 167 predictor, image_data, output_path, ndim=ndim, tile_shape=tile_shape, halo=halo, verbose=verbose 168 ) 169 170 # Precompute the state for automatic instance segmnetaiton (AMG or AIS). 171 if precompute_amg_state: 172 if decoder is None: 173 cache_function = partial( 174 cache_amg_state, predictor=predictor, image_embeddings=embeddings, save_path=output_path 175 ) 176 else: 177 cache_function = partial( 178 cache_is_state, predictor=predictor, decoder=decoder, 179 image_embeddings=embeddings, save_path=output_path 180 ) 181 182 if ndim is None: 183 ndim = image_data.ndim 184 185 if ndim == 2: 186 cache_function(raw=image_data, verbose=verbose) 187 else: 188 n = image_data.shape[0] 189 for i in tqdm(range(n), total=n, desc="Precompute instance segmentation state", disable=not verbose): 190 cache_function(raw=image_data, i=i, verbose=False) 191 192 193def _precompute_state_for_files( 194 predictor: SamPredictor, 195 input_files: Union[List[Union[os.PathLike, str]], List[np.ndarray]], 196 output_path: Union[os.PathLike, str], 197 key: Optional[str] = None, 198 ndim: Optional[int] = None, 199 tile_shape: Optional[Tuple[int, int]] = None, 200 halo: Optional[Tuple[int, int]] = None, 201 precompute_amg_state: bool = False, 202 decoder: Optional["nn.Module"] = None, 203): 204 os.makedirs(output_path, exist_ok=True) 205 for i, file_path in enumerate(tqdm(input_files, total=len(input_files), desc="Precompute state for files")): 206 207 if isinstance(file_path, np.ndarray): 208 out_path = os.path.join(output_path, f"embedding_{i:05}.tif") 209 else: 210 out_path = os.path.join(output_path, os.path.basename(file_path)) 211 212 _precompute_state_for_file( 213 predictor, file_path, out_path, 214 key=key, ndim=ndim, tile_shape=tile_shape, halo=halo, 215 precompute_amg_state=precompute_amg_state, decoder=decoder, 216 verbose=False, 217 ) 218 219 220def precompute_state( 221 input_path: Union[os.PathLike, str], 222 output_path: Union[os.PathLike, str], 223 pattern: Optional[str] = None, 224 model_type: str = util._DEFAULT_MODEL, 225 checkpoint_path: Optional[Union[os.PathLike, str]] = None, 226 key: Optional[str] = None, 227 ndim: Optional[int] = None, 228 tile_shape: Optional[Tuple[int, int]] = None, 229 halo: Optional[Tuple[int, int]] = None, 230 precompute_amg_state: bool = False, 231) -> None: 232 """Precompute the image embeddings and other optional state for the input image(s). 233 234 Args: 235 input_path: The input image file(s). Can either be a single image file (e.g. tif or png), 236 a container file (e.g. hdf5 or zarr) or a folder with images files. 237 In case of a container file the argument `key` must be given. In case of a folder 238 it can be given to provide a glob pattern to subselect files from the folder. 239 output_path: The output path where the embeddings and other state will be saved. 240 pattern: Glob pattern to select files in a folder. The embeddings will be computed 241 for each of these files. To select all files in a folder pass "*". 242 model_type: The SegmentAnything model to use. Will use the standard vit_l model by default. 243 checkpoint_path: Path to a checkpoint for a custom model. 244 key: The key to the input file. This is needed for contaner files (e.g. hdf5 or zarr) 245 or to load several images as 3d volume. Provide a glob pattern, e.g. "*.tif", for this case. 246 ndim: The dimensionality of the data. 247 tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling. 248 halo: Overlap of the tiles for tiled prediction. 249 precompute_amg_state: Whether to precompute the state for automatic instance segmentation 250 in addition to the image embeddings. 251 """ 252 predictor, state = util.get_sam_model( 253 model_type=model_type, checkpoint_path=checkpoint_path, return_state=True, 254 ) 255 if "decoder_state" in state: 256 decoder = instance_segmentation.get_decoder(predictor.model.image_encoder, state["decoder_state"]) 257 else: 258 decoder = None 259 260 # Check if we precompute the state for a single file or for a folder with image files. 261 if pattern is None: 262 _precompute_state_for_file( 263 predictor, input_path, output_path, key, 264 ndim=ndim, tile_shape=tile_shape, halo=halo, 265 precompute_amg_state=precompute_amg_state, 266 decoder=decoder, verbose=True, 267 ) 268 else: 269 input_files = glob(os.path.join(input_path, pattern)) 270 _precompute_state_for_files( 271 predictor, input_files, output_path, key=key, 272 ndim=ndim, tile_shape=tile_shape, halo=halo, 273 precompute_amg_state=precompute_amg_state, 274 decoder=decoder, 275 ) 276 277 278def main(): 279 """@private""" 280 import argparse 281 282 available_models = list(util.get_model_names()) 283 available_models = ", ".join(available_models) 284 285 parser = argparse.ArgumentParser(description="Compute the embeddings for an image.") 286 parser.add_argument( 287 "-i", "--input_path", required=True, 288 help="The filepath to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) " 289 "or elf.io.open_file (e.g. hdf5, zarr, mrc). For the latter you also need to pass the 'key' parameter." 290 ) 291 parser.add_argument( 292 "-e", "--embedding_path", required=True, help="The path where the embeddings will be saved." 293 ) 294 295 parser.add_argument( 296 "--pattern", help="Pattern / wildcard for selecting files in a folder. To select all files use '*'." 297 ) 298 parser.add_argument( 299 "-k", "--key", 300 help="The key for opening data with elf.io.open_file. This is the internal path for a hdf5 or zarr container, " 301 "for an image stack it is a wild-card, e.g. '*.png' and for mrc it is 'data'." 302 ) 303 304 parser.add_argument( 305 "-m", "--model_type", default=util._DEFAULT_MODEL, 306 help=f"The segment anything model that will be used, one of {available_models}." 307 ) 308 parser.add_argument( 309 "-c", "--checkpoint", default=None, 310 help="Checkpoint from which the SAM model will be loaded loaded." 311 ) 312 parser.add_argument( 313 "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction.", default=None 314 ) 315 parser.add_argument( 316 "--halo", nargs="+", type=int, help="The halo for using tiled prediction.", default=None 317 ) 318 parser.add_argument( 319 "-n", "--ndim", type=int, default=None, 320 help="The number of spatial dimensions in the data. " 321 "Please specify this if your data has a channel dimension." 322 ) 323 parser.add_argument( 324 "-p", "--precompute_amg_state", action="store_true", 325 help="Whether to precompute the state for automatic instance segmentation." 326 ) 327 328 args = parser.parse_args() 329 precompute_state( 330 args.input_path, args.embedding_path, 331 model_type=args.model_type, checkpoint_path=args.checkpoint, 332 pattern=args.pattern, key=args.key, 333 tile_shape=args.tile_shape, halo=args.halo, ndim=args.ndim, 334 precompute_amg_state=args.precompute_amg_state, 335 ) 336 337 338if __name__ == "__main__": 339 main()
def
cache_amg_state( predictor: segment_anything.predictor.SamPredictor, raw: numpy.ndarray, image_embeddings: Dict[str, Any], save_path: Union[str, os.PathLike], verbose: bool = True, i: Optional[int] = None, **kwargs) -> micro_sam.instance_segmentation.AMGBase:
28def cache_amg_state( 29 predictor: SamPredictor, 30 raw: np.ndarray, 31 image_embeddings: util.ImageEmbeddings, 32 save_path: Union[str, os.PathLike], 33 verbose: bool = True, 34 i: Optional[int] = None, 35 **kwargs, 36) -> instance_segmentation.AMGBase: 37 """Compute and cache or load the state for the automatic mask generator. 38 39 Args: 40 predictor: The segment anything predictor. 41 raw: The image data. 42 image_embeddings: The image embeddings. 43 save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'. 44 verbose: Whether to run the computation verbose. 45 i: The index for which to cache the state. 46 kwargs: The keyword arguments for the amg class. 47 48 Returns: 49 The automatic mask generator class with the cached state. 50 """ 51 is_tiled = image_embeddings["input_size"] is None 52 amg = instance_segmentation.get_amg(predictor, is_tiled, **kwargs) 53 54 # If i is given we compute the state for a given slice/frame. 55 # And we have to save the state for slices/frames separately. 56 if i is None: 57 save_path_amg = os.path.join(save_path, "amg_state.pickle") 58 else: 59 os.makedirs(os.path.join(save_path, "amg_state"), exist_ok=True) 60 save_path_amg = os.path.join(save_path, "amg_state", f"state-{i}.pkl") 61 62 if os.path.exists(save_path_amg): 63 if verbose: 64 print("Load the AMG state from", save_path_amg) 65 with open(save_path_amg, "rb") as f: 66 amg_state = pickle.load(f) 67 amg.set_state(amg_state) 68 return amg 69 70 if verbose: 71 print("Precomputing the state for instance segmentation.") 72 73 amg.initialize(raw if i is None else raw[i], image_embeddings=image_embeddings, verbose=verbose, i=i) 74 amg_state = amg.get_state() 75 76 # put all state onto the cpu so that the state can be deserialized without a gpu 77 new_crop_list = [] 78 for mask_data in amg_state["crop_list"]: 79 for k, v in mask_data.items(): 80 if torch.is_tensor(v): 81 mask_data[k] = v.cpu() 82 new_crop_list.append(mask_data) 83 amg_state["crop_list"] = new_crop_list 84 85 with open(save_path_amg, "wb") as f: 86 pickle.dump(amg_state, f) 87 88 return amg
Compute and cache or load the state for the automatic mask generator.
Arguments:
- predictor: The segment anything predictor.
- raw: The image data.
- image_embeddings: The image embeddings.
- save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'.
- verbose: Whether to run the computation verbose.
- i: The index for which to cache the state.
- kwargs: The keyword arguments for the amg class.
Returns:
The automatic mask generator class with the cached state.
def
cache_is_state( predictor: segment_anything.predictor.SamPredictor, decoder: torch.nn.modules.module.Module, raw: numpy.ndarray, image_embeddings: Dict[str, Any], save_path: Union[str, os.PathLike], verbose: bool = True, i: Optional[int] = None, skip_load: bool = False, **kwargs) -> Optional[micro_sam.instance_segmentation.AMGBase]:
91def cache_is_state( 92 predictor: SamPredictor, 93 decoder: torch.nn.Module, 94 raw: np.ndarray, 95 image_embeddings: util.ImageEmbeddings, 96 save_path: Union[str, os.PathLike], 97 verbose: bool = True, 98 i: Optional[int] = None, 99 skip_load: bool = False, 100 **kwargs, 101) -> Optional[instance_segmentation.AMGBase]: 102 """Compute and cache or load the state for the automatic mask generator. 103 104 Args: 105 predictor: The segment anything predictor. 106 decoder: The instance segmentation decoder. 107 raw: The image data. 108 image_embeddings: The image embeddings. 109 save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'. 110 verbose: Whether to run the computation verbose. 111 i: The index for which to cache the state. 112 skip_load: Skip loading the state if it is precomputed. 113 kwargs: The keyword arguments for the amg class. 114 115 Returns: 116 The instance segmentation class with the cached state. 117 """ 118 is_tiled = image_embeddings["input_size"] is None 119 amg = instance_segmentation.get_amg(predictor, is_tiled, decoder=decoder, **kwargs) 120 121 # If i is given we compute the state for a given slice/frame. 122 # And we have to save the state for slices/frames separately. 123 save_path = os.path.join(save_path, "is_state.h5") 124 save_key = "state" if i is None else f"state-{i}" 125 126 with h5py.File(save_path, "a") as f: 127 if save_key in f: 128 if skip_load: # Skip loading to speed this up for cases where we don't need the return val. 129 return 130 131 if verbose: 132 print("Load instance segmentation state from", save_path, ":", save_key) 133 g = f[save_key] 134 state = { 135 "foreground": g["foreground"][:], 136 "boundary_distances": g["boundary_distances"][:], 137 "center_distances": g["center_distances"][:], 138 } 139 amg.set_state(state) 140 return amg 141 142 if verbose: 143 print("Precomputing the state for instance segmentation.") 144 145 amg.initialize(raw, image_embeddings=image_embeddings, verbose=verbose, i=i) 146 state = amg.get_state() 147 148 with h5py.File(save_path, "a") as f: 149 g = f.create_group(save_key) 150 g.create_dataset("foreground", data=state["foreground"], compression="gzip") 151 g.create_dataset("boundary_distances", data=state["boundary_distances"], compression="gzip") 152 g.create_dataset("center_distances", data=state["center_distances"], compression="gzip") 153 154 return amg
Compute and cache or load the state for the automatic mask generator.
Arguments:
- predictor: The segment anything predictor.
- decoder: The instance segmentation decoder.
- raw: The image data.
- image_embeddings: The image embeddings.
- save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'.
- verbose: Whether to run the computation verbose.
- i: The index for which to cache the state.
- skip_load: Skip loading the state if it is precomputed.
- kwargs: The keyword arguments for the amg class.
Returns:
The instance segmentation class with the cached state.
def
precompute_state( input_path: Union[os.PathLike, str], output_path: Union[os.PathLike, str], pattern: Optional[str] = None, model_type: str = 'vit_l', checkpoint_path: Union[str, os.PathLike, NoneType] = None, key: Optional[str] = None, ndim: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, precompute_amg_state: bool = False) -> None:
221def precompute_state( 222 input_path: Union[os.PathLike, str], 223 output_path: Union[os.PathLike, str], 224 pattern: Optional[str] = None, 225 model_type: str = util._DEFAULT_MODEL, 226 checkpoint_path: Optional[Union[os.PathLike, str]] = None, 227 key: Optional[str] = None, 228 ndim: Optional[int] = None, 229 tile_shape: Optional[Tuple[int, int]] = None, 230 halo: Optional[Tuple[int, int]] = None, 231 precompute_amg_state: bool = False, 232) -> None: 233 """Precompute the image embeddings and other optional state for the input image(s). 234 235 Args: 236 input_path: The input image file(s). Can either be a single image file (e.g. tif or png), 237 a container file (e.g. hdf5 or zarr) or a folder with images files. 238 In case of a container file the argument `key` must be given. In case of a folder 239 it can be given to provide a glob pattern to subselect files from the folder. 240 output_path: The output path where the embeddings and other state will be saved. 241 pattern: Glob pattern to select files in a folder. The embeddings will be computed 242 for each of these files. To select all files in a folder pass "*". 243 model_type: The SegmentAnything model to use. Will use the standard vit_l model by default. 244 checkpoint_path: Path to a checkpoint for a custom model. 245 key: The key to the input file. This is needed for contaner files (e.g. hdf5 or zarr) 246 or to load several images as 3d volume. Provide a glob pattern, e.g. "*.tif", for this case. 247 ndim: The dimensionality of the data. 248 tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling. 249 halo: Overlap of the tiles for tiled prediction. 250 precompute_amg_state: Whether to precompute the state for automatic instance segmentation 251 in addition to the image embeddings. 252 """ 253 predictor, state = util.get_sam_model( 254 model_type=model_type, checkpoint_path=checkpoint_path, return_state=True, 255 ) 256 if "decoder_state" in state: 257 decoder = instance_segmentation.get_decoder(predictor.model.image_encoder, state["decoder_state"]) 258 else: 259 decoder = None 260 261 # Check if we precompute the state for a single file or for a folder with image files. 262 if pattern is None: 263 _precompute_state_for_file( 264 predictor, input_path, output_path, key, 265 ndim=ndim, tile_shape=tile_shape, halo=halo, 266 precompute_amg_state=precompute_amg_state, 267 decoder=decoder, verbose=True, 268 ) 269 else: 270 input_files = glob(os.path.join(input_path, pattern)) 271 _precompute_state_for_files( 272 predictor, input_files, output_path, key=key, 273 ndim=ndim, tile_shape=tile_shape, halo=halo, 274 precompute_amg_state=precompute_amg_state, 275 decoder=decoder, 276 )
Precompute the image embeddings and other optional state for the input image(s).
Arguments:
- input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
a container file (e.g. hdf5 or zarr) or a folder with images files.
In case of a container file the argument
key
must be given. In case of a folder it can be given to provide a glob pattern to subselect files from the folder. - output_path: The output path where the embeddings and other state will be saved.
- pattern: Glob pattern to select files in a folder. The embeddings will be computed for each of these files. To select all files in a folder pass "*".
- model_type: The SegmentAnything model to use. Will use the standard vit_l model by default.
- checkpoint_path: Path to a checkpoint for a custom model.
- key: The key to the input file. This is needed for contaner files (e.g. hdf5 or zarr) or to load several images as 3d volume. Provide a glob pattern, e.g. "*.tif", for this case.
- ndim: The dimensionality of the data.
- tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
- halo: Overlap of the tiles for tiled prediction.
- precompute_amg_state: Whether to precompute the state for automatic instance segmentation in addition to the image embeddings.