synapse_net.tools.cli
1import argparse 2import os 3from functools import partial 4 5import torch 6import torch_em 7from ..imod.to_imod import export_helper, write_segmentation_to_imod_as_points, write_segmentation_to_imod 8from ..inference.inference import _get_model_registry, get_model, get_model_training_resolution, run_segmentation 9from ..inference.scalable_segmentation import scalable_segmentation 10from ..inference.util import inference_helper, parse_tiling 11 12 13def imod_point_cli(): 14 parser = argparse.ArgumentParser( 15 description="Convert a vesicle segmentation to an IMOD point model, " 16 "corresponding to a sphere for each vesicle in the segmentation." 17 ) 18 parser.add_argument( 19 "--input_path", "-i", required=True, 20 help="The filepath to the mrc file or the directory containing the tomogram data." 21 ) 22 parser.add_argument( 23 "--segmentation_path", "-s", required=True, 24 help="The filepath to the file or the directory containing the segmentations." 25 ) 26 parser.add_argument( 27 "--output_path", "-o", required=True, 28 help="The filepath to directory where the segmentations will be saved." 29 ) 30 parser.add_argument( 31 "--segmentation_key", "-k", 32 help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif." 33 "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset." 34 ) 35 parser.add_argument( 36 "--min_radius", type=float, default=10.0, 37 help="The minimum vesicle radius in nm. Objects that are smaller than this radius will be exclded from the export." # noqa 38 ) 39 parser.add_argument( 40 "--radius_factor", type=float, default=1.0, 41 help="A factor for scaling the sphere radius for the export. " 42 "This can be used to fit the size of segmented vesicles to the best matching spheres.", 43 ) 44 parser.add_argument( 45 "--force", action="store_true", 46 help="Whether to over-write already present export results." 47 ) 48 args = parser.parse_args() 49 50 export_function = partial( 51 write_segmentation_to_imod_as_points, 52 min_radius=args.min_radius, 53 radius_factor=args.radius_factor, 54 ) 55 56 export_helper( 57 input_path=args.input_path, 58 segmentation_path=args.segmentation_path, 59 output_root=args.output_path, 60 export_function=export_function, 61 force=args.force, 62 segmentation_key=args.segmentation_key, 63 ) 64 65 66def imod_object_cli(): 67 parser = argparse.ArgumentParser( 68 description="Convert segmented objects to close contour IMOD models." 69 ) 70 parser.add_argument( 71 "--input_path", "-i", required=True, 72 help="The filepath to the mrc file or the directory containing the tomogram data." 73 ) 74 parser.add_argument( 75 "--segmentation_path", "-s", required=True, 76 help="The filepath to the file or the directory containing the segmentations." 77 ) 78 parser.add_argument( 79 "--output_path", "-o", required=True, 80 help="The filepath to directory where the segmentations will be saved." 81 ) 82 parser.add_argument( 83 "--segmentation_key", "-k", 84 help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif." 85 "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset." 86 ) 87 parser.add_argument( 88 "--force", action="store_true", 89 help="Whether to over-write already present export results." 90 ) 91 args = parser.parse_args() 92 export_helper( 93 input_path=args.input_path, 94 segmentation_path=args.segmentation_path, 95 output_root=args.output_path, 96 export_function=write_segmentation_to_imod, 97 force=args.force, 98 segmentation_key=args.segmentation_key, 99 ) 100 101 102# TODO: handle kwargs 103def segmentation_cli(): 104 parser = argparse.ArgumentParser(description="Run segmentation.") 105 parser.add_argument( 106 "--input_path", "-i", required=True, 107 help="The filepath to the mrc file or the directory containing the tomogram data." 108 ) 109 parser.add_argument( 110 "--output_path", "-o", required=True, 111 help="The filepath to directory where the segmentations will be saved." 112 ) 113 model_names = list(_get_model_registry().urls.keys()) 114 model_names = ", ".join(model_names) 115 parser.add_argument( 116 "--model", "-m", required=True, 117 help=f"The model type. The following models are currently available: {model_names}" 118 ) 119 parser.add_argument( 120 "--mask_path", help="The filepath to a tif file with a mask that will be used to restrict the segmentation." 121 "Can also be a directory with tifs if the filestructure matches input_path." 122 ) 123 parser.add_argument("--input_key", "-k", required=False) 124 parser.add_argument( 125 "--force", action="store_true", 126 help="Whether to over-write already present segmentation results." 127 ) 128 parser.add_argument( 129 "--tile_shape", type=int, nargs=3, 130 help="The tile shape for prediction, in ZYX order. Lower the tile shape if GPU memory is insufficient." 131 ) 132 parser.add_argument( 133 "--halo", type=int, nargs=3, 134 help="The halo for prediction, in ZYX order. Increase the halo to minimize boundary artifacts." 135 ) 136 parser.add_argument( 137 "--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc." 138 ) 139 parser.add_argument( 140 "--checkpoint", "-c", help="Path to a custom model, e.g. from domain adaptation.", 141 ) 142 parser.add_argument( 143 "--segmentation_key", "-s", 144 help="If given, the outputs will be saved to an hdf5 file with this key. Otherwise they will be saved as tif.", 145 ) 146 parser.add_argument( 147 "--scale", type=float, 148 help="The factor for rescaling the data before inference. " 149 "By default, the scaling factor will be derived from the voxel size of the input data. " 150 "If this parameter is given it will over-ride the default behavior. " 151 ) 152 parser.add_argument( 153 "--verbose", "-v", action="store_true", 154 help="Whether to print verbose information about the segmentation progress." 155 ) 156 parser.add_argument( 157 "--scalable", action="store_true", help="Use the scalable segmentation implementation. " 158 "Currently this only works for vesicles, mitochondria, or active zones." 159 ) 160 args = parser.parse_args() 161 162 if args.checkpoint is None: 163 model = get_model(args.model) 164 else: 165 checkpoint_path = args.checkpoint 166 if checkpoint_path.endswith("best.pt"): 167 checkpoint_path = os.path.split(checkpoint_path)[0] 168 169 if os.path.isdir(checkpoint_path): # Load the model from a torch_em checkpoint. 170 model = torch_em.util.load_model(checkpoint=checkpoint_path) 171 else: 172 model = torch.load(checkpoint_path, weights_only=False) 173 assert model is not None, f"The model from {args.checkpoint} could not be loaded." 174 175 is_2d = "2d" in args.model 176 tiling = parse_tiling(args.tile_shape, args.halo, is_2d=is_2d) 177 178 # If the scale argument is not passed, then we get the average training resolution for the model. 179 # The inputs will then be scaled to match this resolution based on the voxel size from the mrc files. 180 if args.scale is None: 181 model_resolution = get_model_training_resolution(args.model) 182 model_resolution = tuple(model_resolution[ax] for ax in ("yx" if is_2d else "zyx")) 183 scale = None 184 # Otherwise, we set the model resolution to None and use the scaling factor provided by the user. 185 else: 186 model_resolution = None 187 scale = (2 if is_2d else 3) * (args.scale,) 188 189 if args.scalable: 190 if not args.model.startswith(("vesicle", "mito", "active")): 191 raise ValueError( 192 "The scalable segmentation implementation is currently only supported for " 193 f"vesicles, mitochondria, or active zones, not for {args.model}." 194 ) 195 segmentation_function = partial( 196 scalable_segmentation, model=model, tiling=tiling, verbose=args.verbose 197 ) 198 allocate_output = True 199 200 else: 201 segmentation_function = partial( 202 run_segmentation, model=model, model_type=args.model, verbose=args.verbose, tiling=tiling, 203 ) 204 allocate_output = False 205 206 inference_helper( 207 args.input_path, args.output_path, segmentation_function, 208 mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext, 209 output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale, 210 allocate_output=allocate_output 211 )
def
imod_point_cli():
14def imod_point_cli(): 15 parser = argparse.ArgumentParser( 16 description="Convert a vesicle segmentation to an IMOD point model, " 17 "corresponding to a sphere for each vesicle in the segmentation." 18 ) 19 parser.add_argument( 20 "--input_path", "-i", required=True, 21 help="The filepath to the mrc file or the directory containing the tomogram data." 22 ) 23 parser.add_argument( 24 "--segmentation_path", "-s", required=True, 25 help="The filepath to the file or the directory containing the segmentations." 26 ) 27 parser.add_argument( 28 "--output_path", "-o", required=True, 29 help="The filepath to directory where the segmentations will be saved." 30 ) 31 parser.add_argument( 32 "--segmentation_key", "-k", 33 help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif." 34 "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset." 35 ) 36 parser.add_argument( 37 "--min_radius", type=float, default=10.0, 38 help="The minimum vesicle radius in nm. Objects that are smaller than this radius will be exclded from the export." # noqa 39 ) 40 parser.add_argument( 41 "--radius_factor", type=float, default=1.0, 42 help="A factor for scaling the sphere radius for the export. " 43 "This can be used to fit the size of segmented vesicles to the best matching spheres.", 44 ) 45 parser.add_argument( 46 "--force", action="store_true", 47 help="Whether to over-write already present export results." 48 ) 49 args = parser.parse_args() 50 51 export_function = partial( 52 write_segmentation_to_imod_as_points, 53 min_radius=args.min_radius, 54 radius_factor=args.radius_factor, 55 ) 56 57 export_helper( 58 input_path=args.input_path, 59 segmentation_path=args.segmentation_path, 60 output_root=args.output_path, 61 export_function=export_function, 62 force=args.force, 63 segmentation_key=args.segmentation_key, 64 )
def
imod_object_cli():
67def imod_object_cli(): 68 parser = argparse.ArgumentParser( 69 description="Convert segmented objects to close contour IMOD models." 70 ) 71 parser.add_argument( 72 "--input_path", "-i", required=True, 73 help="The filepath to the mrc file or the directory containing the tomogram data." 74 ) 75 parser.add_argument( 76 "--segmentation_path", "-s", required=True, 77 help="The filepath to the file or the directory containing the segmentations." 78 ) 79 parser.add_argument( 80 "--output_path", "-o", required=True, 81 help="The filepath to directory where the segmentations will be saved." 82 ) 83 parser.add_argument( 84 "--segmentation_key", "-k", 85 help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif." 86 "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset." 87 ) 88 parser.add_argument( 89 "--force", action="store_true", 90 help="Whether to over-write already present export results." 91 ) 92 args = parser.parse_args() 93 export_helper( 94 input_path=args.input_path, 95 segmentation_path=args.segmentation_path, 96 output_root=args.output_path, 97 export_function=write_segmentation_to_imod, 98 force=args.force, 99 segmentation_key=args.segmentation_key, 100 )
def
segmentation_cli():
104def segmentation_cli(): 105 parser = argparse.ArgumentParser(description="Run segmentation.") 106 parser.add_argument( 107 "--input_path", "-i", required=True, 108 help="The filepath to the mrc file or the directory containing the tomogram data." 109 ) 110 parser.add_argument( 111 "--output_path", "-o", required=True, 112 help="The filepath to directory where the segmentations will be saved." 113 ) 114 model_names = list(_get_model_registry().urls.keys()) 115 model_names = ", ".join(model_names) 116 parser.add_argument( 117 "--model", "-m", required=True, 118 help=f"The model type. The following models are currently available: {model_names}" 119 ) 120 parser.add_argument( 121 "--mask_path", help="The filepath to a tif file with a mask that will be used to restrict the segmentation." 122 "Can also be a directory with tifs if the filestructure matches input_path." 123 ) 124 parser.add_argument("--input_key", "-k", required=False) 125 parser.add_argument( 126 "--force", action="store_true", 127 help="Whether to over-write already present segmentation results." 128 ) 129 parser.add_argument( 130 "--tile_shape", type=int, nargs=3, 131 help="The tile shape for prediction, in ZYX order. Lower the tile shape if GPU memory is insufficient." 132 ) 133 parser.add_argument( 134 "--halo", type=int, nargs=3, 135 help="The halo for prediction, in ZYX order. Increase the halo to minimize boundary artifacts." 136 ) 137 parser.add_argument( 138 "--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc." 139 ) 140 parser.add_argument( 141 "--checkpoint", "-c", help="Path to a custom model, e.g. from domain adaptation.", 142 ) 143 parser.add_argument( 144 "--segmentation_key", "-s", 145 help="If given, the outputs will be saved to an hdf5 file with this key. Otherwise they will be saved as tif.", 146 ) 147 parser.add_argument( 148 "--scale", type=float, 149 help="The factor for rescaling the data before inference. " 150 "By default, the scaling factor will be derived from the voxel size of the input data. " 151 "If this parameter is given it will over-ride the default behavior. " 152 ) 153 parser.add_argument( 154 "--verbose", "-v", action="store_true", 155 help="Whether to print verbose information about the segmentation progress." 156 ) 157 parser.add_argument( 158 "--scalable", action="store_true", help="Use the scalable segmentation implementation. " 159 "Currently this only works for vesicles, mitochondria, or active zones." 160 ) 161 args = parser.parse_args() 162 163 if args.checkpoint is None: 164 model = get_model(args.model) 165 else: 166 checkpoint_path = args.checkpoint 167 if checkpoint_path.endswith("best.pt"): 168 checkpoint_path = os.path.split(checkpoint_path)[0] 169 170 if os.path.isdir(checkpoint_path): # Load the model from a torch_em checkpoint. 171 model = torch_em.util.load_model(checkpoint=checkpoint_path) 172 else: 173 model = torch.load(checkpoint_path, weights_only=False) 174 assert model is not None, f"The model from {args.checkpoint} could not be loaded." 175 176 is_2d = "2d" in args.model 177 tiling = parse_tiling(args.tile_shape, args.halo, is_2d=is_2d) 178 179 # If the scale argument is not passed, then we get the average training resolution for the model. 180 # The inputs will then be scaled to match this resolution based on the voxel size from the mrc files. 181 if args.scale is None: 182 model_resolution = get_model_training_resolution(args.model) 183 model_resolution = tuple(model_resolution[ax] for ax in ("yx" if is_2d else "zyx")) 184 scale = None 185 # Otherwise, we set the model resolution to None and use the scaling factor provided by the user. 186 else: 187 model_resolution = None 188 scale = (2 if is_2d else 3) * (args.scale,) 189 190 if args.scalable: 191 if not args.model.startswith(("vesicle", "mito", "active")): 192 raise ValueError( 193 "The scalable segmentation implementation is currently only supported for " 194 f"vesicles, mitochondria, or active zones, not for {args.model}." 195 ) 196 segmentation_function = partial( 197 scalable_segmentation, model=model, tiling=tiling, verbose=args.verbose 198 ) 199 allocate_output = True 200 201 else: 202 segmentation_function = partial( 203 run_segmentation, model=model, model_type=args.model, verbose=args.verbose, tiling=tiling, 204 ) 205 allocate_output = False 206 207 inference_helper( 208 args.input_path, args.output_path, segmentation_function, 209 mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext, 210 output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale, 211 allocate_output=allocate_output 212 )