synapse_net.tools.cli
1import argparse 2from functools import partial 3 4import torch 5from ..imod.to_imod import export_helper, write_segmentation_to_imod_as_points, write_segmentation_to_imod 6from ..inference.inference import _get_model_registry, get_model, get_model_training_resolution, run_segmentation 7from ..inference.util import inference_helper, parse_tiling 8 9 10def imod_point_cli(): 11 parser = argparse.ArgumentParser( 12 description="Convert a vesicle segmentation to an IMOD point model, " 13 "corresponding to a sphere for each vesicle in the segmentation." 14 ) 15 parser.add_argument( 16 "--input_path", "-i", required=True, 17 help="The filepath to the mrc file or the directory containing the tomogram data." 18 ) 19 parser.add_argument( 20 "--segmentation_path", "-s", required=True, 21 help="The filepath to the file or the directory containing the segmentations." 22 ) 23 parser.add_argument( 24 "--output_path", "-o", required=True, 25 help="The filepath to directory where the segmentations will be saved." 26 ) 27 parser.add_argument( 28 "--segmentation_key", "-k", 29 help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif." 30 "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset." 31 ) 32 parser.add_argument( 33 "--min_radius", type=float, default=10.0, 34 help="The minimum vesicle radius in nm. Objects that are smaller than this radius will be exclded from the export." # noqa 35 ) 36 parser.add_argument( 37 "--radius_factor", type=float, default=1.0, 38 help="A factor for scaling the sphere radius for the export. " 39 "This can be used to fit the size of segmented vesicles to the best matching spheres.", 40 ) 41 parser.add_argument( 42 "--force", action="store_true", 43 help="Whether to over-write already present export results." 44 ) 45 args = parser.parse_args() 46 47 export_function = partial( 48 write_segmentation_to_imod_as_points, 49 min_radius=args.min_radius, 50 radius_factor=args.radius_factor, 51 ) 52 53 export_helper( 54 input_path=args.input_path, 55 segmentation_path=args.segmentation_path, 56 output_root=args.output_path, 57 export_function=export_function, 58 force=args.force, 59 segmentation_key=args.segmentation_key, 60 ) 61 62 63def imod_object_cli(): 64 parser = argparse.ArgumentParser( 65 description="Convert segmented objects to close contour IMOD models." 66 ) 67 parser.add_argument( 68 "--input_path", "-i", required=True, 69 help="The filepath to the mrc file or the directory containing the tomogram data." 70 ) 71 parser.add_argument( 72 "--segmentation_path", "-s", required=True, 73 help="The filepath to the file or the directory containing the segmentations." 74 ) 75 parser.add_argument( 76 "--output_path", "-o", required=True, 77 help="The filepath to directory where the segmentations will be saved." 78 ) 79 parser.add_argument( 80 "--segmentation_key", "-k", 81 help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif." 82 "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset." 83 ) 84 parser.add_argument( 85 "--force", action="store_true", 86 help="Whether to over-write already present export results." 87 ) 88 args = parser.parse_args() 89 export_helper( 90 input_path=args.input_path, 91 segmentation_path=args.segmentation_path, 92 output_root=args.output_path, 93 export_function=write_segmentation_to_imod, 94 force=args.force, 95 segmentation_key=args.segmentation_key, 96 ) 97 98 99# TODO: handle kwargs 100def segmentation_cli(): 101 parser = argparse.ArgumentParser(description="Run segmentation.") 102 parser.add_argument( 103 "--input_path", "-i", required=True, 104 help="The filepath to the mrc file or the directory containing the tomogram data." 105 ) 106 parser.add_argument( 107 "--output_path", "-o", required=True, 108 help="The filepath to directory where the segmentations will be saved." 109 ) 110 model_names = list(_get_model_registry().urls.keys()) 111 model_names = ", ".join(model_names) 112 parser.add_argument( 113 "--model", "-m", required=True, 114 help=f"The model type. The following models are currently available: {model_names}" 115 ) 116 parser.add_argument( 117 "--mask_path", help="The filepath to a tif file with a mask that will be used to restrict the segmentation." 118 "Can also be a directory with tifs if the filestructure matches input_path." 119 ) 120 parser.add_argument("--input_key", "-k", required=False) 121 parser.add_argument( 122 "--force", action="store_true", 123 help="Whether to over-write already present segmentation results." 124 ) 125 parser.add_argument( 126 "--tile_shape", type=int, nargs=3, 127 help="The tile shape for prediction, in ZYX order. Lower the tile shape if GPU memory is insufficient." 128 ) 129 parser.add_argument( 130 "--halo", type=int, nargs=3, 131 help="The halo for prediction, in ZYX order. Increase the halo to minimize boundary artifacts." 132 ) 133 parser.add_argument( 134 "--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc." 135 ) 136 parser.add_argument( 137 "--checkpoint", "-c", help="Path to a custom model, e.g. from domain adaptation.", 138 ) 139 parser.add_argument( 140 "--segmentation_key", "-s", 141 help="If given, the outputs will be saved to an hdf5 file with this key. Otherwise they will be saved as tif.", 142 ) 143 parser.add_argument( 144 "--scale", type=float, 145 help="The factor for rescaling the data before inference. " 146 "By default, the scaling factor will be derived from the voxel size of the input data. " 147 "If this parameter is given it will over-ride the default behavior. " 148 ) 149 args = parser.parse_args() 150 151 if args.checkpoint is None: 152 model = get_model(args.model) 153 else: 154 model = torch.load(args.checkpoint, weights_only=False) 155 assert model is not None, f"The model from {args.checkpoint} could not be loaded." 156 157 is_2d = "2d" in args.model 158 tiling = parse_tiling(args.tile_shape, args.halo, is_2d=is_2d) 159 160 # If the scale argument is not passed, then we get the average training resolution for the model. 161 # The inputs will then be scaled to match this resolution based on the voxel size from the mrc files. 162 if args.scale is None: 163 model_resolution = get_model_training_resolution(args.model) 164 model_resolution = tuple(model_resolution[ax] for ax in ("yx" if is_2d else "zyx")) 165 scale = None 166 # Otherwise, we set the model resolution to None and use the scaling factor provided by the user. 167 else: 168 model_resolution = None 169 scale = (2 if is_2d else 3) * (args.scale,) 170 171 segmentation_function = partial( 172 run_segmentation, model=model, model_type=args.model, verbose=False, tiling=tiling, 173 ) 174 inference_helper( 175 args.input_path, args.output_path, segmentation_function, 176 mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext, 177 output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale, 178 )
def
imod_point_cli():
11def imod_point_cli(): 12 parser = argparse.ArgumentParser( 13 description="Convert a vesicle segmentation to an IMOD point model, " 14 "corresponding to a sphere for each vesicle in the segmentation." 15 ) 16 parser.add_argument( 17 "--input_path", "-i", required=True, 18 help="The filepath to the mrc file or the directory containing the tomogram data." 19 ) 20 parser.add_argument( 21 "--segmentation_path", "-s", required=True, 22 help="The filepath to the file or the directory containing the segmentations." 23 ) 24 parser.add_argument( 25 "--output_path", "-o", required=True, 26 help="The filepath to directory where the segmentations will be saved." 27 ) 28 parser.add_argument( 29 "--segmentation_key", "-k", 30 help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif." 31 "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset." 32 ) 33 parser.add_argument( 34 "--min_radius", type=float, default=10.0, 35 help="The minimum vesicle radius in nm. Objects that are smaller than this radius will be exclded from the export." # noqa 36 ) 37 parser.add_argument( 38 "--radius_factor", type=float, default=1.0, 39 help="A factor for scaling the sphere radius for the export. " 40 "This can be used to fit the size of segmented vesicles to the best matching spheres.", 41 ) 42 parser.add_argument( 43 "--force", action="store_true", 44 help="Whether to over-write already present export results." 45 ) 46 args = parser.parse_args() 47 48 export_function = partial( 49 write_segmentation_to_imod_as_points, 50 min_radius=args.min_radius, 51 radius_factor=args.radius_factor, 52 ) 53 54 export_helper( 55 input_path=args.input_path, 56 segmentation_path=args.segmentation_path, 57 output_root=args.output_path, 58 export_function=export_function, 59 force=args.force, 60 segmentation_key=args.segmentation_key, 61 )
def
imod_object_cli():
64def imod_object_cli(): 65 parser = argparse.ArgumentParser( 66 description="Convert segmented objects to close contour IMOD models." 67 ) 68 parser.add_argument( 69 "--input_path", "-i", required=True, 70 help="The filepath to the mrc file or the directory containing the tomogram data." 71 ) 72 parser.add_argument( 73 "--segmentation_path", "-s", required=True, 74 help="The filepath to the file or the directory containing the segmentations." 75 ) 76 parser.add_argument( 77 "--output_path", "-o", required=True, 78 help="The filepath to directory where the segmentations will be saved." 79 ) 80 parser.add_argument( 81 "--segmentation_key", "-k", 82 help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif." 83 "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset." 84 ) 85 parser.add_argument( 86 "--force", action="store_true", 87 help="Whether to over-write already present export results." 88 ) 89 args = parser.parse_args() 90 export_helper( 91 input_path=args.input_path, 92 segmentation_path=args.segmentation_path, 93 output_root=args.output_path, 94 export_function=write_segmentation_to_imod, 95 force=args.force, 96 segmentation_key=args.segmentation_key, 97 )
def
segmentation_cli():
101def segmentation_cli(): 102 parser = argparse.ArgumentParser(description="Run segmentation.") 103 parser.add_argument( 104 "--input_path", "-i", required=True, 105 help="The filepath to the mrc file or the directory containing the tomogram data." 106 ) 107 parser.add_argument( 108 "--output_path", "-o", required=True, 109 help="The filepath to directory where the segmentations will be saved." 110 ) 111 model_names = list(_get_model_registry().urls.keys()) 112 model_names = ", ".join(model_names) 113 parser.add_argument( 114 "--model", "-m", required=True, 115 help=f"The model type. The following models are currently available: {model_names}" 116 ) 117 parser.add_argument( 118 "--mask_path", help="The filepath to a tif file with a mask that will be used to restrict the segmentation." 119 "Can also be a directory with tifs if the filestructure matches input_path." 120 ) 121 parser.add_argument("--input_key", "-k", required=False) 122 parser.add_argument( 123 "--force", action="store_true", 124 help="Whether to over-write already present segmentation results." 125 ) 126 parser.add_argument( 127 "--tile_shape", type=int, nargs=3, 128 help="The tile shape for prediction, in ZYX order. Lower the tile shape if GPU memory is insufficient." 129 ) 130 parser.add_argument( 131 "--halo", type=int, nargs=3, 132 help="The halo for prediction, in ZYX order. Increase the halo to minimize boundary artifacts." 133 ) 134 parser.add_argument( 135 "--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc." 136 ) 137 parser.add_argument( 138 "--checkpoint", "-c", help="Path to a custom model, e.g. from domain adaptation.", 139 ) 140 parser.add_argument( 141 "--segmentation_key", "-s", 142 help="If given, the outputs will be saved to an hdf5 file with this key. Otherwise they will be saved as tif.", 143 ) 144 parser.add_argument( 145 "--scale", type=float, 146 help="The factor for rescaling the data before inference. " 147 "By default, the scaling factor will be derived from the voxel size of the input data. " 148 "If this parameter is given it will over-ride the default behavior. " 149 ) 150 args = parser.parse_args() 151 152 if args.checkpoint is None: 153 model = get_model(args.model) 154 else: 155 model = torch.load(args.checkpoint, weights_only=False) 156 assert model is not None, f"The model from {args.checkpoint} could not be loaded." 157 158 is_2d = "2d" in args.model 159 tiling = parse_tiling(args.tile_shape, args.halo, is_2d=is_2d) 160 161 # If the scale argument is not passed, then we get the average training resolution for the model. 162 # The inputs will then be scaled to match this resolution based on the voxel size from the mrc files. 163 if args.scale is None: 164 model_resolution = get_model_training_resolution(args.model) 165 model_resolution = tuple(model_resolution[ax] for ax in ("yx" if is_2d else "zyx")) 166 scale = None 167 # Otherwise, we set the model resolution to None and use the scaling factor provided by the user. 168 else: 169 model_resolution = None 170 scale = (2 if is_2d else 3) * (args.scale,) 171 172 segmentation_function = partial( 173 run_segmentation, model=model, model_type=args.model, verbose=False, tiling=tiling, 174 ) 175 inference_helper( 176 args.input_path, args.output_path, segmentation_function, 177 mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext, 178 output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale, 179 )