synapse_net.tools.cli
1import argparse 2import os 3from functools import partial 4 5import torch 6import torch_em 7from ..imod.to_imod import export_helper, write_segmentation_to_imod_as_points, write_segmentation_to_imod 8from ..inference.inference import _get_model_registry, get_model, get_model_training_resolution, run_segmentation 9from ..inference.util import inference_helper, parse_tiling 10 11 12def imod_point_cli(): 13 parser = argparse.ArgumentParser( 14 description="Convert a vesicle segmentation to an IMOD point model, " 15 "corresponding to a sphere for each vesicle in the segmentation." 16 ) 17 parser.add_argument( 18 "--input_path", "-i", required=True, 19 help="The filepath to the mrc file or the directory containing the tomogram data." 20 ) 21 parser.add_argument( 22 "--segmentation_path", "-s", required=True, 23 help="The filepath to the file or the directory containing the segmentations." 24 ) 25 parser.add_argument( 26 "--output_path", "-o", required=True, 27 help="The filepath to directory where the segmentations will be saved." 28 ) 29 parser.add_argument( 30 "--segmentation_key", "-k", 31 help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif." 32 "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset." 33 ) 34 parser.add_argument( 35 "--min_radius", type=float, default=10.0, 36 help="The minimum vesicle radius in nm. Objects that are smaller than this radius will be exclded from the export." # noqa 37 ) 38 parser.add_argument( 39 "--radius_factor", type=float, default=1.0, 40 help="A factor for scaling the sphere radius for the export. " 41 "This can be used to fit the size of segmented vesicles to the best matching spheres.", 42 ) 43 parser.add_argument( 44 "--force", action="store_true", 45 help="Whether to over-write already present export results." 46 ) 47 args = parser.parse_args() 48 49 export_function = partial( 50 write_segmentation_to_imod_as_points, 51 min_radius=args.min_radius, 52 radius_factor=args.radius_factor, 53 ) 54 55 export_helper( 56 input_path=args.input_path, 57 segmentation_path=args.segmentation_path, 58 output_root=args.output_path, 59 export_function=export_function, 60 force=args.force, 61 segmentation_key=args.segmentation_key, 62 ) 63 64 65def imod_object_cli(): 66 parser = argparse.ArgumentParser( 67 description="Convert segmented objects to close contour IMOD models." 68 ) 69 parser.add_argument( 70 "--input_path", "-i", required=True, 71 help="The filepath to the mrc file or the directory containing the tomogram data." 72 ) 73 parser.add_argument( 74 "--segmentation_path", "-s", required=True, 75 help="The filepath to the file or the directory containing the segmentations." 76 ) 77 parser.add_argument( 78 "--output_path", "-o", required=True, 79 help="The filepath to directory where the segmentations will be saved." 80 ) 81 parser.add_argument( 82 "--segmentation_key", "-k", 83 help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif." 84 "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset." 85 ) 86 parser.add_argument( 87 "--force", action="store_true", 88 help="Whether to over-write already present export results." 89 ) 90 args = parser.parse_args() 91 export_helper( 92 input_path=args.input_path, 93 segmentation_path=args.segmentation_path, 94 output_root=args.output_path, 95 export_function=write_segmentation_to_imod, 96 force=args.force, 97 segmentation_key=args.segmentation_key, 98 ) 99 100 101# TODO: handle kwargs 102def segmentation_cli(): 103 parser = argparse.ArgumentParser(description="Run segmentation.") 104 parser.add_argument( 105 "--input_path", "-i", required=True, 106 help="The filepath to the mrc file or the directory containing the tomogram data." 107 ) 108 parser.add_argument( 109 "--output_path", "-o", required=True, 110 help="The filepath to directory where the segmentations will be saved." 111 ) 112 model_names = list(_get_model_registry().urls.keys()) 113 model_names = ", ".join(model_names) 114 parser.add_argument( 115 "--model", "-m", required=True, 116 help=f"The model type. The following models are currently available: {model_names}" 117 ) 118 parser.add_argument( 119 "--mask_path", help="The filepath to a tif file with a mask that will be used to restrict the segmentation." 120 "Can also be a directory with tifs if the filestructure matches input_path." 121 ) 122 parser.add_argument("--input_key", "-k", required=False) 123 parser.add_argument( 124 "--force", action="store_true", 125 help="Whether to over-write already present segmentation results." 126 ) 127 parser.add_argument( 128 "--tile_shape", type=int, nargs=3, 129 help="The tile shape for prediction, in ZYX order. Lower the tile shape if GPU memory is insufficient." 130 ) 131 parser.add_argument( 132 "--halo", type=int, nargs=3, 133 help="The halo for prediction, in ZYX order. Increase the halo to minimize boundary artifacts." 134 ) 135 parser.add_argument( 136 "--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc." 137 ) 138 parser.add_argument( 139 "--checkpoint", "-c", help="Path to a custom model, e.g. from domain adaptation.", 140 ) 141 parser.add_argument( 142 "--segmentation_key", "-s", 143 help="If given, the outputs will be saved to an hdf5 file with this key. Otherwise they will be saved as tif.", 144 ) 145 parser.add_argument( 146 "--scale", type=float, 147 help="The factor for rescaling the data before inference. " 148 "By default, the scaling factor will be derived from the voxel size of the input data. " 149 "If this parameter is given it will over-ride the default behavior. " 150 ) 151 parser.add_argument( 152 "--verbose", "-v", action="store_true", 153 help="Whether to print verbose information about the segmentation progress." 154 ) 155 args = parser.parse_args() 156 157 if args.checkpoint is None: 158 model = get_model(args.model) 159 else: 160 checkpoint_path = args.checkpoint 161 if checkpoint_path.endswith("best.pt"): 162 checkpoint_path = os.path.split(checkpoint_path)[0] 163 164 if os.path.isdir(checkpoint_path): # Load the model from a torch_em checkpoint. 165 model = torch_em.util.load_model(checkpoint=checkpoint_path) 166 else: 167 model = torch.load(checkpoint_path, weights_only=False) 168 assert model is not None, f"The model from {args.checkpoint} could not be loaded." 169 170 is_2d = "2d" in args.model 171 tiling = parse_tiling(args.tile_shape, args.halo, is_2d=is_2d) 172 173 # If the scale argument is not passed, then we get the average training resolution for the model. 174 # The inputs will then be scaled to match this resolution based on the voxel size from the mrc files. 175 if args.scale is None: 176 model_resolution = get_model_training_resolution(args.model) 177 model_resolution = tuple(model_resolution[ax] for ax in ("yx" if is_2d else "zyx")) 178 scale = None 179 # Otherwise, we set the model resolution to None and use the scaling factor provided by the user. 180 else: 181 model_resolution = None 182 scale = (2 if is_2d else 3) * (args.scale,) 183 184 segmentation_function = partial( 185 run_segmentation, model=model, model_type=args.model, verbose=args.verbose, tiling=tiling, 186 ) 187 inference_helper( 188 args.input_path, args.output_path, segmentation_function, 189 mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext, 190 output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale, 191 )
def
imod_point_cli():
13def imod_point_cli(): 14 parser = argparse.ArgumentParser( 15 description="Convert a vesicle segmentation to an IMOD point model, " 16 "corresponding to a sphere for each vesicle in the segmentation." 17 ) 18 parser.add_argument( 19 "--input_path", "-i", required=True, 20 help="The filepath to the mrc file or the directory containing the tomogram data." 21 ) 22 parser.add_argument( 23 "--segmentation_path", "-s", required=True, 24 help="The filepath to the file or the directory containing the segmentations." 25 ) 26 parser.add_argument( 27 "--output_path", "-o", required=True, 28 help="The filepath to directory where the segmentations will be saved." 29 ) 30 parser.add_argument( 31 "--segmentation_key", "-k", 32 help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif." 33 "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset." 34 ) 35 parser.add_argument( 36 "--min_radius", type=float, default=10.0, 37 help="The minimum vesicle radius in nm. Objects that are smaller than this radius will be exclded from the export." # noqa 38 ) 39 parser.add_argument( 40 "--radius_factor", type=float, default=1.0, 41 help="A factor for scaling the sphere radius for the export. " 42 "This can be used to fit the size of segmented vesicles to the best matching spheres.", 43 ) 44 parser.add_argument( 45 "--force", action="store_true", 46 help="Whether to over-write already present export results." 47 ) 48 args = parser.parse_args() 49 50 export_function = partial( 51 write_segmentation_to_imod_as_points, 52 min_radius=args.min_radius, 53 radius_factor=args.radius_factor, 54 ) 55 56 export_helper( 57 input_path=args.input_path, 58 segmentation_path=args.segmentation_path, 59 output_root=args.output_path, 60 export_function=export_function, 61 force=args.force, 62 segmentation_key=args.segmentation_key, 63 )
def
imod_object_cli():
66def imod_object_cli(): 67 parser = argparse.ArgumentParser( 68 description="Convert segmented objects to close contour IMOD models." 69 ) 70 parser.add_argument( 71 "--input_path", "-i", required=True, 72 help="The filepath to the mrc file or the directory containing the tomogram data." 73 ) 74 parser.add_argument( 75 "--segmentation_path", "-s", required=True, 76 help="The filepath to the file or the directory containing the segmentations." 77 ) 78 parser.add_argument( 79 "--output_path", "-o", required=True, 80 help="The filepath to directory where the segmentations will be saved." 81 ) 82 parser.add_argument( 83 "--segmentation_key", "-k", 84 help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif." 85 "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset." 86 ) 87 parser.add_argument( 88 "--force", action="store_true", 89 help="Whether to over-write already present export results." 90 ) 91 args = parser.parse_args() 92 export_helper( 93 input_path=args.input_path, 94 segmentation_path=args.segmentation_path, 95 output_root=args.output_path, 96 export_function=write_segmentation_to_imod, 97 force=args.force, 98 segmentation_key=args.segmentation_key, 99 )
def
segmentation_cli():
103def segmentation_cli(): 104 parser = argparse.ArgumentParser(description="Run segmentation.") 105 parser.add_argument( 106 "--input_path", "-i", required=True, 107 help="The filepath to the mrc file or the directory containing the tomogram data." 108 ) 109 parser.add_argument( 110 "--output_path", "-o", required=True, 111 help="The filepath to directory where the segmentations will be saved." 112 ) 113 model_names = list(_get_model_registry().urls.keys()) 114 model_names = ", ".join(model_names) 115 parser.add_argument( 116 "--model", "-m", required=True, 117 help=f"The model type. The following models are currently available: {model_names}" 118 ) 119 parser.add_argument( 120 "--mask_path", help="The filepath to a tif file with a mask that will be used to restrict the segmentation." 121 "Can also be a directory with tifs if the filestructure matches input_path." 122 ) 123 parser.add_argument("--input_key", "-k", required=False) 124 parser.add_argument( 125 "--force", action="store_true", 126 help="Whether to over-write already present segmentation results." 127 ) 128 parser.add_argument( 129 "--tile_shape", type=int, nargs=3, 130 help="The tile shape for prediction, in ZYX order. Lower the tile shape if GPU memory is insufficient." 131 ) 132 parser.add_argument( 133 "--halo", type=int, nargs=3, 134 help="The halo for prediction, in ZYX order. Increase the halo to minimize boundary artifacts." 135 ) 136 parser.add_argument( 137 "--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc." 138 ) 139 parser.add_argument( 140 "--checkpoint", "-c", help="Path to a custom model, e.g. from domain adaptation.", 141 ) 142 parser.add_argument( 143 "--segmentation_key", "-s", 144 help="If given, the outputs will be saved to an hdf5 file with this key. Otherwise they will be saved as tif.", 145 ) 146 parser.add_argument( 147 "--scale", type=float, 148 help="The factor for rescaling the data before inference. " 149 "By default, the scaling factor will be derived from the voxel size of the input data. " 150 "If this parameter is given it will over-ride the default behavior. " 151 ) 152 parser.add_argument( 153 "--verbose", "-v", action="store_true", 154 help="Whether to print verbose information about the segmentation progress." 155 ) 156 args = parser.parse_args() 157 158 if args.checkpoint is None: 159 model = get_model(args.model) 160 else: 161 checkpoint_path = args.checkpoint 162 if checkpoint_path.endswith("best.pt"): 163 checkpoint_path = os.path.split(checkpoint_path)[0] 164 165 if os.path.isdir(checkpoint_path): # Load the model from a torch_em checkpoint. 166 model = torch_em.util.load_model(checkpoint=checkpoint_path) 167 else: 168 model = torch.load(checkpoint_path, weights_only=False) 169 assert model is not None, f"The model from {args.checkpoint} could not be loaded." 170 171 is_2d = "2d" in args.model 172 tiling = parse_tiling(args.tile_shape, args.halo, is_2d=is_2d) 173 174 # If the scale argument is not passed, then we get the average training resolution for the model. 175 # The inputs will then be scaled to match this resolution based on the voxel size from the mrc files. 176 if args.scale is None: 177 model_resolution = get_model_training_resolution(args.model) 178 model_resolution = tuple(model_resolution[ax] for ax in ("yx" if is_2d else "zyx")) 179 scale = None 180 # Otherwise, we set the model resolution to None and use the scaling factor provided by the user. 181 else: 182 model_resolution = None 183 scale = (2 if is_2d else 3) * (args.scale,) 184 185 segmentation_function = partial( 186 run_segmentation, model=model, model_type=args.model, verbose=args.verbose, tiling=tiling, 187 ) 188 inference_helper( 189 args.input_path, args.output_path, segmentation_function, 190 mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext, 191 output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale, 192 )