micro_sam.precompute_state
Precompute image embeddings and automatic mask generator state for image data.
1"""Precompute image embeddings and automatic mask generator state for image data. 2""" 3 4import os 5import pickle 6from glob import glob 7from pathlib import Path 8from functools import partial 9from typing import Optional, Tuple, Union, List 10 11import h5py 12import numpy as np 13 14import torch 15import torch.nn as nn 16 17from segment_anything.predictor import SamPredictor 18 19try: 20 from napari.utils import progress as tqdm 21except ImportError: 22 from tqdm import tqdm 23 24from . import instance_segmentation, util 25 26 27def cache_amg_state( 28 predictor: SamPredictor, 29 raw: np.ndarray, 30 image_embeddings: util.ImageEmbeddings, 31 save_path: Union[str, os.PathLike], 32 verbose: bool = True, 33 i: Optional[int] = None, 34 **kwargs, 35) -> instance_segmentation.AMGBase: 36 """Compute and cache or load the state for the automatic mask generator. 37 38 Args: 39 predictor: The Segment Anything predictor. 40 raw: The image data. 41 image_embeddings: The image embeddings. 42 save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'. 43 verbose: Whether to run the computation verbose. By default, set to 'True'. 44 i: The index for which to cache the state. 45 kwargs: The keyword arguments for the amg class. 46 47 Returns: 48 The automatic mask generator class with the cached state. 49 """ 50 is_tiled = image_embeddings["input_size"] is None 51 amg = instance_segmentation.get_instance_segmentation_generator(predictor, is_tiled=is_tiled, **kwargs) 52 53 # If i is given we compute the state for a given slice/frame. 54 # And we have to save the state for slices/frames separately. 55 if i is None: 56 save_path_amg = os.path.join(save_path, "amg_state.pickle") 57 else: 58 os.makedirs(os.path.join(save_path, "amg_state"), exist_ok=True) 59 save_path_amg = os.path.join(save_path, "amg_state", f"state-{i}.pkl") 60 61 if os.path.exists(save_path_amg): 62 if verbose: 63 print("Load the AMG state from", save_path_amg) 64 with open(save_path_amg, "rb") as f: 65 amg_state = pickle.load(f) 66 amg.set_state(amg_state) 67 return amg 68 69 if verbose: 70 print("Precomputing the state for instance segmentation.") 71 72 amg.initialize(raw if i is None else raw[i], image_embeddings=image_embeddings, verbose=verbose, i=i) 73 amg_state = amg.get_state() 74 75 # put all state onto the cpu so that the state can be deserialized without a gpu 76 new_crop_list = [] 77 for mask_data in amg_state["crop_list"]: 78 for k, v in mask_data.items(): 79 if torch.is_tensor(v): 80 mask_data[k] = v.cpu() 81 new_crop_list.append(mask_data) 82 amg_state["crop_list"] = new_crop_list 83 84 with open(save_path_amg, "wb") as f: 85 pickle.dump(amg_state, f) 86 87 return amg 88 89 90def cache_is_state( 91 predictor: SamPredictor, 92 decoder: torch.nn.Module, 93 raw: np.ndarray, 94 image_embeddings: util.ImageEmbeddings, 95 save_path: Union[str, os.PathLike], 96 verbose: bool = True, 97 i: Optional[int] = None, 98 skip_load: bool = False, 99 **kwargs, 100) -> Optional[instance_segmentation.AMGBase]: 101 """Compute and cache or load the state for the automatic mask generator. 102 103 Args: 104 predictor: The Segment Anything predictor. 105 decoder: The instance segmentation decoder. 106 raw: The image data. 107 image_embeddings: The image embeddings. 108 save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'. 109 verbose: Whether to run the computation verbose. By default, set to 'True'. 110 i: The index for which to cache the state. 111 skip_load: Skip loading the state if it is precomputed. By default, set to 'False'. 112 kwargs: The keyword arguments for the amg class. 113 114 Returns: 115 The instance segmentation class with the cached state. 116 """ 117 is_tiled = image_embeddings["input_size"] is None 118 amg = instance_segmentation.get_instance_segmentation_generator( 119 predictor, is_tiled=is_tiled, decoder=decoder, **kwargs 120 ) 121 122 # If i is given we compute the state for a given slice/frame. 123 # And we have to save the state for slices/frames separately. 124 save_path = os.path.join(save_path, "is_state.h5") 125 save_key = "state" if i is None else f"state-{i}" 126 127 with h5py.File(save_path, "a") as f: 128 if save_key in f: 129 if skip_load: # Skip loading to speed this up for cases where we don't need the return val. 130 return 131 132 if verbose: 133 print("Load instance segmentation state from", save_path, ":", save_key) 134 g = f[save_key] 135 state = { 136 "foreground": g["foreground"][:], 137 "boundary_distances": g["boundary_distances"][:], 138 "center_distances": g["center_distances"][:], 139 } 140 amg.set_state(state) 141 return amg 142 143 if verbose: 144 print("Precomputing the state for instance segmentation.") 145 146 amg.initialize(raw, image_embeddings=image_embeddings, verbose=verbose, i=i) 147 state = amg.get_state() 148 149 with h5py.File(save_path, "a") as f: 150 g = f.create_group(save_key) 151 g.create_dataset("foreground", data=state["foreground"], compression="gzip") 152 g.create_dataset("boundary_distances", data=state["boundary_distances"], compression="gzip") 153 g.create_dataset("center_distances", data=state["center_distances"], compression="gzip") 154 155 return amg 156 157 158def _precompute_state_for_file( 159 predictor, input_path, output_path, key, ndim, tile_shape, halo, precompute_amg_state, decoder, verbose 160): 161 if isinstance(input_path, np.ndarray): 162 image_data = input_path 163 else: 164 image_data = util.load_image_data(input_path, key) 165 166 # Precompute the image embeddings. 167 output_path = Path(output_path).with_suffix(".zarr") 168 embeddings = util.precompute_image_embeddings( 169 predictor, image_data, output_path, ndim=ndim, tile_shape=tile_shape, halo=halo, verbose=verbose 170 ) 171 172 # Precompute the state for automatic instance segmnetaiton (AMG or AIS). 173 if precompute_amg_state: 174 if decoder is None: 175 cache_function = partial( 176 cache_amg_state, predictor=predictor, image_embeddings=embeddings, save_path=output_path 177 ) 178 else: 179 cache_function = partial( 180 cache_is_state, predictor=predictor, decoder=decoder, 181 image_embeddings=embeddings, save_path=output_path 182 ) 183 184 if ndim is None: 185 ndim = image_data.ndim 186 187 if ndim == 2: 188 cache_function(raw=image_data, verbose=verbose) 189 else: 190 n = image_data.shape[0] 191 for i in tqdm(range(n), total=n, desc="Precompute instance segmentation state", disable=not verbose): 192 cache_function(raw=image_data, i=i, verbose=False) 193 194 195def _precompute_state_for_files( 196 predictor: SamPredictor, 197 input_files: Union[List[Union[os.PathLike, str]], List[np.ndarray]], 198 output_path: Union[os.PathLike, str], 199 key: Optional[str] = None, 200 ndim: Optional[int] = None, 201 tile_shape: Optional[Tuple[int, int]] = None, 202 halo: Optional[Tuple[int, int]] = None, 203 precompute_amg_state: bool = False, 204 decoder: Optional["nn.Module"] = None, 205): 206 os.makedirs(output_path, exist_ok=True) 207 idx = 0 208 for file_path in tqdm(input_files, total=len(input_files), desc="Precompute state for files"): 209 210 if isinstance(file_path, np.ndarray): 211 out_path = os.path.join(output_path, f"embedding_{idx:05}.tif") 212 else: 213 out_path = os.path.join(output_path, os.path.basename(file_path)) 214 215 _precompute_state_for_file( 216 predictor, file_path, out_path, 217 key=key, ndim=ndim, tile_shape=tile_shape, halo=halo, 218 precompute_amg_state=precompute_amg_state, decoder=decoder, 219 verbose=False, 220 ) 221 idx += 1 222 223 224def precompute_state( 225 input_path: Union[os.PathLike, str], 226 output_path: Union[os.PathLike, str], 227 pattern: Optional[str] = None, 228 model_type: str = util._DEFAULT_MODEL, 229 checkpoint_path: Optional[Union[os.PathLike, str]] = None, 230 key: Optional[str] = None, 231 ndim: Optional[int] = None, 232 tile_shape: Optional[Tuple[int, int]] = None, 233 halo: Optional[Tuple[int, int]] = None, 234 precompute_amg_state: bool = False, 235) -> None: 236 """Precompute the image embeddings and other optional state for the input image(s). 237 238 Args: 239 input_path: The input image file(s). Can either be a single image file (e.g. tif or png), 240 a container file (e.g. hdf5 or zarr) or a folder with images files. 241 In case of a container file the argument `key` must be given. In case of a folder 242 it can be given to provide a glob pattern to subselect files from the folder. 243 output_path: The output path where the embeddings and other state will be saved. 244 pattern: Glob pattern to select files in a folder. The embeddings will be computed 245 for each of these files. To select all files in a folder pass "*". 246 model_type: The Segment Anything model to use. Will use the `vit_b_lm` model by default. 247 checkpoint_path: Path to a checkpoint for a custom model. 248 key: The key to the input file. This is needed for contaner files (e.g. hdf5 or zarr) 249 or to load several images as 3d volume. Provide a glob pattern, e.g. "*.tif", for this case. 250 ndim: The dimensionality of the data. By default, computes it from the input data. 251 tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling. 252 halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. 253 precompute_amg_state: Whether to precompute the state for automatic instance segmentation 254 in addition to the image embeddings. 255 """ 256 predictor, state = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, return_state=True) 257 258 if "decoder_state" in state: 259 decoder = instance_segmentation.get_decoder(predictor.model.image_encoder, state["decoder_state"]) 260 else: 261 decoder = None 262 263 # Check if we precompute the state for a single file or for a folder with image files. 264 if pattern is None: 265 _precompute_state_for_file( 266 predictor, input_path, output_path, key, 267 ndim=ndim, tile_shape=tile_shape, halo=halo, 268 precompute_amg_state=precompute_amg_state, 269 decoder=decoder, verbose=True, 270 ) 271 else: 272 input_files = glob(os.path.join(input_path, pattern)) 273 _precompute_state_for_files( 274 predictor, input_files, output_path, key=key, 275 ndim=ndim, tile_shape=tile_shape, halo=halo, 276 precompute_amg_state=precompute_amg_state, 277 decoder=decoder, 278 ) 279 280 281def main(): 282 """@private""" 283 import argparse 284 285 available_models = list(util.get_model_names()) 286 available_models = ", ".join(available_models) 287 288 parser = argparse.ArgumentParser(description="Compute the embeddings for an image.") 289 parser.add_argument( 290 "-i", "--input_path", required=True, 291 help="The filepath to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) " 292 "or elf.io.open_file (e.g. hdf5, zarr, mrc). For the latter you also need to pass the 'key' parameter." 293 ) 294 parser.add_argument( 295 "-e", "--embedding_path", required=True, help="The path where the embeddings will be saved." 296 ) 297 parser.add_argument( 298 "--pattern", help="Pattern / wildcard for selecting files in a folder. To select all files use '*'." 299 ) 300 parser.add_argument( 301 "-k", "--key", 302 help="The key for opening data with elf.io.open_file. This is the internal path for a hdf5 or zarr container, " 303 "for an image stack it is a wild-card, e.g. '*.png' and for mrc it is 'data'." 304 ) 305 parser.add_argument( 306 "-m", "--model_type", default=util._DEFAULT_MODEL, 307 help=f"The segment anything model that will be used, one of {available_models}." 308 ) 309 parser.add_argument( 310 "-c", "--checkpoint", default=None, 311 help="Checkpoint from which the SAM model will be loaded loaded." 312 ) 313 parser.add_argument( 314 "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction.", default=None 315 ) 316 parser.add_argument( 317 "--halo", nargs="+", type=int, help="The halo for using tiled prediction.", default=None 318 ) 319 parser.add_argument( 320 "-n", "--ndim", type=int, default=None, 321 help="The number of spatial dimensions in the data. " 322 "Please specify this if your data has a channel dimension." 323 ) 324 parser.add_argument( 325 "-p", "--precompute_amg_state", action="store_true", 326 help="Whether to precompute the state for automatic instance segmentation." 327 ) 328 329 args = parser.parse_args() 330 precompute_state( 331 args.input_path, args.embedding_path, 332 model_type=args.model_type, checkpoint_path=args.checkpoint, 333 pattern=args.pattern, key=args.key, 334 tile_shape=args.tile_shape, halo=args.halo, ndim=args.ndim, 335 precompute_amg_state=args.precompute_amg_state, 336 ) 337 338 339if __name__ == "__main__": 340 main()
def
cache_amg_state( predictor: segment_anything.predictor.SamPredictor, raw: numpy.ndarray, image_embeddings: Dict[str, Any], save_path: Union[str, os.PathLike], verbose: bool = True, i: Optional[int] = None, **kwargs) -> micro_sam.instance_segmentation.AMGBase:
28def cache_amg_state( 29 predictor: SamPredictor, 30 raw: np.ndarray, 31 image_embeddings: util.ImageEmbeddings, 32 save_path: Union[str, os.PathLike], 33 verbose: bool = True, 34 i: Optional[int] = None, 35 **kwargs, 36) -> instance_segmentation.AMGBase: 37 """Compute and cache or load the state for the automatic mask generator. 38 39 Args: 40 predictor: The Segment Anything predictor. 41 raw: The image data. 42 image_embeddings: The image embeddings. 43 save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'. 44 verbose: Whether to run the computation verbose. By default, set to 'True'. 45 i: The index for which to cache the state. 46 kwargs: The keyword arguments for the amg class. 47 48 Returns: 49 The automatic mask generator class with the cached state. 50 """ 51 is_tiled = image_embeddings["input_size"] is None 52 amg = instance_segmentation.get_instance_segmentation_generator(predictor, is_tiled=is_tiled, **kwargs) 53 54 # If i is given we compute the state for a given slice/frame. 55 # And we have to save the state for slices/frames separately. 56 if i is None: 57 save_path_amg = os.path.join(save_path, "amg_state.pickle") 58 else: 59 os.makedirs(os.path.join(save_path, "amg_state"), exist_ok=True) 60 save_path_amg = os.path.join(save_path, "amg_state", f"state-{i}.pkl") 61 62 if os.path.exists(save_path_amg): 63 if verbose: 64 print("Load the AMG state from", save_path_amg) 65 with open(save_path_amg, "rb") as f: 66 amg_state = pickle.load(f) 67 amg.set_state(amg_state) 68 return amg 69 70 if verbose: 71 print("Precomputing the state for instance segmentation.") 72 73 amg.initialize(raw if i is None else raw[i], image_embeddings=image_embeddings, verbose=verbose, i=i) 74 amg_state = amg.get_state() 75 76 # put all state onto the cpu so that the state can be deserialized without a gpu 77 new_crop_list = [] 78 for mask_data in amg_state["crop_list"]: 79 for k, v in mask_data.items(): 80 if torch.is_tensor(v): 81 mask_data[k] = v.cpu() 82 new_crop_list.append(mask_data) 83 amg_state["crop_list"] = new_crop_list 84 85 with open(save_path_amg, "wb") as f: 86 pickle.dump(amg_state, f) 87 88 return amg
Compute and cache or load the state for the automatic mask generator.
Arguments:
- predictor: The Segment Anything predictor.
- raw: The image data.
- image_embeddings: The image embeddings.
- save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'.
- verbose: Whether to run the computation verbose. By default, set to 'True'.
- i: The index for which to cache the state.
- kwargs: The keyword arguments for the amg class.
Returns:
The automatic mask generator class with the cached state.
def
cache_is_state( predictor: segment_anything.predictor.SamPredictor, decoder: torch.nn.modules.module.Module, raw: numpy.ndarray, image_embeddings: Dict[str, Any], save_path: Union[str, os.PathLike], verbose: bool = True, i: Optional[int] = None, skip_load: bool = False, **kwargs) -> Optional[micro_sam.instance_segmentation.AMGBase]:
91def cache_is_state( 92 predictor: SamPredictor, 93 decoder: torch.nn.Module, 94 raw: np.ndarray, 95 image_embeddings: util.ImageEmbeddings, 96 save_path: Union[str, os.PathLike], 97 verbose: bool = True, 98 i: Optional[int] = None, 99 skip_load: bool = False, 100 **kwargs, 101) -> Optional[instance_segmentation.AMGBase]: 102 """Compute and cache or load the state for the automatic mask generator. 103 104 Args: 105 predictor: The Segment Anything predictor. 106 decoder: The instance segmentation decoder. 107 raw: The image data. 108 image_embeddings: The image embeddings. 109 save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'. 110 verbose: Whether to run the computation verbose. By default, set to 'True'. 111 i: The index for which to cache the state. 112 skip_load: Skip loading the state if it is precomputed. By default, set to 'False'. 113 kwargs: The keyword arguments for the amg class. 114 115 Returns: 116 The instance segmentation class with the cached state. 117 """ 118 is_tiled = image_embeddings["input_size"] is None 119 amg = instance_segmentation.get_instance_segmentation_generator( 120 predictor, is_tiled=is_tiled, decoder=decoder, **kwargs 121 ) 122 123 # If i is given we compute the state for a given slice/frame. 124 # And we have to save the state for slices/frames separately. 125 save_path = os.path.join(save_path, "is_state.h5") 126 save_key = "state" if i is None else f"state-{i}" 127 128 with h5py.File(save_path, "a") as f: 129 if save_key in f: 130 if skip_load: # Skip loading to speed this up for cases where we don't need the return val. 131 return 132 133 if verbose: 134 print("Load instance segmentation state from", save_path, ":", save_key) 135 g = f[save_key] 136 state = { 137 "foreground": g["foreground"][:], 138 "boundary_distances": g["boundary_distances"][:], 139 "center_distances": g["center_distances"][:], 140 } 141 amg.set_state(state) 142 return amg 143 144 if verbose: 145 print("Precomputing the state for instance segmentation.") 146 147 amg.initialize(raw, image_embeddings=image_embeddings, verbose=verbose, i=i) 148 state = amg.get_state() 149 150 with h5py.File(save_path, "a") as f: 151 g = f.create_group(save_key) 152 g.create_dataset("foreground", data=state["foreground"], compression="gzip") 153 g.create_dataset("boundary_distances", data=state["boundary_distances"], compression="gzip") 154 g.create_dataset("center_distances", data=state["center_distances"], compression="gzip") 155 156 return amg
Compute and cache or load the state for the automatic mask generator.
Arguments:
- predictor: The Segment Anything predictor.
- decoder: The instance segmentation decoder.
- raw: The image data.
- image_embeddings: The image embeddings.
- save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'.
- verbose: Whether to run the computation verbose. By default, set to 'True'.
- i: The index for which to cache the state.
- skip_load: Skip loading the state if it is precomputed. By default, set to 'False'.
- kwargs: The keyword arguments for the amg class.
Returns:
The instance segmentation class with the cached state.
def
precompute_state( input_path: Union[os.PathLike, str], output_path: Union[os.PathLike, str], pattern: Optional[str] = None, model_type: str = 'vit_b_lm', checkpoint_path: Union[os.PathLike, str, NoneType] = None, key: Optional[str] = None, ndim: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, precompute_amg_state: bool = False) -> None:
225def precompute_state( 226 input_path: Union[os.PathLike, str], 227 output_path: Union[os.PathLike, str], 228 pattern: Optional[str] = None, 229 model_type: str = util._DEFAULT_MODEL, 230 checkpoint_path: Optional[Union[os.PathLike, str]] = None, 231 key: Optional[str] = None, 232 ndim: Optional[int] = None, 233 tile_shape: Optional[Tuple[int, int]] = None, 234 halo: Optional[Tuple[int, int]] = None, 235 precompute_amg_state: bool = False, 236) -> None: 237 """Precompute the image embeddings and other optional state for the input image(s). 238 239 Args: 240 input_path: The input image file(s). Can either be a single image file (e.g. tif or png), 241 a container file (e.g. hdf5 or zarr) or a folder with images files. 242 In case of a container file the argument `key` must be given. In case of a folder 243 it can be given to provide a glob pattern to subselect files from the folder. 244 output_path: The output path where the embeddings and other state will be saved. 245 pattern: Glob pattern to select files in a folder. The embeddings will be computed 246 for each of these files. To select all files in a folder pass "*". 247 model_type: The Segment Anything model to use. Will use the `vit_b_lm` model by default. 248 checkpoint_path: Path to a checkpoint for a custom model. 249 key: The key to the input file. This is needed for contaner files (e.g. hdf5 or zarr) 250 or to load several images as 3d volume. Provide a glob pattern, e.g. "*.tif", for this case. 251 ndim: The dimensionality of the data. By default, computes it from the input data. 252 tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling. 253 halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. 254 precompute_amg_state: Whether to precompute the state for automatic instance segmentation 255 in addition to the image embeddings. 256 """ 257 predictor, state = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, return_state=True) 258 259 if "decoder_state" in state: 260 decoder = instance_segmentation.get_decoder(predictor.model.image_encoder, state["decoder_state"]) 261 else: 262 decoder = None 263 264 # Check if we precompute the state for a single file or for a folder with image files. 265 if pattern is None: 266 _precompute_state_for_file( 267 predictor, input_path, output_path, key, 268 ndim=ndim, tile_shape=tile_shape, halo=halo, 269 precompute_amg_state=precompute_amg_state, 270 decoder=decoder, verbose=True, 271 ) 272 else: 273 input_files = glob(os.path.join(input_path, pattern)) 274 _precompute_state_for_files( 275 predictor, input_files, output_path, key=key, 276 ndim=ndim, tile_shape=tile_shape, halo=halo, 277 precompute_amg_state=precompute_amg_state, 278 decoder=decoder, 279 )
Precompute the image embeddings and other optional state for the input image(s).
Arguments:
- input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
a container file (e.g. hdf5 or zarr) or a folder with images files.
In case of a container file the argument
keymust be given. In case of a folder it can be given to provide a glob pattern to subselect files from the folder. - output_path: The output path where the embeddings and other state will be saved.
- pattern: Glob pattern to select files in a folder. The embeddings will be computed for each of these files. To select all files in a folder pass "*".
- model_type: The Segment Anything model to use. Will use the
vit_b_lmmodel by default. - checkpoint_path: Path to a checkpoint for a custom model.
- key: The key to the input file. This is needed for contaner files (e.g. hdf5 or zarr) or to load several images as 3d volume. Provide a glob pattern, e.g. "*.tif", for this case.
- ndim: The dimensionality of the data. By default, computes it from the input data.
- tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
- halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
- precompute_amg_state: Whether to precompute the state for automatic instance segmentation in addition to the image embeddings.