synapse_net.tools.cli

  1import argparse
  2from functools import partial
  3
  4import torch
  5from ..imod.to_imod import export_helper, write_segmentation_to_imod_as_points, write_segmentation_to_imod
  6from ..inference.inference import _get_model_registry, get_model, get_model_training_resolution, run_segmentation
  7from ..inference.util import inference_helper, parse_tiling
  8
  9
 10def imod_point_cli():
 11    parser = argparse.ArgumentParser(
 12        description="Convert a vesicle segmentation to an IMOD point model, "
 13        "corresponding to a sphere for each vesicle in the segmentation."
 14    )
 15    parser.add_argument(
 16        "--input_path", "-i", required=True,
 17        help="The filepath to the mrc file or the directory containing the tomogram data."
 18    )
 19    parser.add_argument(
 20        "--segmentation_path", "-s", required=True,
 21        help="The filepath to the file or the directory containing the segmentations."
 22    )
 23    parser.add_argument(
 24        "--output_path", "-o", required=True,
 25        help="The filepath to directory where the segmentations will be saved."
 26    )
 27    parser.add_argument(
 28        "--segmentation_key", "-k",
 29        help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif."
 30        "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset."
 31    )
 32    parser.add_argument(
 33        "--min_radius", type=float, default=10.0,
 34        help="The minimum vesicle radius in nm. Objects that are smaller than this radius will be exclded from the export."  # noqa
 35    )
 36    parser.add_argument(
 37        "--radius_factor", type=float, default=1.0,
 38        help="A factor for scaling the sphere radius for the export. "
 39        "This can be used to fit the size of segmented vesicles to the best matching spheres.",
 40    )
 41    parser.add_argument(
 42        "--force", action="store_true",
 43        help="Whether to over-write already present export results."
 44    )
 45    args = parser.parse_args()
 46
 47    export_function = partial(
 48        write_segmentation_to_imod_as_points,
 49        min_radius=args.min_radius,
 50        radius_factor=args.radius_factor,
 51    )
 52
 53    export_helper(
 54        input_path=args.input_path,
 55        segmentation_path=args.segmentation_path,
 56        output_root=args.output_path,
 57        export_function=export_function,
 58        force=args.force,
 59        segmentation_key=args.segmentation_key,
 60    )
 61
 62
 63def imod_object_cli():
 64    parser = argparse.ArgumentParser(
 65        description="Convert segmented objects to close contour IMOD models."
 66    )
 67    parser.add_argument(
 68        "--input_path", "-i", required=True,
 69        help="The filepath to the mrc file or the directory containing the tomogram data."
 70    )
 71    parser.add_argument(
 72        "--segmentation_path", "-s", required=True,
 73        help="The filepath to the file or the directory containing the segmentations."
 74    )
 75    parser.add_argument(
 76        "--output_path", "-o", required=True,
 77        help="The filepath to directory where the segmentations will be saved."
 78    )
 79    parser.add_argument(
 80        "--segmentation_key", "-k",
 81        help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif."
 82        "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset."
 83    )
 84    parser.add_argument(
 85        "--force", action="store_true",
 86        help="Whether to over-write already present export results."
 87    )
 88    args = parser.parse_args()
 89    export_helper(
 90        input_path=args.input_path,
 91        segmentation_path=args.segmentation_path,
 92        output_root=args.output_path,
 93        export_function=write_segmentation_to_imod,
 94        force=args.force,
 95        segmentation_key=args.segmentation_key,
 96    )
 97
 98
 99# TODO: handle kwargs
100def segmentation_cli():
101    parser = argparse.ArgumentParser(description="Run segmentation.")
102    parser.add_argument(
103        "--input_path", "-i", required=True,
104        help="The filepath to the mrc file or the directory containing the tomogram data."
105    )
106    parser.add_argument(
107        "--output_path", "-o", required=True,
108        help="The filepath to directory where the segmentations will be saved."
109    )
110    model_names = list(_get_model_registry().urls.keys())
111    model_names = ", ".join(model_names)
112    parser.add_argument(
113        "--model", "-m", required=True,
114        help=f"The model type. The following models are currently available: {model_names}"
115    )
116    parser.add_argument(
117        "--mask_path", help="The filepath to a tif file with a mask that will be used to restrict the segmentation."
118        "Can also be a directory with tifs if the filestructure matches input_path."
119    )
120    parser.add_argument("--input_key", "-k", required=False)
121    parser.add_argument(
122        "--force", action="store_true",
123        help="Whether to over-write already present segmentation results."
124    )
125    parser.add_argument(
126        "--tile_shape", type=int, nargs=3,
127        help="The tile shape for prediction, in ZYX order. Lower the tile shape if GPU memory is insufficient."
128    )
129    parser.add_argument(
130        "--halo", type=int, nargs=3,
131        help="The halo for prediction, in ZYX order. Increase the halo to minimize boundary artifacts."
132    )
133    parser.add_argument(
134        "--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc."
135    )
136    parser.add_argument(
137        "--checkpoint", "-c", help="Path to a custom model, e.g. from domain adaptation.",
138    )
139    parser.add_argument(
140        "--segmentation_key", "-s",
141        help="If given, the outputs will be saved to an hdf5 file with this key. Otherwise they will be saved as tif.",
142    )
143    parser.add_argument(
144        "--scale", type=float,
145        help="The factor for rescaling the data before inference. "
146        "By default, the scaling factor will be derived from the voxel size of the input data. "
147        "If this parameter is given it will over-ride the default behavior. "
148    )
149    args = parser.parse_args()
150
151    if args.checkpoint is None:
152        model = get_model(args.model)
153    else:
154        model = torch.load(args.checkpoint, weights_only=False)
155        assert model is not None, f"The model from {args.checkpoint} could not be loaded."
156
157    is_2d = "2d" in args.model
158    tiling = parse_tiling(args.tile_shape, args.halo, is_2d=is_2d)
159
160    # If the scale argument is not passed, then we get the average training resolution for the model.
161    # The inputs will then be scaled to match this resolution based on the voxel size from the mrc files.
162    if args.scale is None:
163        model_resolution = get_model_training_resolution(args.model)
164        model_resolution = tuple(model_resolution[ax] for ax in ("yx" if is_2d else "zyx"))
165        scale = None
166    # Otherwise, we set the model resolution to None and use the scaling factor provided by the user.
167    else:
168        model_resolution = None
169        scale = (2 if is_2d else 3) * (args.scale,)
170
171    segmentation_function = partial(
172        run_segmentation, model=model, model_type=args.model, verbose=False, tiling=tiling,
173    )
174    inference_helper(
175        args.input_path, args.output_path, segmentation_function,
176        mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext,
177        output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale,
178    )
def imod_point_cli():
11def imod_point_cli():
12    parser = argparse.ArgumentParser(
13        description="Convert a vesicle segmentation to an IMOD point model, "
14        "corresponding to a sphere for each vesicle in the segmentation."
15    )
16    parser.add_argument(
17        "--input_path", "-i", required=True,
18        help="The filepath to the mrc file or the directory containing the tomogram data."
19    )
20    parser.add_argument(
21        "--segmentation_path", "-s", required=True,
22        help="The filepath to the file or the directory containing the segmentations."
23    )
24    parser.add_argument(
25        "--output_path", "-o", required=True,
26        help="The filepath to directory where the segmentations will be saved."
27    )
28    parser.add_argument(
29        "--segmentation_key", "-k",
30        help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif."
31        "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset."
32    )
33    parser.add_argument(
34        "--min_radius", type=float, default=10.0,
35        help="The minimum vesicle radius in nm. Objects that are smaller than this radius will be exclded from the export."  # noqa
36    )
37    parser.add_argument(
38        "--radius_factor", type=float, default=1.0,
39        help="A factor for scaling the sphere radius for the export. "
40        "This can be used to fit the size of segmented vesicles to the best matching spheres.",
41    )
42    parser.add_argument(
43        "--force", action="store_true",
44        help="Whether to over-write already present export results."
45    )
46    args = parser.parse_args()
47
48    export_function = partial(
49        write_segmentation_to_imod_as_points,
50        min_radius=args.min_radius,
51        radius_factor=args.radius_factor,
52    )
53
54    export_helper(
55        input_path=args.input_path,
56        segmentation_path=args.segmentation_path,
57        output_root=args.output_path,
58        export_function=export_function,
59        force=args.force,
60        segmentation_key=args.segmentation_key,
61    )
def imod_object_cli():
64def imod_object_cli():
65    parser = argparse.ArgumentParser(
66        description="Convert segmented objects to close contour IMOD models."
67    )
68    parser.add_argument(
69        "--input_path", "-i", required=True,
70        help="The filepath to the mrc file or the directory containing the tomogram data."
71    )
72    parser.add_argument(
73        "--segmentation_path", "-s", required=True,
74        help="The filepath to the file or the directory containing the segmentations."
75    )
76    parser.add_argument(
77        "--output_path", "-o", required=True,
78        help="The filepath to directory where the segmentations will be saved."
79    )
80    parser.add_argument(
81        "--segmentation_key", "-k",
82        help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif."
83        "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset."
84    )
85    parser.add_argument(
86        "--force", action="store_true",
87        help="Whether to over-write already present export results."
88    )
89    args = parser.parse_args()
90    export_helper(
91        input_path=args.input_path,
92        segmentation_path=args.segmentation_path,
93        output_root=args.output_path,
94        export_function=write_segmentation_to_imod,
95        force=args.force,
96        segmentation_key=args.segmentation_key,
97    )
def segmentation_cli():
101def segmentation_cli():
102    parser = argparse.ArgumentParser(description="Run segmentation.")
103    parser.add_argument(
104        "--input_path", "-i", required=True,
105        help="The filepath to the mrc file or the directory containing the tomogram data."
106    )
107    parser.add_argument(
108        "--output_path", "-o", required=True,
109        help="The filepath to directory where the segmentations will be saved."
110    )
111    model_names = list(_get_model_registry().urls.keys())
112    model_names = ", ".join(model_names)
113    parser.add_argument(
114        "--model", "-m", required=True,
115        help=f"The model type. The following models are currently available: {model_names}"
116    )
117    parser.add_argument(
118        "--mask_path", help="The filepath to a tif file with a mask that will be used to restrict the segmentation."
119        "Can also be a directory with tifs if the filestructure matches input_path."
120    )
121    parser.add_argument("--input_key", "-k", required=False)
122    parser.add_argument(
123        "--force", action="store_true",
124        help="Whether to over-write already present segmentation results."
125    )
126    parser.add_argument(
127        "--tile_shape", type=int, nargs=3,
128        help="The tile shape for prediction, in ZYX order. Lower the tile shape if GPU memory is insufficient."
129    )
130    parser.add_argument(
131        "--halo", type=int, nargs=3,
132        help="The halo for prediction, in ZYX order. Increase the halo to minimize boundary artifacts."
133    )
134    parser.add_argument(
135        "--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc."
136    )
137    parser.add_argument(
138        "--checkpoint", "-c", help="Path to a custom model, e.g. from domain adaptation.",
139    )
140    parser.add_argument(
141        "--segmentation_key", "-s",
142        help="If given, the outputs will be saved to an hdf5 file with this key. Otherwise they will be saved as tif.",
143    )
144    parser.add_argument(
145        "--scale", type=float,
146        help="The factor for rescaling the data before inference. "
147        "By default, the scaling factor will be derived from the voxel size of the input data. "
148        "If this parameter is given it will over-ride the default behavior. "
149    )
150    args = parser.parse_args()
151
152    if args.checkpoint is None:
153        model = get_model(args.model)
154    else:
155        model = torch.load(args.checkpoint, weights_only=False)
156        assert model is not None, f"The model from {args.checkpoint} could not be loaded."
157
158    is_2d = "2d" in args.model
159    tiling = parse_tiling(args.tile_shape, args.halo, is_2d=is_2d)
160
161    # If the scale argument is not passed, then we get the average training resolution for the model.
162    # The inputs will then be scaled to match this resolution based on the voxel size from the mrc files.
163    if args.scale is None:
164        model_resolution = get_model_training_resolution(args.model)
165        model_resolution = tuple(model_resolution[ax] for ax in ("yx" if is_2d else "zyx"))
166        scale = None
167    # Otherwise, we set the model resolution to None and use the scaling factor provided by the user.
168    else:
169        model_resolution = None
170        scale = (2 if is_2d else 3) * (args.scale,)
171
172    segmentation_function = partial(
173        run_segmentation, model=model, model_type=args.model, verbose=False, tiling=tiling,
174    )
175    inference_helper(
176        args.input_path, args.output_path, segmentation_function,
177        mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext,
178        output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale,
179    )