synapse_net.tools.cli
1import argparse 2import os 3from functools import partial 4 5import torch 6import torch_em 7from ..imod.to_imod import export_helper, write_segmentation_to_imod_as_points, write_segmentation_to_imod 8from ..inference.inference import _get_model_registry, get_model, get_model_training_resolution, run_segmentation 9from ..inference.scalable_segmentation import scalable_segmentation 10from ..inference.util import inference_helper, parse_tiling 11from .pool_visualization import _visualize_vesicle_pools 12 13 14def imod_point_cli(): 15 parser = argparse.ArgumentParser( 16 description="Convert a vesicle segmentation to an IMOD point model, " 17 "corresponding to a sphere for each vesicle in the segmentation." 18 ) 19 parser.add_argument( 20 "--input_path", "-i", required=True, 21 help="The filepath to the mrc file or the directory containing the tomogram data." 22 ) 23 parser.add_argument( 24 "--segmentation_path", "-s", required=True, 25 help="The filepath to the file or the directory containing the segmentations." 26 ) 27 parser.add_argument( 28 "--output_path", "-o", required=True, 29 help="The filepath to directory where the segmentations will be saved." 30 ) 31 parser.add_argument( 32 "--segmentation_key", "-k", 33 help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif." 34 "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset." 35 ) 36 parser.add_argument( 37 "--min_radius", type=float, default=10.0, 38 help="The minimum vesicle radius in nm. Objects that are smaller than this radius will be exclded from the export." # noqa 39 ) 40 parser.add_argument( 41 "--radius_factor", type=float, default=1.0, 42 help="A factor for scaling the sphere radius for the export. " 43 "This can be used to fit the size of segmented vesicles to the best matching spheres.", 44 ) 45 parser.add_argument( 46 "--force", action="store_true", 47 help="Whether to over-write already present export results." 48 ) 49 args = parser.parse_args() 50 51 export_function = partial( 52 write_segmentation_to_imod_as_points, 53 min_radius=args.min_radius, 54 radius_factor=args.radius_factor, 55 ) 56 57 export_helper( 58 input_path=args.input_path, 59 segmentation_path=args.segmentation_path, 60 output_root=args.output_path, 61 export_function=export_function, 62 force=args.force, 63 segmentation_key=args.segmentation_key, 64 ) 65 66 67def imod_object_cli(): 68 parser = argparse.ArgumentParser( 69 description="Convert segmented objects to close contour IMOD models." 70 ) 71 parser.add_argument( 72 "--input_path", "-i", required=True, 73 help="The filepath to the mrc file or the directory containing the tomogram data." 74 ) 75 parser.add_argument( 76 "--segmentation_path", "-s", required=True, 77 help="The filepath to the file or the directory containing the segmentations." 78 ) 79 parser.add_argument( 80 "--output_path", "-o", required=True, 81 help="The filepath to directory where the segmentations will be saved." 82 ) 83 parser.add_argument( 84 "--segmentation_key", "-k", 85 help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif." 86 "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset." 87 ) 88 parser.add_argument( 89 "--force", action="store_true", 90 help="Whether to over-write already present export results." 91 ) 92 args = parser.parse_args() 93 export_helper( 94 input_path=args.input_path, 95 segmentation_path=args.segmentation_path, 96 output_root=args.output_path, 97 export_function=write_segmentation_to_imod, 98 force=args.force, 99 segmentation_key=args.segmentation_key, 100 ) 101 102 103def pool_visualization_cli(): 104 parser = argparse.ArgumentParser(description="Load tomogram data, vesicle pools and additional segmentations for viualization.") # noqa 105 parser.add_argument( 106 "--input_path", "-i", required=True, 107 help="The filepath to the mrc file containing the tomogram data." 108 ) 109 parser.add_argument( 110 "--vesicle_paths", "-v", required=True, nargs="+", 111 help="The filepath(s) to the tif file(s) containing the vesicle segmentation." 112 ) 113 parser.add_argument( 114 "--table_paths", "-t", required=True, nargs="+", 115 help="The filepath(s) to the table(s) with the vesicle pool assignments." 116 ) 117 parser.add_argument( 118 "-s", "--segmentation_paths", nargs="+", help="Filepaths for additional segmentations." 119 ) 120 parser.add_argument( 121 "--split_pools", action="store_true", help="Whether to split the pools into individual layers.", 122 ) 123 args = parser.parse_args() 124 _visualize_vesicle_pools( 125 args.input_path, args.vesicle_paths, args.table_paths, args.segmentation_paths, args.split_pools 126 ) 127 128 129# TODO: handle kwargs 130def segmentation_cli(): 131 parser = argparse.ArgumentParser(description="Run segmentation.") 132 parser.add_argument( 133 "--input_path", "-i", required=True, 134 help="The filepath to the mrc file or the directory containing the tomogram data." 135 ) 136 parser.add_argument( 137 "--output_path", "-o", required=True, 138 help="The filepath to directory where the segmentations will be saved." 139 ) 140 model_names = list(_get_model_registry().urls.keys()) 141 model_names = ", ".join(model_names) 142 parser.add_argument( 143 "--model", "-m", required=True, 144 help=f"The model type. The following models are currently available: {model_names}" 145 ) 146 parser.add_argument( 147 "--mask_path", help="The filepath to a tif file with a mask that will be used to restrict the segmentation." 148 "Can also be a directory with tifs if the filestructure matches input_path." 149 ) 150 parser.add_argument("--input_key", "-k", required=False) 151 parser.add_argument( 152 "--force", action="store_true", 153 help="Whether to over-write already present segmentation results." 154 ) 155 parser.add_argument( 156 "--tile_shape", type=int, nargs=3, 157 help="The tile shape for prediction, in ZYX order. Lower the tile shape if GPU memory is insufficient." 158 ) 159 parser.add_argument( 160 "--halo", type=int, nargs=3, 161 help="The halo for prediction, in ZYX order. Increase the halo to minimize boundary artifacts." 162 ) 163 parser.add_argument( 164 "--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc." 165 ) 166 parser.add_argument( 167 "--checkpoint", "-c", help="Path to a custom model, e.g. from domain adaptation.", 168 ) 169 parser.add_argument( 170 "--segmentation_key", "-s", 171 help="If given, the outputs will be saved to an hdf5 file with this key. Otherwise they will be saved as tif.", 172 ) 173 parser.add_argument( 174 "--scale", type=float, 175 help="The factor for rescaling the data before inference. " 176 "By default, the scaling factor will be derived from the voxel size of the input data. " 177 "If this parameter is given it will over-ride the default behavior. " 178 ) 179 parser.add_argument( 180 "--verbose", "-v", action="store_true", 181 help="Whether to print verbose information about the segmentation progress." 182 ) 183 parser.add_argument( 184 "--scalable", action="store_true", help="Use the scalable segmentation implementation. " 185 "Currently this only works for vesicles, mitochondria, or active zones." 186 ) 187 parser.add_argument( 188 "--extra_input_path", default=None, help="Filepath to extra inputs, needed for cristae segmentation." 189 ) 190 parser.add_argument( 191 "--extra_input_ext", default=".tif", help="File extension for the extra inputs, default is tif." 192 ) 193 args = parser.parse_args() 194 195 if args.checkpoint is None: 196 model = get_model(args.model) 197 else: 198 checkpoint_path = args.checkpoint 199 if checkpoint_path.endswith("best.pt"): 200 checkpoint_path = os.path.split(checkpoint_path)[0] 201 202 if os.path.isdir(checkpoint_path): # Load the model from a torch_em checkpoint. 203 model = torch_em.util.load_model(checkpoint=checkpoint_path) 204 else: 205 model = torch.load(checkpoint_path, weights_only=False) 206 assert model is not None, f"The model from {args.checkpoint} could not be loaded." 207 208 is_2d = "2d" in args.model 209 tiling = parse_tiling(args.tile_shape, args.halo, is_2d=is_2d) 210 211 # If the scale argument is not passed, then we get the average training resolution for the model. 212 # The inputs will then be scaled to match this resolution based on the voxel size from the mrc files. 213 if args.scale is None: 214 model_resolution = get_model_training_resolution(args.model) 215 model_resolution = tuple(model_resolution[ax] for ax in ("yx" if is_2d else "zyx")) 216 scale = None 217 # Otherwise, we set the model resolution to None and use the scaling factor provided by the user. 218 else: 219 model_resolution = None 220 scale = (2 if is_2d else 3) * (args.scale,) 221 222 if args.scalable: 223 if not args.model.startswith(("vesicle", "mito", "active")): 224 raise ValueError( 225 "The scalable segmentation implementation is currently only supported for " 226 f"vesicles, mitochondria, or active zones, not for {args.model}." 227 ) 228 segmentation_function = partial( 229 scalable_segmentation, model=model, tiling=tiling, verbose=args.verbose 230 ) 231 allocate_output = True 232 233 else: 234 segmentation_function = partial( 235 run_segmentation, model=model, model_type=args.model, verbose=args.verbose, tiling=tiling, 236 ) 237 allocate_output = False 238 239 inference_helper( 240 args.input_path, args.output_path, segmentation_function, 241 mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext, 242 output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale, 243 allocate_output=allocate_output, extra_input_path=args.extra_input_path, 244 extra_input_ext=args.extra_input_ext 245 )
def
imod_point_cli():
15def imod_point_cli(): 16 parser = argparse.ArgumentParser( 17 description="Convert a vesicle segmentation to an IMOD point model, " 18 "corresponding to a sphere for each vesicle in the segmentation." 19 ) 20 parser.add_argument( 21 "--input_path", "-i", required=True, 22 help="The filepath to the mrc file or the directory containing the tomogram data." 23 ) 24 parser.add_argument( 25 "--segmentation_path", "-s", required=True, 26 help="The filepath to the file or the directory containing the segmentations." 27 ) 28 parser.add_argument( 29 "--output_path", "-o", required=True, 30 help="The filepath to directory where the segmentations will be saved." 31 ) 32 parser.add_argument( 33 "--segmentation_key", "-k", 34 help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif." 35 "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset." 36 ) 37 parser.add_argument( 38 "--min_radius", type=float, default=10.0, 39 help="The minimum vesicle radius in nm. Objects that are smaller than this radius will be exclded from the export." # noqa 40 ) 41 parser.add_argument( 42 "--radius_factor", type=float, default=1.0, 43 help="A factor for scaling the sphere radius for the export. " 44 "This can be used to fit the size of segmented vesicles to the best matching spheres.", 45 ) 46 parser.add_argument( 47 "--force", action="store_true", 48 help="Whether to over-write already present export results." 49 ) 50 args = parser.parse_args() 51 52 export_function = partial( 53 write_segmentation_to_imod_as_points, 54 min_radius=args.min_radius, 55 radius_factor=args.radius_factor, 56 ) 57 58 export_helper( 59 input_path=args.input_path, 60 segmentation_path=args.segmentation_path, 61 output_root=args.output_path, 62 export_function=export_function, 63 force=args.force, 64 segmentation_key=args.segmentation_key, 65 )
def
imod_object_cli():
68def imod_object_cli(): 69 parser = argparse.ArgumentParser( 70 description="Convert segmented objects to close contour IMOD models." 71 ) 72 parser.add_argument( 73 "--input_path", "-i", required=True, 74 help="The filepath to the mrc file or the directory containing the tomogram data." 75 ) 76 parser.add_argument( 77 "--segmentation_path", "-s", required=True, 78 help="The filepath to the file or the directory containing the segmentations." 79 ) 80 parser.add_argument( 81 "--output_path", "-o", required=True, 82 help="The filepath to directory where the segmentations will be saved." 83 ) 84 parser.add_argument( 85 "--segmentation_key", "-k", 86 help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif." 87 "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset." 88 ) 89 parser.add_argument( 90 "--force", action="store_true", 91 help="Whether to over-write already present export results." 92 ) 93 args = parser.parse_args() 94 export_helper( 95 input_path=args.input_path, 96 segmentation_path=args.segmentation_path, 97 output_root=args.output_path, 98 export_function=write_segmentation_to_imod, 99 force=args.force, 100 segmentation_key=args.segmentation_key, 101 )
def
pool_visualization_cli():
104def pool_visualization_cli(): 105 parser = argparse.ArgumentParser(description="Load tomogram data, vesicle pools and additional segmentations for viualization.") # noqa 106 parser.add_argument( 107 "--input_path", "-i", required=True, 108 help="The filepath to the mrc file containing the tomogram data." 109 ) 110 parser.add_argument( 111 "--vesicle_paths", "-v", required=True, nargs="+", 112 help="The filepath(s) to the tif file(s) containing the vesicle segmentation." 113 ) 114 parser.add_argument( 115 "--table_paths", "-t", required=True, nargs="+", 116 help="The filepath(s) to the table(s) with the vesicle pool assignments." 117 ) 118 parser.add_argument( 119 "-s", "--segmentation_paths", nargs="+", help="Filepaths for additional segmentations." 120 ) 121 parser.add_argument( 122 "--split_pools", action="store_true", help="Whether to split the pools into individual layers.", 123 ) 124 args = parser.parse_args() 125 _visualize_vesicle_pools( 126 args.input_path, args.vesicle_paths, args.table_paths, args.segmentation_paths, args.split_pools 127 )
def
segmentation_cli():
131def segmentation_cli(): 132 parser = argparse.ArgumentParser(description="Run segmentation.") 133 parser.add_argument( 134 "--input_path", "-i", required=True, 135 help="The filepath to the mrc file or the directory containing the tomogram data." 136 ) 137 parser.add_argument( 138 "--output_path", "-o", required=True, 139 help="The filepath to directory where the segmentations will be saved." 140 ) 141 model_names = list(_get_model_registry().urls.keys()) 142 model_names = ", ".join(model_names) 143 parser.add_argument( 144 "--model", "-m", required=True, 145 help=f"The model type. The following models are currently available: {model_names}" 146 ) 147 parser.add_argument( 148 "--mask_path", help="The filepath to a tif file with a mask that will be used to restrict the segmentation." 149 "Can also be a directory with tifs if the filestructure matches input_path." 150 ) 151 parser.add_argument("--input_key", "-k", required=False) 152 parser.add_argument( 153 "--force", action="store_true", 154 help="Whether to over-write already present segmentation results." 155 ) 156 parser.add_argument( 157 "--tile_shape", type=int, nargs=3, 158 help="The tile shape for prediction, in ZYX order. Lower the tile shape if GPU memory is insufficient." 159 ) 160 parser.add_argument( 161 "--halo", type=int, nargs=3, 162 help="The halo for prediction, in ZYX order. Increase the halo to minimize boundary artifacts." 163 ) 164 parser.add_argument( 165 "--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc." 166 ) 167 parser.add_argument( 168 "--checkpoint", "-c", help="Path to a custom model, e.g. from domain adaptation.", 169 ) 170 parser.add_argument( 171 "--segmentation_key", "-s", 172 help="If given, the outputs will be saved to an hdf5 file with this key. Otherwise they will be saved as tif.", 173 ) 174 parser.add_argument( 175 "--scale", type=float, 176 help="The factor for rescaling the data before inference. " 177 "By default, the scaling factor will be derived from the voxel size of the input data. " 178 "If this parameter is given it will over-ride the default behavior. " 179 ) 180 parser.add_argument( 181 "--verbose", "-v", action="store_true", 182 help="Whether to print verbose information about the segmentation progress." 183 ) 184 parser.add_argument( 185 "--scalable", action="store_true", help="Use the scalable segmentation implementation. " 186 "Currently this only works for vesicles, mitochondria, or active zones." 187 ) 188 parser.add_argument( 189 "--extra_input_path", default=None, help="Filepath to extra inputs, needed for cristae segmentation." 190 ) 191 parser.add_argument( 192 "--extra_input_ext", default=".tif", help="File extension for the extra inputs, default is tif." 193 ) 194 args = parser.parse_args() 195 196 if args.checkpoint is None: 197 model = get_model(args.model) 198 else: 199 checkpoint_path = args.checkpoint 200 if checkpoint_path.endswith("best.pt"): 201 checkpoint_path = os.path.split(checkpoint_path)[0] 202 203 if os.path.isdir(checkpoint_path): # Load the model from a torch_em checkpoint. 204 model = torch_em.util.load_model(checkpoint=checkpoint_path) 205 else: 206 model = torch.load(checkpoint_path, weights_only=False) 207 assert model is not None, f"The model from {args.checkpoint} could not be loaded." 208 209 is_2d = "2d" in args.model 210 tiling = parse_tiling(args.tile_shape, args.halo, is_2d=is_2d) 211 212 # If the scale argument is not passed, then we get the average training resolution for the model. 213 # The inputs will then be scaled to match this resolution based on the voxel size from the mrc files. 214 if args.scale is None: 215 model_resolution = get_model_training_resolution(args.model) 216 model_resolution = tuple(model_resolution[ax] for ax in ("yx" if is_2d else "zyx")) 217 scale = None 218 # Otherwise, we set the model resolution to None and use the scaling factor provided by the user. 219 else: 220 model_resolution = None 221 scale = (2 if is_2d else 3) * (args.scale,) 222 223 if args.scalable: 224 if not args.model.startswith(("vesicle", "mito", "active")): 225 raise ValueError( 226 "The scalable segmentation implementation is currently only supported for " 227 f"vesicles, mitochondria, or active zones, not for {args.model}." 228 ) 229 segmentation_function = partial( 230 scalable_segmentation, model=model, tiling=tiling, verbose=args.verbose 231 ) 232 allocate_output = True 233 234 else: 235 segmentation_function = partial( 236 run_segmentation, model=model, model_type=args.model, verbose=args.verbose, tiling=tiling, 237 ) 238 allocate_output = False 239 240 inference_helper( 241 args.input_path, args.output_path, segmentation_function, 242 mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext, 243 output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale, 244 allocate_output=allocate_output, extra_input_path=args.extra_input_path, 245 extra_input_ext=args.extra_input_ext 246 )