synapse_net.tools.cli

  1import argparse
  2import os
  3from functools import partial
  4
  5import torch
  6import torch_em
  7from ..imod.to_imod import export_helper, write_segmentation_to_imod_as_points, write_segmentation_to_imod
  8from ..inference.inference import _get_model_registry, get_model, get_model_training_resolution, run_segmentation
  9from ..inference.scalable_segmentation import scalable_segmentation
 10from ..inference.util import inference_helper, parse_tiling
 11
 12
 13def imod_point_cli():
 14    parser = argparse.ArgumentParser(
 15        description="Convert a vesicle segmentation to an IMOD point model, "
 16        "corresponding to a sphere for each vesicle in the segmentation."
 17    )
 18    parser.add_argument(
 19        "--input_path", "-i", required=True,
 20        help="The filepath to the mrc file or the directory containing the tomogram data."
 21    )
 22    parser.add_argument(
 23        "--segmentation_path", "-s", required=True,
 24        help="The filepath to the file or the directory containing the segmentations."
 25    )
 26    parser.add_argument(
 27        "--output_path", "-o", required=True,
 28        help="The filepath to directory where the segmentations will be saved."
 29    )
 30    parser.add_argument(
 31        "--segmentation_key", "-k",
 32        help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif."
 33        "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset."
 34    )
 35    parser.add_argument(
 36        "--min_radius", type=float, default=10.0,
 37        help="The minimum vesicle radius in nm. Objects that are smaller than this radius will be exclded from the export."  # noqa
 38    )
 39    parser.add_argument(
 40        "--radius_factor", type=float, default=1.0,
 41        help="A factor for scaling the sphere radius for the export. "
 42        "This can be used to fit the size of segmented vesicles to the best matching spheres.",
 43    )
 44    parser.add_argument(
 45        "--force", action="store_true",
 46        help="Whether to over-write already present export results."
 47    )
 48    args = parser.parse_args()
 49
 50    export_function = partial(
 51        write_segmentation_to_imod_as_points,
 52        min_radius=args.min_radius,
 53        radius_factor=args.radius_factor,
 54    )
 55
 56    export_helper(
 57        input_path=args.input_path,
 58        segmentation_path=args.segmentation_path,
 59        output_root=args.output_path,
 60        export_function=export_function,
 61        force=args.force,
 62        segmentation_key=args.segmentation_key,
 63    )
 64
 65
 66def imod_object_cli():
 67    parser = argparse.ArgumentParser(
 68        description="Convert segmented objects to close contour IMOD models."
 69    )
 70    parser.add_argument(
 71        "--input_path", "-i", required=True,
 72        help="The filepath to the mrc file or the directory containing the tomogram data."
 73    )
 74    parser.add_argument(
 75        "--segmentation_path", "-s", required=True,
 76        help="The filepath to the file or the directory containing the segmentations."
 77    )
 78    parser.add_argument(
 79        "--output_path", "-o", required=True,
 80        help="The filepath to directory where the segmentations will be saved."
 81    )
 82    parser.add_argument(
 83        "--segmentation_key", "-k",
 84        help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif."
 85        "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset."
 86    )
 87    parser.add_argument(
 88        "--force", action="store_true",
 89        help="Whether to over-write already present export results."
 90    )
 91    args = parser.parse_args()
 92    export_helper(
 93        input_path=args.input_path,
 94        segmentation_path=args.segmentation_path,
 95        output_root=args.output_path,
 96        export_function=write_segmentation_to_imod,
 97        force=args.force,
 98        segmentation_key=args.segmentation_key,
 99    )
100
101
102# TODO: handle kwargs
103def segmentation_cli():
104    parser = argparse.ArgumentParser(description="Run segmentation.")
105    parser.add_argument(
106        "--input_path", "-i", required=True,
107        help="The filepath to the mrc file or the directory containing the tomogram data."
108    )
109    parser.add_argument(
110        "--output_path", "-o", required=True,
111        help="The filepath to directory where the segmentations will be saved."
112    )
113    model_names = list(_get_model_registry().urls.keys())
114    model_names = ", ".join(model_names)
115    parser.add_argument(
116        "--model", "-m", required=True,
117        help=f"The model type. The following models are currently available: {model_names}"
118    )
119    parser.add_argument(
120        "--mask_path", help="The filepath to a tif file with a mask that will be used to restrict the segmentation."
121        "Can also be a directory with tifs if the filestructure matches input_path."
122    )
123    parser.add_argument("--input_key", "-k", required=False)
124    parser.add_argument(
125        "--force", action="store_true",
126        help="Whether to over-write already present segmentation results."
127    )
128    parser.add_argument(
129        "--tile_shape", type=int, nargs=3,
130        help="The tile shape for prediction, in ZYX order. Lower the tile shape if GPU memory is insufficient."
131    )
132    parser.add_argument(
133        "--halo", type=int, nargs=3,
134        help="The halo for prediction, in ZYX order. Increase the halo to minimize boundary artifacts."
135    )
136    parser.add_argument(
137        "--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc."
138    )
139    parser.add_argument(
140        "--checkpoint", "-c", help="Path to a custom model, e.g. from domain adaptation.",
141    )
142    parser.add_argument(
143        "--segmentation_key", "-s",
144        help="If given, the outputs will be saved to an hdf5 file with this key. Otherwise they will be saved as tif.",
145    )
146    parser.add_argument(
147        "--scale", type=float,
148        help="The factor for rescaling the data before inference. "
149        "By default, the scaling factor will be derived from the voxel size of the input data. "
150        "If this parameter is given it will over-ride the default behavior. "
151    )
152    parser.add_argument(
153        "--verbose", "-v", action="store_true",
154        help="Whether to print verbose information about the segmentation progress."
155    )
156    parser.add_argument(
157        "--scalable", action="store_true", help="Use the scalable segmentation implementation. "
158        "Currently this only works for vesicles, mitochondria, or active zones."
159    )
160    args = parser.parse_args()
161
162    if args.checkpoint is None:
163        model = get_model(args.model)
164    else:
165        checkpoint_path = args.checkpoint
166        if checkpoint_path.endswith("best.pt"):
167            checkpoint_path = os.path.split(checkpoint_path)[0]
168
169        if os.path.isdir(checkpoint_path):  # Load the model from a torch_em checkpoint.
170            model = torch_em.util.load_model(checkpoint=checkpoint_path)
171        else:
172            model = torch.load(checkpoint_path, weights_only=False)
173        assert model is not None, f"The model from {args.checkpoint} could not be loaded."
174
175    is_2d = "2d" in args.model
176    tiling = parse_tiling(args.tile_shape, args.halo, is_2d=is_2d)
177
178    # If the scale argument is not passed, then we get the average training resolution for the model.
179    # The inputs will then be scaled to match this resolution based on the voxel size from the mrc files.
180    if args.scale is None:
181        model_resolution = get_model_training_resolution(args.model)
182        model_resolution = tuple(model_resolution[ax] for ax in ("yx" if is_2d else "zyx"))
183        scale = None
184    # Otherwise, we set the model resolution to None and use the scaling factor provided by the user.
185    else:
186        model_resolution = None
187        scale = (2 if is_2d else 3) * (args.scale,)
188
189    if args.scalable:
190        if not args.model.startswith(("vesicle", "mito", "active")):
191            raise ValueError(
192                "The scalable segmentation implementation is currently only supported for "
193                f"vesicles, mitochondria, or active zones, not for {args.model}."
194            )
195        segmentation_function = partial(
196            scalable_segmentation, model=model, tiling=tiling, verbose=args.verbose
197        )
198        allocate_output = True
199
200    else:
201        segmentation_function = partial(
202            run_segmentation, model=model, model_type=args.model, verbose=args.verbose, tiling=tiling,
203        )
204        allocate_output = False
205
206    inference_helper(
207        args.input_path, args.output_path, segmentation_function,
208        mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext,
209        output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale,
210        allocate_output=allocate_output
211    )
def imod_point_cli():
14def imod_point_cli():
15    parser = argparse.ArgumentParser(
16        description="Convert a vesicle segmentation to an IMOD point model, "
17        "corresponding to a sphere for each vesicle in the segmentation."
18    )
19    parser.add_argument(
20        "--input_path", "-i", required=True,
21        help="The filepath to the mrc file or the directory containing the tomogram data."
22    )
23    parser.add_argument(
24        "--segmentation_path", "-s", required=True,
25        help="The filepath to the file or the directory containing the segmentations."
26    )
27    parser.add_argument(
28        "--output_path", "-o", required=True,
29        help="The filepath to directory where the segmentations will be saved."
30    )
31    parser.add_argument(
32        "--segmentation_key", "-k",
33        help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif."
34        "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset."
35    )
36    parser.add_argument(
37        "--min_radius", type=float, default=10.0,
38        help="The minimum vesicle radius in nm. Objects that are smaller than this radius will be exclded from the export."  # noqa
39    )
40    parser.add_argument(
41        "--radius_factor", type=float, default=1.0,
42        help="A factor for scaling the sphere radius for the export. "
43        "This can be used to fit the size of segmented vesicles to the best matching spheres.",
44    )
45    parser.add_argument(
46        "--force", action="store_true",
47        help="Whether to over-write already present export results."
48    )
49    args = parser.parse_args()
50
51    export_function = partial(
52        write_segmentation_to_imod_as_points,
53        min_radius=args.min_radius,
54        radius_factor=args.radius_factor,
55    )
56
57    export_helper(
58        input_path=args.input_path,
59        segmentation_path=args.segmentation_path,
60        output_root=args.output_path,
61        export_function=export_function,
62        force=args.force,
63        segmentation_key=args.segmentation_key,
64    )
def imod_object_cli():
 67def imod_object_cli():
 68    parser = argparse.ArgumentParser(
 69        description="Convert segmented objects to close contour IMOD models."
 70    )
 71    parser.add_argument(
 72        "--input_path", "-i", required=True,
 73        help="The filepath to the mrc file or the directory containing the tomogram data."
 74    )
 75    parser.add_argument(
 76        "--segmentation_path", "-s", required=True,
 77        help="The filepath to the file or the directory containing the segmentations."
 78    )
 79    parser.add_argument(
 80        "--output_path", "-o", required=True,
 81        help="The filepath to directory where the segmentations will be saved."
 82    )
 83    parser.add_argument(
 84        "--segmentation_key", "-k",
 85        help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif."
 86        "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset."
 87    )
 88    parser.add_argument(
 89        "--force", action="store_true",
 90        help="Whether to over-write already present export results."
 91    )
 92    args = parser.parse_args()
 93    export_helper(
 94        input_path=args.input_path,
 95        segmentation_path=args.segmentation_path,
 96        output_root=args.output_path,
 97        export_function=write_segmentation_to_imod,
 98        force=args.force,
 99        segmentation_key=args.segmentation_key,
100    )
def segmentation_cli():
104def segmentation_cli():
105    parser = argparse.ArgumentParser(description="Run segmentation.")
106    parser.add_argument(
107        "--input_path", "-i", required=True,
108        help="The filepath to the mrc file or the directory containing the tomogram data."
109    )
110    parser.add_argument(
111        "--output_path", "-o", required=True,
112        help="The filepath to directory where the segmentations will be saved."
113    )
114    model_names = list(_get_model_registry().urls.keys())
115    model_names = ", ".join(model_names)
116    parser.add_argument(
117        "--model", "-m", required=True,
118        help=f"The model type. The following models are currently available: {model_names}"
119    )
120    parser.add_argument(
121        "--mask_path", help="The filepath to a tif file with a mask that will be used to restrict the segmentation."
122        "Can also be a directory with tifs if the filestructure matches input_path."
123    )
124    parser.add_argument("--input_key", "-k", required=False)
125    parser.add_argument(
126        "--force", action="store_true",
127        help="Whether to over-write already present segmentation results."
128    )
129    parser.add_argument(
130        "--tile_shape", type=int, nargs=3,
131        help="The tile shape for prediction, in ZYX order. Lower the tile shape if GPU memory is insufficient."
132    )
133    parser.add_argument(
134        "--halo", type=int, nargs=3,
135        help="The halo for prediction, in ZYX order. Increase the halo to minimize boundary artifacts."
136    )
137    parser.add_argument(
138        "--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc."
139    )
140    parser.add_argument(
141        "--checkpoint", "-c", help="Path to a custom model, e.g. from domain adaptation.",
142    )
143    parser.add_argument(
144        "--segmentation_key", "-s",
145        help="If given, the outputs will be saved to an hdf5 file with this key. Otherwise they will be saved as tif.",
146    )
147    parser.add_argument(
148        "--scale", type=float,
149        help="The factor for rescaling the data before inference. "
150        "By default, the scaling factor will be derived from the voxel size of the input data. "
151        "If this parameter is given it will over-ride the default behavior. "
152    )
153    parser.add_argument(
154        "--verbose", "-v", action="store_true",
155        help="Whether to print verbose information about the segmentation progress."
156    )
157    parser.add_argument(
158        "--scalable", action="store_true", help="Use the scalable segmentation implementation. "
159        "Currently this only works for vesicles, mitochondria, or active zones."
160    )
161    args = parser.parse_args()
162
163    if args.checkpoint is None:
164        model = get_model(args.model)
165    else:
166        checkpoint_path = args.checkpoint
167        if checkpoint_path.endswith("best.pt"):
168            checkpoint_path = os.path.split(checkpoint_path)[0]
169
170        if os.path.isdir(checkpoint_path):  # Load the model from a torch_em checkpoint.
171            model = torch_em.util.load_model(checkpoint=checkpoint_path)
172        else:
173            model = torch.load(checkpoint_path, weights_only=False)
174        assert model is not None, f"The model from {args.checkpoint} could not be loaded."
175
176    is_2d = "2d" in args.model
177    tiling = parse_tiling(args.tile_shape, args.halo, is_2d=is_2d)
178
179    # If the scale argument is not passed, then we get the average training resolution for the model.
180    # The inputs will then be scaled to match this resolution based on the voxel size from the mrc files.
181    if args.scale is None:
182        model_resolution = get_model_training_resolution(args.model)
183        model_resolution = tuple(model_resolution[ax] for ax in ("yx" if is_2d else "zyx"))
184        scale = None
185    # Otherwise, we set the model resolution to None and use the scaling factor provided by the user.
186    else:
187        model_resolution = None
188        scale = (2 if is_2d else 3) * (args.scale,)
189
190    if args.scalable:
191        if not args.model.startswith(("vesicle", "mito", "active")):
192            raise ValueError(
193                "The scalable segmentation implementation is currently only supported for "
194                f"vesicles, mitochondria, or active zones, not for {args.model}."
195            )
196        segmentation_function = partial(
197            scalable_segmentation, model=model, tiling=tiling, verbose=args.verbose
198        )
199        allocate_output = True
200
201    else:
202        segmentation_function = partial(
203            run_segmentation, model=model, model_type=args.model, verbose=args.verbose, tiling=tiling,
204        )
205        allocate_output = False
206
207    inference_helper(
208        args.input_path, args.output_path, segmentation_function,
209        mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext,
210        output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale,
211        allocate_output=allocate_output
212    )