synapse_net.tools.cli

  1import argparse
  2from functools import partial
  3
  4import torch
  5from ..imod.to_imod import export_helper, write_segmentation_to_imod_as_points, write_segmentation_to_imod
  6from ..inference.inference import _get_model_registry, get_model, get_model_training_resolution, run_segmentation
  7from ..inference.util import inference_helper, parse_tiling
  8
  9
 10def imod_point_cli():
 11    parser = argparse.ArgumentParser(
 12        description="Convert a vesicle segmentation to an IMOD point model, "
 13        "corresponding to a sphere for each vesicle in the segmentation."
 14    )
 15    parser.add_argument(
 16        "--input_path", "-i", required=True,
 17        help="The filepath to the mrc file or the directory containing the tomogram data."
 18    )
 19    parser.add_argument(
 20        "--segmentation_path", "-s", required=True,
 21        help="The filepath to the file or the directory containing the segmentations."
 22    )
 23    parser.add_argument(
 24        "--output_path", "-o", required=True,
 25        help="The filepath to directory where the segmentations will be saved."
 26    )
 27    parser.add_argument(
 28        "--segmentation_key", "-k",
 29        help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif."
 30        "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset."
 31    )
 32    parser.add_argument(
 33        "--min_radius", type=float, default=10.0,
 34        help="The minimum vesicle radius in nm. Objects that are smaller than this radius will be exclded from the export."  # noqa
 35    )
 36    parser.add_argument(
 37        "--radius_factor", type=float, default=1.0,
 38        help="A factor for scaling the sphere radius for the export. "
 39        "This can be used to fit the size of segmented vesicles to the best matching spheres.",
 40    )
 41    parser.add_argument(
 42        "--force", action="store_true",
 43        help="Whether to over-write already present export results."
 44    )
 45    args = parser.parse_args()
 46
 47    export_function = partial(
 48        write_segmentation_to_imod_as_points,
 49        min_radius=args.min_radius,
 50        radius_factor=args.radius_factor,
 51    )
 52
 53    export_helper(
 54        input_path=args.input_path,
 55        segmentation_path=args.segmentation_path,
 56        output_root=args.output_path,
 57        export_function=export_function,
 58        force=args.force,
 59        segmentation_key=args.segmentation_key,
 60    )
 61
 62
 63def imod_object_cli():
 64    parser = argparse.ArgumentParser(
 65        description="Convert segmented objects to close contour IMOD models."
 66    )
 67    parser.add_argument(
 68        "--input_path", "-i", required=True,
 69        help="The filepath to the mrc file or the directory containing the tomogram data."
 70    )
 71    parser.add_argument(
 72        "--segmentation_path", "-s", required=True,
 73        help="The filepath to the file or the directory containing the segmentations."
 74    )
 75    parser.add_argument(
 76        "--output_path", "-o", required=True,
 77        help="The filepath to directory where the segmentations will be saved."
 78    )
 79    parser.add_argument(
 80        "--segmentation_key", "-k",
 81        help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif."
 82        "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset."
 83    )
 84    parser.add_argument(
 85        "--force", action="store_true",
 86        help="Whether to over-write already present export results."
 87    )
 88    args = parser.parse_args()
 89    export_helper(
 90        input_path=args.input_path,
 91        segmentation_path=args.segmentation_path,
 92        output_root=args.output_path,
 93        export_function=write_segmentation_to_imod,
 94        force=args.force,
 95        segmentation_key=args.segmentation_key,
 96    )
 97
 98
 99# TODO: handle kwargs
100def segmentation_cli():
101    parser = argparse.ArgumentParser(description="Run segmentation.")
102    parser.add_argument(
103        "--input_path", "-i", required=True,
104        help="The filepath to the mrc file or the directory containing the tomogram data."
105    )
106    parser.add_argument(
107        "--output_path", "-o", required=True,
108        help="The filepath to directory where the segmentations will be saved."
109    )
110    model_names = list(_get_model_registry().urls.keys())
111    model_names = ", ".join(model_names)
112    parser.add_argument(
113        "--model", "-m", required=True,
114        help=f"The model type. The following models are currently available: {model_names}"
115    )
116    parser.add_argument(
117        "--mask_path", help="The filepath to a tif file with a mask that will be used to restrict the segmentation."
118        "Can also be a directory with tifs if the filestructure matches input_path."
119    )
120    parser.add_argument("--input_key", "-k", required=False)
121    parser.add_argument(
122        "--force", action="store_true",
123        help="Whether to over-write already present segmentation results."
124    )
125    parser.add_argument(
126        "--tile_shape", type=int, nargs=3,
127        help="The tile shape for prediction, in ZYX order. Lower the tile shape if GPU memory is insufficient."
128    )
129    parser.add_argument(
130        "--halo", type=int, nargs=3,
131        help="The halo for prediction, in ZYX order. Increase the halo to minimize boundary artifacts."
132    )
133    parser.add_argument(
134        "--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc."
135    )
136    parser.add_argument(
137        "--checkpoint", "-c", help="Path to a custom model, e.g. from domain adaptation.",
138    )
139    parser.add_argument(
140        "--segmentation_key", "-s",
141        help="If given, the outputs will be saved to an hdf5 file with this key. Otherwise they will be saved as tif.",
142    )
143    parser.add_argument(
144        "--scale", type=float,
145        help="The factor for rescaling the data before inference. "
146        "By default, the scaling factor will be derived from the voxel size of the input data. "
147        "If this parameter is given it will over-ride the default behavior. "
148    )
149    parser.add_argument(
150        "--verbose", "-v", action="store_true",
151        help="Whether to print verbose information about the segmentation progress."
152    )
153    args = parser.parse_args()
154
155    if args.checkpoint is None:
156        model = get_model(args.model)
157    else:
158        model = torch.load(args.checkpoint, weights_only=False)
159        assert model is not None, f"The model from {args.checkpoint} could not be loaded."
160
161    is_2d = "2d" in args.model
162    tiling = parse_tiling(args.tile_shape, args.halo, is_2d=is_2d)
163
164    # If the scale argument is not passed, then we get the average training resolution for the model.
165    # The inputs will then be scaled to match this resolution based on the voxel size from the mrc files.
166    if args.scale is None:
167        model_resolution = get_model_training_resolution(args.model)
168        model_resolution = tuple(model_resolution[ax] for ax in ("yx" if is_2d else "zyx"))
169        scale = None
170    # Otherwise, we set the model resolution to None and use the scaling factor provided by the user.
171    else:
172        model_resolution = None
173        scale = (2 if is_2d else 3) * (args.scale,)
174
175    segmentation_function = partial(
176        run_segmentation, model=model, model_type=args.model, verbose=args.verbose, tiling=tiling,
177    )
178    inference_helper(
179        args.input_path, args.output_path, segmentation_function,
180        mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext,
181        output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale,
182    )
def imod_point_cli():
11def imod_point_cli():
12    parser = argparse.ArgumentParser(
13        description="Convert a vesicle segmentation to an IMOD point model, "
14        "corresponding to a sphere for each vesicle in the segmentation."
15    )
16    parser.add_argument(
17        "--input_path", "-i", required=True,
18        help="The filepath to the mrc file or the directory containing the tomogram data."
19    )
20    parser.add_argument(
21        "--segmentation_path", "-s", required=True,
22        help="The filepath to the file or the directory containing the segmentations."
23    )
24    parser.add_argument(
25        "--output_path", "-o", required=True,
26        help="The filepath to directory where the segmentations will be saved."
27    )
28    parser.add_argument(
29        "--segmentation_key", "-k",
30        help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif."
31        "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset."
32    )
33    parser.add_argument(
34        "--min_radius", type=float, default=10.0,
35        help="The minimum vesicle radius in nm. Objects that are smaller than this radius will be exclded from the export."  # noqa
36    )
37    parser.add_argument(
38        "--radius_factor", type=float, default=1.0,
39        help="A factor for scaling the sphere radius for the export. "
40        "This can be used to fit the size of segmented vesicles to the best matching spheres.",
41    )
42    parser.add_argument(
43        "--force", action="store_true",
44        help="Whether to over-write already present export results."
45    )
46    args = parser.parse_args()
47
48    export_function = partial(
49        write_segmentation_to_imod_as_points,
50        min_radius=args.min_radius,
51        radius_factor=args.radius_factor,
52    )
53
54    export_helper(
55        input_path=args.input_path,
56        segmentation_path=args.segmentation_path,
57        output_root=args.output_path,
58        export_function=export_function,
59        force=args.force,
60        segmentation_key=args.segmentation_key,
61    )
def imod_object_cli():
64def imod_object_cli():
65    parser = argparse.ArgumentParser(
66        description="Convert segmented objects to close contour IMOD models."
67    )
68    parser.add_argument(
69        "--input_path", "-i", required=True,
70        help="The filepath to the mrc file or the directory containing the tomogram data."
71    )
72    parser.add_argument(
73        "--segmentation_path", "-s", required=True,
74        help="The filepath to the file or the directory containing the segmentations."
75    )
76    parser.add_argument(
77        "--output_path", "-o", required=True,
78        help="The filepath to directory where the segmentations will be saved."
79    )
80    parser.add_argument(
81        "--segmentation_key", "-k",
82        help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif."
83        "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset."
84    )
85    parser.add_argument(
86        "--force", action="store_true",
87        help="Whether to over-write already present export results."
88    )
89    args = parser.parse_args()
90    export_helper(
91        input_path=args.input_path,
92        segmentation_path=args.segmentation_path,
93        output_root=args.output_path,
94        export_function=write_segmentation_to_imod,
95        force=args.force,
96        segmentation_key=args.segmentation_key,
97    )
def segmentation_cli():
101def segmentation_cli():
102    parser = argparse.ArgumentParser(description="Run segmentation.")
103    parser.add_argument(
104        "--input_path", "-i", required=True,
105        help="The filepath to the mrc file or the directory containing the tomogram data."
106    )
107    parser.add_argument(
108        "--output_path", "-o", required=True,
109        help="The filepath to directory where the segmentations will be saved."
110    )
111    model_names = list(_get_model_registry().urls.keys())
112    model_names = ", ".join(model_names)
113    parser.add_argument(
114        "--model", "-m", required=True,
115        help=f"The model type. The following models are currently available: {model_names}"
116    )
117    parser.add_argument(
118        "--mask_path", help="The filepath to a tif file with a mask that will be used to restrict the segmentation."
119        "Can also be a directory with tifs if the filestructure matches input_path."
120    )
121    parser.add_argument("--input_key", "-k", required=False)
122    parser.add_argument(
123        "--force", action="store_true",
124        help="Whether to over-write already present segmentation results."
125    )
126    parser.add_argument(
127        "--tile_shape", type=int, nargs=3,
128        help="The tile shape for prediction, in ZYX order. Lower the tile shape if GPU memory is insufficient."
129    )
130    parser.add_argument(
131        "--halo", type=int, nargs=3,
132        help="The halo for prediction, in ZYX order. Increase the halo to minimize boundary artifacts."
133    )
134    parser.add_argument(
135        "--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc."
136    )
137    parser.add_argument(
138        "--checkpoint", "-c", help="Path to a custom model, e.g. from domain adaptation.",
139    )
140    parser.add_argument(
141        "--segmentation_key", "-s",
142        help="If given, the outputs will be saved to an hdf5 file with this key. Otherwise they will be saved as tif.",
143    )
144    parser.add_argument(
145        "--scale", type=float,
146        help="The factor for rescaling the data before inference. "
147        "By default, the scaling factor will be derived from the voxel size of the input data. "
148        "If this parameter is given it will over-ride the default behavior. "
149    )
150    parser.add_argument(
151        "--verbose", "-v", action="store_true",
152        help="Whether to print verbose information about the segmentation progress."
153    )
154    args = parser.parse_args()
155
156    if args.checkpoint is None:
157        model = get_model(args.model)
158    else:
159        model = torch.load(args.checkpoint, weights_only=False)
160        assert model is not None, f"The model from {args.checkpoint} could not be loaded."
161
162    is_2d = "2d" in args.model
163    tiling = parse_tiling(args.tile_shape, args.halo, is_2d=is_2d)
164
165    # If the scale argument is not passed, then we get the average training resolution for the model.
166    # The inputs will then be scaled to match this resolution based on the voxel size from the mrc files.
167    if args.scale is None:
168        model_resolution = get_model_training_resolution(args.model)
169        model_resolution = tuple(model_resolution[ax] for ax in ("yx" if is_2d else "zyx"))
170        scale = None
171    # Otherwise, we set the model resolution to None and use the scaling factor provided by the user.
172    else:
173        model_resolution = None
174        scale = (2 if is_2d else 3) * (args.scale,)
175
176    segmentation_function = partial(
177        run_segmentation, model=model, model_type=args.model, verbose=args.verbose, tiling=tiling,
178    )
179    inference_helper(
180        args.input_path, args.output_path, segmentation_function,
181        mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext,
182        output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale,
183    )