synapse_net.tools.cli
1import argparse 2from functools import partial 3 4import torch 5from ..imod.to_imod import export_helper, write_segmentation_to_imod_as_points, write_segmentation_to_imod 6from ..inference.inference import _get_model_registry, get_model, get_model_training_resolution, run_segmentation 7from ..inference.util import inference_helper, parse_tiling 8 9 10def imod_point_cli(): 11 parser = argparse.ArgumentParser( 12 description="Convert a vesicle segmentation to an IMOD point model, " 13 "corresponding to a sphere for each vesicle in the segmentation." 14 ) 15 parser.add_argument( 16 "--input_path", "-i", required=True, 17 help="The filepath to the mrc file or the directory containing the tomogram data." 18 ) 19 parser.add_argument( 20 "--segmentation_path", "-s", required=True, 21 help="The filepath to the file or the directory containing the segmentations." 22 ) 23 parser.add_argument( 24 "--output_path", "-o", required=True, 25 help="The filepath to directory where the segmentations will be saved." 26 ) 27 parser.add_argument( 28 "--segmentation_key", "-k", 29 help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif." 30 "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset." 31 ) 32 parser.add_argument( 33 "--min_radius", type=float, default=10.0, 34 help="The minimum vesicle radius in nm. Objects that are smaller than this radius will be exclded from the export." # noqa 35 ) 36 parser.add_argument( 37 "--radius_factor", type=float, default=1.0, 38 help="A factor for scaling the sphere radius for the export. " 39 "This can be used to fit the size of segmented vesicles to the best matching spheres.", 40 ) 41 parser.add_argument( 42 "--force", action="store_true", 43 help="Whether to over-write already present export results." 44 ) 45 args = parser.parse_args() 46 47 export_function = partial( 48 write_segmentation_to_imod_as_points, 49 min_radius=args.min_radius, 50 radius_factor=args.radius_factor, 51 ) 52 53 export_helper( 54 input_path=args.input_path, 55 segmentation_path=args.segmentation_path, 56 output_root=args.output_path, 57 export_function=export_function, 58 force=args.force, 59 segmentation_key=args.segmentation_key, 60 ) 61 62 63def imod_object_cli(): 64 parser = argparse.ArgumentParser( 65 description="Convert segmented objects to close contour IMOD models." 66 ) 67 parser.add_argument( 68 "--input_path", "-i", required=True, 69 help="The filepath to the mrc file or the directory containing the tomogram data." 70 ) 71 parser.add_argument( 72 "--segmentation_path", "-s", required=True, 73 help="The filepath to the file or the directory containing the segmentations." 74 ) 75 parser.add_argument( 76 "--output_path", "-o", required=True, 77 help="The filepath to directory where the segmentations will be saved." 78 ) 79 parser.add_argument( 80 "--segmentation_key", "-k", 81 help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif." 82 "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset." 83 ) 84 parser.add_argument( 85 "--force", action="store_true", 86 help="Whether to over-write already present export results." 87 ) 88 args = parser.parse_args() 89 export_helper( 90 input_path=args.input_path, 91 segmentation_path=args.segmentation_path, 92 output_root=args.output_path, 93 export_function=write_segmentation_to_imod, 94 force=args.force, 95 segmentation_key=args.segmentation_key, 96 ) 97 98 99# TODO: handle kwargs 100def segmentation_cli(): 101 parser = argparse.ArgumentParser(description="Run segmentation.") 102 parser.add_argument( 103 "--input_path", "-i", required=True, 104 help="The filepath to the mrc file or the directory containing the tomogram data." 105 ) 106 parser.add_argument( 107 "--output_path", "-o", required=True, 108 help="The filepath to directory where the segmentations will be saved." 109 ) 110 model_names = list(_get_model_registry().urls.keys()) 111 model_names = ", ".join(model_names) 112 parser.add_argument( 113 "--model", "-m", required=True, 114 help=f"The model type. The following models are currently available: {model_names}" 115 ) 116 parser.add_argument( 117 "--mask_path", help="The filepath to a tif file with a mask that will be used to restrict the segmentation." 118 "Can also be a directory with tifs if the filestructure matches input_path." 119 ) 120 parser.add_argument("--input_key", "-k", required=False) 121 parser.add_argument( 122 "--force", action="store_true", 123 help="Whether to over-write already present segmentation results." 124 ) 125 parser.add_argument( 126 "--tile_shape", type=int, nargs=3, 127 help="The tile shape for prediction, in ZYX order. Lower the tile shape if GPU memory is insufficient." 128 ) 129 parser.add_argument( 130 "--halo", type=int, nargs=3, 131 help="The halo for prediction, in ZYX order. Increase the halo to minimize boundary artifacts." 132 ) 133 parser.add_argument( 134 "--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc." 135 ) 136 parser.add_argument( 137 "--checkpoint", "-c", help="Path to a custom model, e.g. from domain adaptation.", 138 ) 139 parser.add_argument( 140 "--segmentation_key", "-s", 141 help="If given, the outputs will be saved to an hdf5 file with this key. Otherwise they will be saved as tif.", 142 ) 143 parser.add_argument( 144 "--scale", type=float, 145 help="The factor for rescaling the data before inference. " 146 "By default, the scaling factor will be derived from the voxel size of the input data. " 147 "If this parameter is given it will over-ride the default behavior. " 148 ) 149 parser.add_argument( 150 "--verbose", "-v", action="store_true", 151 help="Whether to print verbose information about the segmentation progress." 152 ) 153 args = parser.parse_args() 154 155 if args.checkpoint is None: 156 model = get_model(args.model) 157 else: 158 model = torch.load(args.checkpoint, weights_only=False) 159 assert model is not None, f"The model from {args.checkpoint} could not be loaded." 160 161 is_2d = "2d" in args.model 162 tiling = parse_tiling(args.tile_shape, args.halo, is_2d=is_2d) 163 164 # If the scale argument is not passed, then we get the average training resolution for the model. 165 # The inputs will then be scaled to match this resolution based on the voxel size from the mrc files. 166 if args.scale is None: 167 model_resolution = get_model_training_resolution(args.model) 168 model_resolution = tuple(model_resolution[ax] for ax in ("yx" if is_2d else "zyx")) 169 scale = None 170 # Otherwise, we set the model resolution to None and use the scaling factor provided by the user. 171 else: 172 model_resolution = None 173 scale = (2 if is_2d else 3) * (args.scale,) 174 175 segmentation_function = partial( 176 run_segmentation, model=model, model_type=args.model, verbose=args.verbose, tiling=tiling, 177 ) 178 inference_helper( 179 args.input_path, args.output_path, segmentation_function, 180 mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext, 181 output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale, 182 )
def
imod_point_cli():
11def imod_point_cli(): 12 parser = argparse.ArgumentParser( 13 description="Convert a vesicle segmentation to an IMOD point model, " 14 "corresponding to a sphere for each vesicle in the segmentation." 15 ) 16 parser.add_argument( 17 "--input_path", "-i", required=True, 18 help="The filepath to the mrc file or the directory containing the tomogram data." 19 ) 20 parser.add_argument( 21 "--segmentation_path", "-s", required=True, 22 help="The filepath to the file or the directory containing the segmentations." 23 ) 24 parser.add_argument( 25 "--output_path", "-o", required=True, 26 help="The filepath to directory where the segmentations will be saved." 27 ) 28 parser.add_argument( 29 "--segmentation_key", "-k", 30 help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif." 31 "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset." 32 ) 33 parser.add_argument( 34 "--min_radius", type=float, default=10.0, 35 help="The minimum vesicle radius in nm. Objects that are smaller than this radius will be exclded from the export." # noqa 36 ) 37 parser.add_argument( 38 "--radius_factor", type=float, default=1.0, 39 help="A factor for scaling the sphere radius for the export. " 40 "This can be used to fit the size of segmented vesicles to the best matching spheres.", 41 ) 42 parser.add_argument( 43 "--force", action="store_true", 44 help="Whether to over-write already present export results." 45 ) 46 args = parser.parse_args() 47 48 export_function = partial( 49 write_segmentation_to_imod_as_points, 50 min_radius=args.min_radius, 51 radius_factor=args.radius_factor, 52 ) 53 54 export_helper( 55 input_path=args.input_path, 56 segmentation_path=args.segmentation_path, 57 output_root=args.output_path, 58 export_function=export_function, 59 force=args.force, 60 segmentation_key=args.segmentation_key, 61 )
def
imod_object_cli():
64def imod_object_cli(): 65 parser = argparse.ArgumentParser( 66 description="Convert segmented objects to close contour IMOD models." 67 ) 68 parser.add_argument( 69 "--input_path", "-i", required=True, 70 help="The filepath to the mrc file or the directory containing the tomogram data." 71 ) 72 parser.add_argument( 73 "--segmentation_path", "-s", required=True, 74 help="The filepath to the file or the directory containing the segmentations." 75 ) 76 parser.add_argument( 77 "--output_path", "-o", required=True, 78 help="The filepath to directory where the segmentations will be saved." 79 ) 80 parser.add_argument( 81 "--segmentation_key", "-k", 82 help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif." 83 "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset." 84 ) 85 parser.add_argument( 86 "--force", action="store_true", 87 help="Whether to over-write already present export results." 88 ) 89 args = parser.parse_args() 90 export_helper( 91 input_path=args.input_path, 92 segmentation_path=args.segmentation_path, 93 output_root=args.output_path, 94 export_function=write_segmentation_to_imod, 95 force=args.force, 96 segmentation_key=args.segmentation_key, 97 )
def
segmentation_cli():
101def segmentation_cli(): 102 parser = argparse.ArgumentParser(description="Run segmentation.") 103 parser.add_argument( 104 "--input_path", "-i", required=True, 105 help="The filepath to the mrc file or the directory containing the tomogram data." 106 ) 107 parser.add_argument( 108 "--output_path", "-o", required=True, 109 help="The filepath to directory where the segmentations will be saved." 110 ) 111 model_names = list(_get_model_registry().urls.keys()) 112 model_names = ", ".join(model_names) 113 parser.add_argument( 114 "--model", "-m", required=True, 115 help=f"The model type. The following models are currently available: {model_names}" 116 ) 117 parser.add_argument( 118 "--mask_path", help="The filepath to a tif file with a mask that will be used to restrict the segmentation." 119 "Can also be a directory with tifs if the filestructure matches input_path." 120 ) 121 parser.add_argument("--input_key", "-k", required=False) 122 parser.add_argument( 123 "--force", action="store_true", 124 help="Whether to over-write already present segmentation results." 125 ) 126 parser.add_argument( 127 "--tile_shape", type=int, nargs=3, 128 help="The tile shape for prediction, in ZYX order. Lower the tile shape if GPU memory is insufficient." 129 ) 130 parser.add_argument( 131 "--halo", type=int, nargs=3, 132 help="The halo for prediction, in ZYX order. Increase the halo to minimize boundary artifacts." 133 ) 134 parser.add_argument( 135 "--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc." 136 ) 137 parser.add_argument( 138 "--checkpoint", "-c", help="Path to a custom model, e.g. from domain adaptation.", 139 ) 140 parser.add_argument( 141 "--segmentation_key", "-s", 142 help="If given, the outputs will be saved to an hdf5 file with this key. Otherwise they will be saved as tif.", 143 ) 144 parser.add_argument( 145 "--scale", type=float, 146 help="The factor for rescaling the data before inference. " 147 "By default, the scaling factor will be derived from the voxel size of the input data. " 148 "If this parameter is given it will over-ride the default behavior. " 149 ) 150 parser.add_argument( 151 "--verbose", "-v", action="store_true", 152 help="Whether to print verbose information about the segmentation progress." 153 ) 154 args = parser.parse_args() 155 156 if args.checkpoint is None: 157 model = get_model(args.model) 158 else: 159 model = torch.load(args.checkpoint, weights_only=False) 160 assert model is not None, f"The model from {args.checkpoint} could not be loaded." 161 162 is_2d = "2d" in args.model 163 tiling = parse_tiling(args.tile_shape, args.halo, is_2d=is_2d) 164 165 # If the scale argument is not passed, then we get the average training resolution for the model. 166 # The inputs will then be scaled to match this resolution based on the voxel size from the mrc files. 167 if args.scale is None: 168 model_resolution = get_model_training_resolution(args.model) 169 model_resolution = tuple(model_resolution[ax] for ax in ("yx" if is_2d else "zyx")) 170 scale = None 171 # Otherwise, we set the model resolution to None and use the scaling factor provided by the user. 172 else: 173 model_resolution = None 174 scale = (2 if is_2d else 3) * (args.scale,) 175 176 segmentation_function = partial( 177 run_segmentation, model=model, model_type=args.model, verbose=args.verbose, tiling=tiling, 178 ) 179 inference_helper( 180 args.input_path, args.output_path, segmentation_function, 181 mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext, 182 output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale, 183 )