synapse_net.tools.cli
1import argparse 2import os 3from functools import partial 4 5import torch 6import torch_em 7from ..imod.to_imod import export_helper, write_segmentation_to_imod_as_points, write_segmentation_to_imod 8from ..inference.inference import _get_model_registry, get_model, get_model_training_resolution, run_segmentation 9from ..inference.scalable_segmentation import scalable_segmentation 10from ..inference.util import inference_helper, parse_tiling 11from .pool_visualization import _visualize_vesicle_pools 12 13 14def imod_point_cli(): 15 parser = argparse.ArgumentParser( 16 description="Convert a vesicle segmentation to an IMOD point model, " 17 "corresponding to a sphere for each vesicle in the segmentation." 18 ) 19 parser.add_argument( 20 "--input_path", "-i", required=True, 21 help="The filepath to the mrc file or the directory containing the tomogram data." 22 ) 23 parser.add_argument( 24 "--segmentation_path", "-s", required=True, 25 help="The filepath to the file or the directory containing the segmentations." 26 ) 27 parser.add_argument( 28 "--output_path", "-o", required=True, 29 help="The filepath to directory where the segmentations will be saved." 30 ) 31 parser.add_argument( 32 "--segmentation_key", "-k", 33 help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif." 34 "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset." 35 ) 36 parser.add_argument( 37 "--min_radius", type=float, default=10.0, 38 help="The minimum vesicle radius in nm. Objects that are smaller than this radius will be exclded from the export." # noqa 39 ) 40 parser.add_argument( 41 "--radius_factor", type=float, default=1.0, 42 help="A factor for scaling the sphere radius for the export. " 43 "This can be used to fit the size of segmented vesicles to the best matching spheres.", 44 ) 45 parser.add_argument( 46 "--force", action="store_true", 47 help="Whether to over-write already present export results." 48 ) 49 args = parser.parse_args() 50 51 export_function = partial( 52 write_segmentation_to_imod_as_points, 53 min_radius=args.min_radius, 54 radius_factor=args.radius_factor, 55 ) 56 57 export_helper( 58 input_path=args.input_path, 59 segmentation_path=args.segmentation_path, 60 output_root=args.output_path, 61 export_function=export_function, 62 force=args.force, 63 segmentation_key=args.segmentation_key, 64 ) 65 66 67def imod_object_cli(): 68 parser = argparse.ArgumentParser( 69 description="Convert segmented objects to close contour IMOD models." 70 ) 71 parser.add_argument( 72 "--input_path", "-i", required=True, 73 help="The filepath to the mrc file or the directory containing the tomogram data." 74 ) 75 parser.add_argument( 76 "--segmentation_path", "-s", required=True, 77 help="The filepath to the file or the directory containing the segmentations." 78 ) 79 parser.add_argument( 80 "--output_path", "-o", required=True, 81 help="The filepath to directory where the segmentations will be saved." 82 ) 83 parser.add_argument( 84 "--segmentation_key", "-k", 85 help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif." 86 "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset." 87 ) 88 parser.add_argument( 89 "--force", action="store_true", 90 help="Whether to over-write already present export results." 91 ) 92 args = parser.parse_args() 93 export_helper( 94 input_path=args.input_path, 95 segmentation_path=args.segmentation_path, 96 output_root=args.output_path, 97 export_function=write_segmentation_to_imod, 98 force=args.force, 99 segmentation_key=args.segmentation_key, 100 ) 101 102 103def pool_visualization_cli(): 104 parser = argparse.ArgumentParser(description="Load tomogram data, vesicle pools and additional segmentations for viualization.") # noqa 105 parser.add_argument( 106 "--input_path", "-i", required=True, 107 help="The filepath to the mrc file containing the tomogram data." 108 ) 109 parser.add_argument( 110 "--vesicle_paths", "-v", required=True, nargs="+", 111 help="The filepath(s) to the tif file(s) containing the vesicle segmentation." 112 ) 113 parser.add_argument( 114 "--table_paths", "-t", required=True, nargs="+", 115 help="The filepath(s) to the table(s) with the vesicle pool assignments." 116 ) 117 parser.add_argument( 118 "-s", "--segmentation_paths", nargs="+", help="Filepaths for additional segmentations." 119 ) 120 parser.add_argument( 121 "--split_pools", action="store_true", help="Whether to split the pools into individual layers.", 122 ) 123 args = parser.parse_args() 124 _visualize_vesicle_pools( 125 args.input_path, args.vesicle_paths, args.table_paths, args.segmentation_paths, args.split_pools 126 ) 127 128 129# TODO: handle kwargs 130def segmentation_cli(): 131 parser = argparse.ArgumentParser(description="Run segmentation.") 132 parser.add_argument( 133 "--input_path", "-i", required=True, 134 help="The filepath to the mrc file or the directory containing the tomogram data." 135 ) 136 parser.add_argument( 137 "--output_path", "-o", required=True, 138 help="The filepath to directory where the segmentations will be saved." 139 ) 140 model_names = list(_get_model_registry().urls.keys()) 141 model_names = ", ".join(model_names) 142 parser.add_argument( 143 "--model", "-m", required=True, 144 help=f"The model type. The following models are currently available: {model_names}" 145 ) 146 parser.add_argument( 147 "--mask_path", help="The filepath to a tif file with a mask that will be used to restrict the segmentation." 148 "Can also be a directory with tifs if the filestructure matches input_path." 149 ) 150 parser.add_argument("--input_key", "-k", required=False) 151 parser.add_argument( 152 "--force", action="store_true", 153 help="Whether to over-write already present segmentation results." 154 ) 155 parser.add_argument( 156 "--tile_shape", type=int, nargs=3, 157 help="The tile shape for prediction, in ZYX order. Lower the tile shape if GPU memory is insufficient." 158 ) 159 parser.add_argument( 160 "--halo", type=int, nargs=3, 161 help="The halo for prediction, in ZYX order. Increase the halo to minimize boundary artifacts." 162 ) 163 parser.add_argument( 164 "--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc." 165 ) 166 parser.add_argument( 167 "--checkpoint", "-c", help="Path to a custom model, e.g. from domain adaptation.", 168 ) 169 parser.add_argument( 170 "--segmentation_key", "-s", 171 help="If given, the outputs will be saved to an hdf5 file with this key. Otherwise they will be saved as tif.", 172 ) 173 parser.add_argument( 174 "--scale", type=float, 175 help="The factor for rescaling the data before inference. " 176 "By default, the scaling factor will be derived from the voxel size of the input data. " 177 "If this parameter is given it will over-ride the default behavior. " 178 ) 179 parser.add_argument( 180 "--verbose", "-v", action="store_true", 181 help="Whether to print verbose information about the segmentation progress." 182 ) 183 parser.add_argument( 184 "--scalable", action="store_true", help="Use the scalable segmentation implementation. " 185 "Currently this only works for vesicles, mitochondria, or active zones." 186 ) 187 args = parser.parse_args() 188 189 if args.checkpoint is None: 190 model = get_model(args.model) 191 else: 192 checkpoint_path = args.checkpoint 193 if checkpoint_path.endswith("best.pt"): 194 checkpoint_path = os.path.split(checkpoint_path)[0] 195 196 if os.path.isdir(checkpoint_path): # Load the model from a torch_em checkpoint. 197 model = torch_em.util.load_model(checkpoint=checkpoint_path) 198 else: 199 model = torch.load(checkpoint_path, weights_only=False) 200 assert model is not None, f"The model from {args.checkpoint} could not be loaded." 201 202 is_2d = "2d" in args.model 203 tiling = parse_tiling(args.tile_shape, args.halo, is_2d=is_2d) 204 205 # If the scale argument is not passed, then we get the average training resolution for the model. 206 # The inputs will then be scaled to match this resolution based on the voxel size from the mrc files. 207 if args.scale is None: 208 model_resolution = get_model_training_resolution(args.model) 209 model_resolution = tuple(model_resolution[ax] for ax in ("yx" if is_2d else "zyx")) 210 scale = None 211 # Otherwise, we set the model resolution to None and use the scaling factor provided by the user. 212 else: 213 model_resolution = None 214 scale = (2 if is_2d else 3) * (args.scale,) 215 216 if args.scalable: 217 if not args.model.startswith(("vesicle", "mito", "active")): 218 raise ValueError( 219 "The scalable segmentation implementation is currently only supported for " 220 f"vesicles, mitochondria, or active zones, not for {args.model}." 221 ) 222 segmentation_function = partial( 223 scalable_segmentation, model=model, tiling=tiling, verbose=args.verbose 224 ) 225 allocate_output = True 226 227 else: 228 segmentation_function = partial( 229 run_segmentation, model=model, model_type=args.model, verbose=args.verbose, tiling=tiling, 230 ) 231 allocate_output = False 232 233 inference_helper( 234 args.input_path, args.output_path, segmentation_function, 235 mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext, 236 output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale, 237 allocate_output=allocate_output 238 )
def
imod_point_cli():
15def imod_point_cli(): 16 parser = argparse.ArgumentParser( 17 description="Convert a vesicle segmentation to an IMOD point model, " 18 "corresponding to a sphere for each vesicle in the segmentation." 19 ) 20 parser.add_argument( 21 "--input_path", "-i", required=True, 22 help="The filepath to the mrc file or the directory containing the tomogram data." 23 ) 24 parser.add_argument( 25 "--segmentation_path", "-s", required=True, 26 help="The filepath to the file or the directory containing the segmentations." 27 ) 28 parser.add_argument( 29 "--output_path", "-o", required=True, 30 help="The filepath to directory where the segmentations will be saved." 31 ) 32 parser.add_argument( 33 "--segmentation_key", "-k", 34 help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif." 35 "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset." 36 ) 37 parser.add_argument( 38 "--min_radius", type=float, default=10.0, 39 help="The minimum vesicle radius in nm. Objects that are smaller than this radius will be exclded from the export." # noqa 40 ) 41 parser.add_argument( 42 "--radius_factor", type=float, default=1.0, 43 help="A factor for scaling the sphere radius for the export. " 44 "This can be used to fit the size of segmented vesicles to the best matching spheres.", 45 ) 46 parser.add_argument( 47 "--force", action="store_true", 48 help="Whether to over-write already present export results." 49 ) 50 args = parser.parse_args() 51 52 export_function = partial( 53 write_segmentation_to_imod_as_points, 54 min_radius=args.min_radius, 55 radius_factor=args.radius_factor, 56 ) 57 58 export_helper( 59 input_path=args.input_path, 60 segmentation_path=args.segmentation_path, 61 output_root=args.output_path, 62 export_function=export_function, 63 force=args.force, 64 segmentation_key=args.segmentation_key, 65 )
def
imod_object_cli():
68def imod_object_cli(): 69 parser = argparse.ArgumentParser( 70 description="Convert segmented objects to close contour IMOD models." 71 ) 72 parser.add_argument( 73 "--input_path", "-i", required=True, 74 help="The filepath to the mrc file or the directory containing the tomogram data." 75 ) 76 parser.add_argument( 77 "--segmentation_path", "-s", required=True, 78 help="The filepath to the file or the directory containing the segmentations." 79 ) 80 parser.add_argument( 81 "--output_path", "-o", required=True, 82 help="The filepath to directory where the segmentations will be saved." 83 ) 84 parser.add_argument( 85 "--segmentation_key", "-k", 86 help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif." 87 "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset." 88 ) 89 parser.add_argument( 90 "--force", action="store_true", 91 help="Whether to over-write already present export results." 92 ) 93 args = parser.parse_args() 94 export_helper( 95 input_path=args.input_path, 96 segmentation_path=args.segmentation_path, 97 output_root=args.output_path, 98 export_function=write_segmentation_to_imod, 99 force=args.force, 100 segmentation_key=args.segmentation_key, 101 )
def
pool_visualization_cli():
104def pool_visualization_cli(): 105 parser = argparse.ArgumentParser(description="Load tomogram data, vesicle pools and additional segmentations for viualization.") # noqa 106 parser.add_argument( 107 "--input_path", "-i", required=True, 108 help="The filepath to the mrc file containing the tomogram data." 109 ) 110 parser.add_argument( 111 "--vesicle_paths", "-v", required=True, nargs="+", 112 help="The filepath(s) to the tif file(s) containing the vesicle segmentation." 113 ) 114 parser.add_argument( 115 "--table_paths", "-t", required=True, nargs="+", 116 help="The filepath(s) to the table(s) with the vesicle pool assignments." 117 ) 118 parser.add_argument( 119 "-s", "--segmentation_paths", nargs="+", help="Filepaths for additional segmentations." 120 ) 121 parser.add_argument( 122 "--split_pools", action="store_true", help="Whether to split the pools into individual layers.", 123 ) 124 args = parser.parse_args() 125 _visualize_vesicle_pools( 126 args.input_path, args.vesicle_paths, args.table_paths, args.segmentation_paths, args.split_pools 127 )
def
segmentation_cli():
131def segmentation_cli(): 132 parser = argparse.ArgumentParser(description="Run segmentation.") 133 parser.add_argument( 134 "--input_path", "-i", required=True, 135 help="The filepath to the mrc file or the directory containing the tomogram data." 136 ) 137 parser.add_argument( 138 "--output_path", "-o", required=True, 139 help="The filepath to directory where the segmentations will be saved." 140 ) 141 model_names = list(_get_model_registry().urls.keys()) 142 model_names = ", ".join(model_names) 143 parser.add_argument( 144 "--model", "-m", required=True, 145 help=f"The model type. The following models are currently available: {model_names}" 146 ) 147 parser.add_argument( 148 "--mask_path", help="The filepath to a tif file with a mask that will be used to restrict the segmentation." 149 "Can also be a directory with tifs if the filestructure matches input_path." 150 ) 151 parser.add_argument("--input_key", "-k", required=False) 152 parser.add_argument( 153 "--force", action="store_true", 154 help="Whether to over-write already present segmentation results." 155 ) 156 parser.add_argument( 157 "--tile_shape", type=int, nargs=3, 158 help="The tile shape for prediction, in ZYX order. Lower the tile shape if GPU memory is insufficient." 159 ) 160 parser.add_argument( 161 "--halo", type=int, nargs=3, 162 help="The halo for prediction, in ZYX order. Increase the halo to minimize boundary artifacts." 163 ) 164 parser.add_argument( 165 "--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc." 166 ) 167 parser.add_argument( 168 "--checkpoint", "-c", help="Path to a custom model, e.g. from domain adaptation.", 169 ) 170 parser.add_argument( 171 "--segmentation_key", "-s", 172 help="If given, the outputs will be saved to an hdf5 file with this key. Otherwise they will be saved as tif.", 173 ) 174 parser.add_argument( 175 "--scale", type=float, 176 help="The factor for rescaling the data before inference. " 177 "By default, the scaling factor will be derived from the voxel size of the input data. " 178 "If this parameter is given it will over-ride the default behavior. " 179 ) 180 parser.add_argument( 181 "--verbose", "-v", action="store_true", 182 help="Whether to print verbose information about the segmentation progress." 183 ) 184 parser.add_argument( 185 "--scalable", action="store_true", help="Use the scalable segmentation implementation. " 186 "Currently this only works for vesicles, mitochondria, or active zones." 187 ) 188 args = parser.parse_args() 189 190 if args.checkpoint is None: 191 model = get_model(args.model) 192 else: 193 checkpoint_path = args.checkpoint 194 if checkpoint_path.endswith("best.pt"): 195 checkpoint_path = os.path.split(checkpoint_path)[0] 196 197 if os.path.isdir(checkpoint_path): # Load the model from a torch_em checkpoint. 198 model = torch_em.util.load_model(checkpoint=checkpoint_path) 199 else: 200 model = torch.load(checkpoint_path, weights_only=False) 201 assert model is not None, f"The model from {args.checkpoint} could not be loaded." 202 203 is_2d = "2d" in args.model 204 tiling = parse_tiling(args.tile_shape, args.halo, is_2d=is_2d) 205 206 # If the scale argument is not passed, then we get the average training resolution for the model. 207 # The inputs will then be scaled to match this resolution based on the voxel size from the mrc files. 208 if args.scale is None: 209 model_resolution = get_model_training_resolution(args.model) 210 model_resolution = tuple(model_resolution[ax] for ax in ("yx" if is_2d else "zyx")) 211 scale = None 212 # Otherwise, we set the model resolution to None and use the scaling factor provided by the user. 213 else: 214 model_resolution = None 215 scale = (2 if is_2d else 3) * (args.scale,) 216 217 if args.scalable: 218 if not args.model.startswith(("vesicle", "mito", "active")): 219 raise ValueError( 220 "The scalable segmentation implementation is currently only supported for " 221 f"vesicles, mitochondria, or active zones, not for {args.model}." 222 ) 223 segmentation_function = partial( 224 scalable_segmentation, model=model, tiling=tiling, verbose=args.verbose 225 ) 226 allocate_output = True 227 228 else: 229 segmentation_function = partial( 230 run_segmentation, model=model, model_type=args.model, verbose=args.verbose, tiling=tiling, 231 ) 232 allocate_output = False 233 234 inference_helper( 235 args.input_path, args.output_path, segmentation_function, 236 mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext, 237 output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale, 238 allocate_output=allocate_output 239 )