synapse_net.tools.cli

  1import argparse
  2import os
  3from functools import partial
  4
  5import torch
  6import torch_em
  7from ..imod.to_imod import export_helper, write_segmentation_to_imod_as_points, write_segmentation_to_imod
  8from ..inference.inference import _get_model_registry, get_model, get_model_training_resolution, run_segmentation
  9from ..inference.util import inference_helper, parse_tiling
 10
 11
 12def imod_point_cli():
 13    parser = argparse.ArgumentParser(
 14        description="Convert a vesicle segmentation to an IMOD point model, "
 15        "corresponding to a sphere for each vesicle in the segmentation."
 16    )
 17    parser.add_argument(
 18        "--input_path", "-i", required=True,
 19        help="The filepath to the mrc file or the directory containing the tomogram data."
 20    )
 21    parser.add_argument(
 22        "--segmentation_path", "-s", required=True,
 23        help="The filepath to the file or the directory containing the segmentations."
 24    )
 25    parser.add_argument(
 26        "--output_path", "-o", required=True,
 27        help="The filepath to directory where the segmentations will be saved."
 28    )
 29    parser.add_argument(
 30        "--segmentation_key", "-k",
 31        help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif."
 32        "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset."
 33    )
 34    parser.add_argument(
 35        "--min_radius", type=float, default=10.0,
 36        help="The minimum vesicle radius in nm. Objects that are smaller than this radius will be exclded from the export."  # noqa
 37    )
 38    parser.add_argument(
 39        "--radius_factor", type=float, default=1.0,
 40        help="A factor for scaling the sphere radius for the export. "
 41        "This can be used to fit the size of segmented vesicles to the best matching spheres.",
 42    )
 43    parser.add_argument(
 44        "--force", action="store_true",
 45        help="Whether to over-write already present export results."
 46    )
 47    args = parser.parse_args()
 48
 49    export_function = partial(
 50        write_segmentation_to_imod_as_points,
 51        min_radius=args.min_radius,
 52        radius_factor=args.radius_factor,
 53    )
 54
 55    export_helper(
 56        input_path=args.input_path,
 57        segmentation_path=args.segmentation_path,
 58        output_root=args.output_path,
 59        export_function=export_function,
 60        force=args.force,
 61        segmentation_key=args.segmentation_key,
 62    )
 63
 64
 65def imod_object_cli():
 66    parser = argparse.ArgumentParser(
 67        description="Convert segmented objects to close contour IMOD models."
 68    )
 69    parser.add_argument(
 70        "--input_path", "-i", required=True,
 71        help="The filepath to the mrc file or the directory containing the tomogram data."
 72    )
 73    parser.add_argument(
 74        "--segmentation_path", "-s", required=True,
 75        help="The filepath to the file or the directory containing the segmentations."
 76    )
 77    parser.add_argument(
 78        "--output_path", "-o", required=True,
 79        help="The filepath to directory where the segmentations will be saved."
 80    )
 81    parser.add_argument(
 82        "--segmentation_key", "-k",
 83        help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif."
 84        "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset."
 85    )
 86    parser.add_argument(
 87        "--force", action="store_true",
 88        help="Whether to over-write already present export results."
 89    )
 90    args = parser.parse_args()
 91    export_helper(
 92        input_path=args.input_path,
 93        segmentation_path=args.segmentation_path,
 94        output_root=args.output_path,
 95        export_function=write_segmentation_to_imod,
 96        force=args.force,
 97        segmentation_key=args.segmentation_key,
 98    )
 99
100
101# TODO: handle kwargs
102def segmentation_cli():
103    parser = argparse.ArgumentParser(description="Run segmentation.")
104    parser.add_argument(
105        "--input_path", "-i", required=True,
106        help="The filepath to the mrc file or the directory containing the tomogram data."
107    )
108    parser.add_argument(
109        "--output_path", "-o", required=True,
110        help="The filepath to directory where the segmentations will be saved."
111    )
112    model_names = list(_get_model_registry().urls.keys())
113    model_names = ", ".join(model_names)
114    parser.add_argument(
115        "--model", "-m", required=True,
116        help=f"The model type. The following models are currently available: {model_names}"
117    )
118    parser.add_argument(
119        "--mask_path", help="The filepath to a tif file with a mask that will be used to restrict the segmentation."
120        "Can also be a directory with tifs if the filestructure matches input_path."
121    )
122    parser.add_argument("--input_key", "-k", required=False)
123    parser.add_argument(
124        "--force", action="store_true",
125        help="Whether to over-write already present segmentation results."
126    )
127    parser.add_argument(
128        "--tile_shape", type=int, nargs=3,
129        help="The tile shape for prediction, in ZYX order. Lower the tile shape if GPU memory is insufficient."
130    )
131    parser.add_argument(
132        "--halo", type=int, nargs=3,
133        help="The halo for prediction, in ZYX order. Increase the halo to minimize boundary artifacts."
134    )
135    parser.add_argument(
136        "--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc."
137    )
138    parser.add_argument(
139        "--checkpoint", "-c", help="Path to a custom model, e.g. from domain adaptation.",
140    )
141    parser.add_argument(
142        "--segmentation_key", "-s",
143        help="If given, the outputs will be saved to an hdf5 file with this key. Otherwise they will be saved as tif.",
144    )
145    parser.add_argument(
146        "--scale", type=float,
147        help="The factor for rescaling the data before inference. "
148        "By default, the scaling factor will be derived from the voxel size of the input data. "
149        "If this parameter is given it will over-ride the default behavior. "
150    )
151    parser.add_argument(
152        "--verbose", "-v", action="store_true",
153        help="Whether to print verbose information about the segmentation progress."
154    )
155    args = parser.parse_args()
156
157    if args.checkpoint is None:
158        model = get_model(args.model)
159    else:
160        checkpoint_path = args.checkpoint
161        if checkpoint_path.endswith("best.pt"):
162            checkpoint_path = os.path.split(checkpoint_path)[0]
163
164        if os.path.isdir(checkpoint_path):  # Load the model from a torch_em checkpoint.
165            model = torch_em.util.load_model(checkpoint=checkpoint_path)
166        else:
167            model = torch.load(checkpoint_path, weights_only=False)
168        assert model is not None, f"The model from {args.checkpoint} could not be loaded."
169
170    is_2d = "2d" in args.model
171    tiling = parse_tiling(args.tile_shape, args.halo, is_2d=is_2d)
172
173    # If the scale argument is not passed, then we get the average training resolution for the model.
174    # The inputs will then be scaled to match this resolution based on the voxel size from the mrc files.
175    if args.scale is None:
176        model_resolution = get_model_training_resolution(args.model)
177        model_resolution = tuple(model_resolution[ax] for ax in ("yx" if is_2d else "zyx"))
178        scale = None
179    # Otherwise, we set the model resolution to None and use the scaling factor provided by the user.
180    else:
181        model_resolution = None
182        scale = (2 if is_2d else 3) * (args.scale,)
183
184    segmentation_function = partial(
185        run_segmentation, model=model, model_type=args.model, verbose=args.verbose, tiling=tiling,
186    )
187    inference_helper(
188        args.input_path, args.output_path, segmentation_function,
189        mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext,
190        output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale,
191    )
def imod_point_cli():
13def imod_point_cli():
14    parser = argparse.ArgumentParser(
15        description="Convert a vesicle segmentation to an IMOD point model, "
16        "corresponding to a sphere for each vesicle in the segmentation."
17    )
18    parser.add_argument(
19        "--input_path", "-i", required=True,
20        help="The filepath to the mrc file or the directory containing the tomogram data."
21    )
22    parser.add_argument(
23        "--segmentation_path", "-s", required=True,
24        help="The filepath to the file or the directory containing the segmentations."
25    )
26    parser.add_argument(
27        "--output_path", "-o", required=True,
28        help="The filepath to directory where the segmentations will be saved."
29    )
30    parser.add_argument(
31        "--segmentation_key", "-k",
32        help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif."
33        "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset."
34    )
35    parser.add_argument(
36        "--min_radius", type=float, default=10.0,
37        help="The minimum vesicle radius in nm. Objects that are smaller than this radius will be exclded from the export."  # noqa
38    )
39    parser.add_argument(
40        "--radius_factor", type=float, default=1.0,
41        help="A factor for scaling the sphere radius for the export. "
42        "This can be used to fit the size of segmented vesicles to the best matching spheres.",
43    )
44    parser.add_argument(
45        "--force", action="store_true",
46        help="Whether to over-write already present export results."
47    )
48    args = parser.parse_args()
49
50    export_function = partial(
51        write_segmentation_to_imod_as_points,
52        min_radius=args.min_radius,
53        radius_factor=args.radius_factor,
54    )
55
56    export_helper(
57        input_path=args.input_path,
58        segmentation_path=args.segmentation_path,
59        output_root=args.output_path,
60        export_function=export_function,
61        force=args.force,
62        segmentation_key=args.segmentation_key,
63    )
def imod_object_cli():
66def imod_object_cli():
67    parser = argparse.ArgumentParser(
68        description="Convert segmented objects to close contour IMOD models."
69    )
70    parser.add_argument(
71        "--input_path", "-i", required=True,
72        help="The filepath to the mrc file or the directory containing the tomogram data."
73    )
74    parser.add_argument(
75        "--segmentation_path", "-s", required=True,
76        help="The filepath to the file or the directory containing the segmentations."
77    )
78    parser.add_argument(
79        "--output_path", "-o", required=True,
80        help="The filepath to directory where the segmentations will be saved."
81    )
82    parser.add_argument(
83        "--segmentation_key", "-k",
84        help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif."
85        "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset."
86    )
87    parser.add_argument(
88        "--force", action="store_true",
89        help="Whether to over-write already present export results."
90    )
91    args = parser.parse_args()
92    export_helper(
93        input_path=args.input_path,
94        segmentation_path=args.segmentation_path,
95        output_root=args.output_path,
96        export_function=write_segmentation_to_imod,
97        force=args.force,
98        segmentation_key=args.segmentation_key,
99    )
def segmentation_cli():
103def segmentation_cli():
104    parser = argparse.ArgumentParser(description="Run segmentation.")
105    parser.add_argument(
106        "--input_path", "-i", required=True,
107        help="The filepath to the mrc file or the directory containing the tomogram data."
108    )
109    parser.add_argument(
110        "--output_path", "-o", required=True,
111        help="The filepath to directory where the segmentations will be saved."
112    )
113    model_names = list(_get_model_registry().urls.keys())
114    model_names = ", ".join(model_names)
115    parser.add_argument(
116        "--model", "-m", required=True,
117        help=f"The model type. The following models are currently available: {model_names}"
118    )
119    parser.add_argument(
120        "--mask_path", help="The filepath to a tif file with a mask that will be used to restrict the segmentation."
121        "Can also be a directory with tifs if the filestructure matches input_path."
122    )
123    parser.add_argument("--input_key", "-k", required=False)
124    parser.add_argument(
125        "--force", action="store_true",
126        help="Whether to over-write already present segmentation results."
127    )
128    parser.add_argument(
129        "--tile_shape", type=int, nargs=3,
130        help="The tile shape for prediction, in ZYX order. Lower the tile shape if GPU memory is insufficient."
131    )
132    parser.add_argument(
133        "--halo", type=int, nargs=3,
134        help="The halo for prediction, in ZYX order. Increase the halo to minimize boundary artifacts."
135    )
136    parser.add_argument(
137        "--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc."
138    )
139    parser.add_argument(
140        "--checkpoint", "-c", help="Path to a custom model, e.g. from domain adaptation.",
141    )
142    parser.add_argument(
143        "--segmentation_key", "-s",
144        help="If given, the outputs will be saved to an hdf5 file with this key. Otherwise they will be saved as tif.",
145    )
146    parser.add_argument(
147        "--scale", type=float,
148        help="The factor for rescaling the data before inference. "
149        "By default, the scaling factor will be derived from the voxel size of the input data. "
150        "If this parameter is given it will over-ride the default behavior. "
151    )
152    parser.add_argument(
153        "--verbose", "-v", action="store_true",
154        help="Whether to print verbose information about the segmentation progress."
155    )
156    args = parser.parse_args()
157
158    if args.checkpoint is None:
159        model = get_model(args.model)
160    else:
161        checkpoint_path = args.checkpoint
162        if checkpoint_path.endswith("best.pt"):
163            checkpoint_path = os.path.split(checkpoint_path)[0]
164
165        if os.path.isdir(checkpoint_path):  # Load the model from a torch_em checkpoint.
166            model = torch_em.util.load_model(checkpoint=checkpoint_path)
167        else:
168            model = torch.load(checkpoint_path, weights_only=False)
169        assert model is not None, f"The model from {args.checkpoint} could not be loaded."
170
171    is_2d = "2d" in args.model
172    tiling = parse_tiling(args.tile_shape, args.halo, is_2d=is_2d)
173
174    # If the scale argument is not passed, then we get the average training resolution for the model.
175    # The inputs will then be scaled to match this resolution based on the voxel size from the mrc files.
176    if args.scale is None:
177        model_resolution = get_model_training_resolution(args.model)
178        model_resolution = tuple(model_resolution[ax] for ax in ("yx" if is_2d else "zyx"))
179        scale = None
180    # Otherwise, we set the model resolution to None and use the scaling factor provided by the user.
181    else:
182        model_resolution = None
183        scale = (2 if is_2d else 3) * (args.scale,)
184
185    segmentation_function = partial(
186        run_segmentation, model=model, model_type=args.model, verbose=args.verbose, tiling=tiling,
187    )
188    inference_helper(
189        args.input_path, args.output_path, segmentation_function,
190        mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext,
191        output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale,
192    )