synapse_net.tools.cli

  1import argparse
  2import os
  3from functools import partial
  4
  5import torch
  6import torch_em
  7from ..imod.to_imod import export_helper, write_segmentation_to_imod_as_points, write_segmentation_to_imod
  8from ..inference.inference import _get_model_registry, get_model, get_model_training_resolution, run_segmentation
  9from ..inference.scalable_segmentation import scalable_segmentation
 10from ..inference.util import inference_helper, parse_tiling
 11from .pool_visualization import _visualize_vesicle_pools
 12
 13
 14def imod_point_cli():
 15    parser = argparse.ArgumentParser(
 16        description="Convert a vesicle segmentation to an IMOD point model, "
 17        "corresponding to a sphere for each vesicle in the segmentation."
 18    )
 19    parser.add_argument(
 20        "--input_path", "-i", required=True,
 21        help="The filepath to the mrc file or the directory containing the tomogram data."
 22    )
 23    parser.add_argument(
 24        "--segmentation_path", "-s", required=True,
 25        help="The filepath to the file or the directory containing the segmentations."
 26    )
 27    parser.add_argument(
 28        "--output_path", "-o", required=True,
 29        help="The filepath to directory where the segmentations will be saved."
 30    )
 31    parser.add_argument(
 32        "--segmentation_key", "-k",
 33        help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif."
 34        "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset."
 35    )
 36    parser.add_argument(
 37        "--min_radius", type=float, default=10.0,
 38        help="The minimum vesicle radius in nm. Objects that are smaller than this radius will be exclded from the export."  # noqa
 39    )
 40    parser.add_argument(
 41        "--radius_factor", type=float, default=1.0,
 42        help="A factor for scaling the sphere radius for the export. "
 43        "This can be used to fit the size of segmented vesicles to the best matching spheres.",
 44    )
 45    parser.add_argument(
 46        "--force", action="store_true",
 47        help="Whether to over-write already present export results."
 48    )
 49    args = parser.parse_args()
 50
 51    export_function = partial(
 52        write_segmentation_to_imod_as_points,
 53        min_radius=args.min_radius,
 54        radius_factor=args.radius_factor,
 55    )
 56
 57    export_helper(
 58        input_path=args.input_path,
 59        segmentation_path=args.segmentation_path,
 60        output_root=args.output_path,
 61        export_function=export_function,
 62        force=args.force,
 63        segmentation_key=args.segmentation_key,
 64    )
 65
 66
 67def imod_object_cli():
 68    parser = argparse.ArgumentParser(
 69        description="Convert segmented objects to close contour IMOD models."
 70    )
 71    parser.add_argument(
 72        "--input_path", "-i", required=True,
 73        help="The filepath to the mrc file or the directory containing the tomogram data."
 74    )
 75    parser.add_argument(
 76        "--segmentation_path", "-s", required=True,
 77        help="The filepath to the file or the directory containing the segmentations."
 78    )
 79    parser.add_argument(
 80        "--output_path", "-o", required=True,
 81        help="The filepath to directory where the segmentations will be saved."
 82    )
 83    parser.add_argument(
 84        "--segmentation_key", "-k",
 85        help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif."
 86        "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset."
 87    )
 88    parser.add_argument(
 89        "--force", action="store_true",
 90        help="Whether to over-write already present export results."
 91    )
 92    args = parser.parse_args()
 93    export_helper(
 94        input_path=args.input_path,
 95        segmentation_path=args.segmentation_path,
 96        output_root=args.output_path,
 97        export_function=write_segmentation_to_imod,
 98        force=args.force,
 99        segmentation_key=args.segmentation_key,
100    )
101
102
103def pool_visualization_cli():
104    parser = argparse.ArgumentParser(description="Load tomogram data, vesicle pools and additional segmentations for viualization.")  # noqa
105    parser.add_argument(
106        "--input_path", "-i", required=True,
107        help="The filepath to the mrc file containing the tomogram data."
108    )
109    parser.add_argument(
110        "--vesicle_paths", "-v", required=True, nargs="+",
111        help="The filepath(s) to the tif file(s) containing the vesicle segmentation."
112    )
113    parser.add_argument(
114        "--table_paths", "-t", required=True, nargs="+",
115        help="The filepath(s) to the table(s) with the vesicle pool assignments."
116    )
117    parser.add_argument(
118        "-s", "--segmentation_paths", nargs="+", help="Filepaths for additional segmentations."
119    )
120    parser.add_argument(
121        "--split_pools", action="store_true", help="Whether to split the pools into individual layers.",
122    )
123    args = parser.parse_args()
124    _visualize_vesicle_pools(
125        args.input_path, args.vesicle_paths, args.table_paths, args.segmentation_paths, args.split_pools
126    )
127
128
129# TODO: handle kwargs
130def segmentation_cli():
131    parser = argparse.ArgumentParser(description="Run segmentation.")
132    parser.add_argument(
133        "--input_path", "-i", required=True,
134        help="The filepath to the mrc file or the directory containing the tomogram data."
135    )
136    parser.add_argument(
137        "--output_path", "-o", required=True,
138        help="The filepath to directory where the segmentations will be saved."
139    )
140    model_names = list(_get_model_registry().urls.keys())
141    model_names = ", ".join(model_names)
142    parser.add_argument(
143        "--model", "-m", required=True,
144        help=f"The model type. The following models are currently available: {model_names}"
145    )
146    parser.add_argument(
147        "--mask_path", help="The filepath to a tif file with a mask that will be used to restrict the segmentation."
148        "Can also be a directory with tifs if the filestructure matches input_path."
149    )
150    parser.add_argument("--input_key", "-k", required=False)
151    parser.add_argument(
152        "--force", action="store_true",
153        help="Whether to over-write already present segmentation results."
154    )
155    parser.add_argument(
156        "--tile_shape", type=int, nargs=3,
157        help="The tile shape for prediction, in ZYX order. Lower the tile shape if GPU memory is insufficient."
158    )
159    parser.add_argument(
160        "--halo", type=int, nargs=3,
161        help="The halo for prediction, in ZYX order. Increase the halo to minimize boundary artifacts."
162    )
163    parser.add_argument(
164        "--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc."
165    )
166    parser.add_argument(
167        "--checkpoint", "-c", help="Path to a custom model, e.g. from domain adaptation.",
168    )
169    parser.add_argument(
170        "--segmentation_key", "-s",
171        help="If given, the outputs will be saved to an hdf5 file with this key. Otherwise they will be saved as tif.",
172    )
173    parser.add_argument(
174        "--scale", type=float,
175        help="The factor for rescaling the data before inference. "
176        "By default, the scaling factor will be derived from the voxel size of the input data. "
177        "If this parameter is given it will over-ride the default behavior. "
178    )
179    parser.add_argument(
180        "--verbose", "-v", action="store_true",
181        help="Whether to print verbose information about the segmentation progress."
182    )
183    parser.add_argument(
184        "--scalable", action="store_true", help="Use the scalable segmentation implementation. "
185        "Currently this only works for vesicles, mitochondria, or active zones."
186    )
187    args = parser.parse_args()
188
189    if args.checkpoint is None:
190        model = get_model(args.model)
191    else:
192        checkpoint_path = args.checkpoint
193        if checkpoint_path.endswith("best.pt"):
194            checkpoint_path = os.path.split(checkpoint_path)[0]
195
196        if os.path.isdir(checkpoint_path):  # Load the model from a torch_em checkpoint.
197            model = torch_em.util.load_model(checkpoint=checkpoint_path)
198        else:
199            model = torch.load(checkpoint_path, weights_only=False)
200        assert model is not None, f"The model from {args.checkpoint} could not be loaded."
201
202    is_2d = "2d" in args.model
203    tiling = parse_tiling(args.tile_shape, args.halo, is_2d=is_2d)
204
205    # If the scale argument is not passed, then we get the average training resolution for the model.
206    # The inputs will then be scaled to match this resolution based on the voxel size from the mrc files.
207    if args.scale is None:
208        model_resolution = get_model_training_resolution(args.model)
209        model_resolution = tuple(model_resolution[ax] for ax in ("yx" if is_2d else "zyx"))
210        scale = None
211    # Otherwise, we set the model resolution to None and use the scaling factor provided by the user.
212    else:
213        model_resolution = None
214        scale = (2 if is_2d else 3) * (args.scale,)
215
216    if args.scalable:
217        if not args.model.startswith(("vesicle", "mito", "active")):
218            raise ValueError(
219                "The scalable segmentation implementation is currently only supported for "
220                f"vesicles, mitochondria, or active zones, not for {args.model}."
221            )
222        segmentation_function = partial(
223            scalable_segmentation, model=model, tiling=tiling, verbose=args.verbose
224        )
225        allocate_output = True
226
227    else:
228        segmentation_function = partial(
229            run_segmentation, model=model, model_type=args.model, verbose=args.verbose, tiling=tiling,
230        )
231        allocate_output = False
232
233    inference_helper(
234        args.input_path, args.output_path, segmentation_function,
235        mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext,
236        output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale,
237        allocate_output=allocate_output
238    )
def imod_point_cli():
15def imod_point_cli():
16    parser = argparse.ArgumentParser(
17        description="Convert a vesicle segmentation to an IMOD point model, "
18        "corresponding to a sphere for each vesicle in the segmentation."
19    )
20    parser.add_argument(
21        "--input_path", "-i", required=True,
22        help="The filepath to the mrc file or the directory containing the tomogram data."
23    )
24    parser.add_argument(
25        "--segmentation_path", "-s", required=True,
26        help="The filepath to the file or the directory containing the segmentations."
27    )
28    parser.add_argument(
29        "--output_path", "-o", required=True,
30        help="The filepath to directory where the segmentations will be saved."
31    )
32    parser.add_argument(
33        "--segmentation_key", "-k",
34        help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif."
35        "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset."
36    )
37    parser.add_argument(
38        "--min_radius", type=float, default=10.0,
39        help="The minimum vesicle radius in nm. Objects that are smaller than this radius will be exclded from the export."  # noqa
40    )
41    parser.add_argument(
42        "--radius_factor", type=float, default=1.0,
43        help="A factor for scaling the sphere radius for the export. "
44        "This can be used to fit the size of segmented vesicles to the best matching spheres.",
45    )
46    parser.add_argument(
47        "--force", action="store_true",
48        help="Whether to over-write already present export results."
49    )
50    args = parser.parse_args()
51
52    export_function = partial(
53        write_segmentation_to_imod_as_points,
54        min_radius=args.min_radius,
55        radius_factor=args.radius_factor,
56    )
57
58    export_helper(
59        input_path=args.input_path,
60        segmentation_path=args.segmentation_path,
61        output_root=args.output_path,
62        export_function=export_function,
63        force=args.force,
64        segmentation_key=args.segmentation_key,
65    )
def imod_object_cli():
 68def imod_object_cli():
 69    parser = argparse.ArgumentParser(
 70        description="Convert segmented objects to close contour IMOD models."
 71    )
 72    parser.add_argument(
 73        "--input_path", "-i", required=True,
 74        help="The filepath to the mrc file or the directory containing the tomogram data."
 75    )
 76    parser.add_argument(
 77        "--segmentation_path", "-s", required=True,
 78        help="The filepath to the file or the directory containing the segmentations."
 79    )
 80    parser.add_argument(
 81        "--output_path", "-o", required=True,
 82        help="The filepath to directory where the segmentations will be saved."
 83    )
 84    parser.add_argument(
 85        "--segmentation_key", "-k",
 86        help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif."
 87        "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset."
 88    )
 89    parser.add_argument(
 90        "--force", action="store_true",
 91        help="Whether to over-write already present export results."
 92    )
 93    args = parser.parse_args()
 94    export_helper(
 95        input_path=args.input_path,
 96        segmentation_path=args.segmentation_path,
 97        output_root=args.output_path,
 98        export_function=write_segmentation_to_imod,
 99        force=args.force,
100        segmentation_key=args.segmentation_key,
101    )
def pool_visualization_cli():
104def pool_visualization_cli():
105    parser = argparse.ArgumentParser(description="Load tomogram data, vesicle pools and additional segmentations for viualization.")  # noqa
106    parser.add_argument(
107        "--input_path", "-i", required=True,
108        help="The filepath to the mrc file containing the tomogram data."
109    )
110    parser.add_argument(
111        "--vesicle_paths", "-v", required=True, nargs="+",
112        help="The filepath(s) to the tif file(s) containing the vesicle segmentation."
113    )
114    parser.add_argument(
115        "--table_paths", "-t", required=True, nargs="+",
116        help="The filepath(s) to the table(s) with the vesicle pool assignments."
117    )
118    parser.add_argument(
119        "-s", "--segmentation_paths", nargs="+", help="Filepaths for additional segmentations."
120    )
121    parser.add_argument(
122        "--split_pools", action="store_true", help="Whether to split the pools into individual layers.",
123    )
124    args = parser.parse_args()
125    _visualize_vesicle_pools(
126        args.input_path, args.vesicle_paths, args.table_paths, args.segmentation_paths, args.split_pools
127    )
def segmentation_cli():
131def segmentation_cli():
132    parser = argparse.ArgumentParser(description="Run segmentation.")
133    parser.add_argument(
134        "--input_path", "-i", required=True,
135        help="The filepath to the mrc file or the directory containing the tomogram data."
136    )
137    parser.add_argument(
138        "--output_path", "-o", required=True,
139        help="The filepath to directory where the segmentations will be saved."
140    )
141    model_names = list(_get_model_registry().urls.keys())
142    model_names = ", ".join(model_names)
143    parser.add_argument(
144        "--model", "-m", required=True,
145        help=f"The model type. The following models are currently available: {model_names}"
146    )
147    parser.add_argument(
148        "--mask_path", help="The filepath to a tif file with a mask that will be used to restrict the segmentation."
149        "Can also be a directory with tifs if the filestructure matches input_path."
150    )
151    parser.add_argument("--input_key", "-k", required=False)
152    parser.add_argument(
153        "--force", action="store_true",
154        help="Whether to over-write already present segmentation results."
155    )
156    parser.add_argument(
157        "--tile_shape", type=int, nargs=3,
158        help="The tile shape for prediction, in ZYX order. Lower the tile shape if GPU memory is insufficient."
159    )
160    parser.add_argument(
161        "--halo", type=int, nargs=3,
162        help="The halo for prediction, in ZYX order. Increase the halo to minimize boundary artifacts."
163    )
164    parser.add_argument(
165        "--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc."
166    )
167    parser.add_argument(
168        "--checkpoint", "-c", help="Path to a custom model, e.g. from domain adaptation.",
169    )
170    parser.add_argument(
171        "--segmentation_key", "-s",
172        help="If given, the outputs will be saved to an hdf5 file with this key. Otherwise they will be saved as tif.",
173    )
174    parser.add_argument(
175        "--scale", type=float,
176        help="The factor for rescaling the data before inference. "
177        "By default, the scaling factor will be derived from the voxel size of the input data. "
178        "If this parameter is given it will over-ride the default behavior. "
179    )
180    parser.add_argument(
181        "--verbose", "-v", action="store_true",
182        help="Whether to print verbose information about the segmentation progress."
183    )
184    parser.add_argument(
185        "--scalable", action="store_true", help="Use the scalable segmentation implementation. "
186        "Currently this only works for vesicles, mitochondria, or active zones."
187    )
188    args = parser.parse_args()
189
190    if args.checkpoint is None:
191        model = get_model(args.model)
192    else:
193        checkpoint_path = args.checkpoint
194        if checkpoint_path.endswith("best.pt"):
195            checkpoint_path = os.path.split(checkpoint_path)[0]
196
197        if os.path.isdir(checkpoint_path):  # Load the model from a torch_em checkpoint.
198            model = torch_em.util.load_model(checkpoint=checkpoint_path)
199        else:
200            model = torch.load(checkpoint_path, weights_only=False)
201        assert model is not None, f"The model from {args.checkpoint} could not be loaded."
202
203    is_2d = "2d" in args.model
204    tiling = parse_tiling(args.tile_shape, args.halo, is_2d=is_2d)
205
206    # If the scale argument is not passed, then we get the average training resolution for the model.
207    # The inputs will then be scaled to match this resolution based on the voxel size from the mrc files.
208    if args.scale is None:
209        model_resolution = get_model_training_resolution(args.model)
210        model_resolution = tuple(model_resolution[ax] for ax in ("yx" if is_2d else "zyx"))
211        scale = None
212    # Otherwise, we set the model resolution to None and use the scaling factor provided by the user.
213    else:
214        model_resolution = None
215        scale = (2 if is_2d else 3) * (args.scale,)
216
217    if args.scalable:
218        if not args.model.startswith(("vesicle", "mito", "active")):
219            raise ValueError(
220                "The scalable segmentation implementation is currently only supported for "
221                f"vesicles, mitochondria, or active zones, not for {args.model}."
222            )
223        segmentation_function = partial(
224            scalable_segmentation, model=model, tiling=tiling, verbose=args.verbose
225        )
226        allocate_output = True
227
228    else:
229        segmentation_function = partial(
230            run_segmentation, model=model, model_type=args.model, verbose=args.verbose, tiling=tiling,
231        )
232        allocate_output = False
233
234    inference_helper(
235        args.input_path, args.output_path, segmentation_function,
236        mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext,
237        output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale,
238        allocate_output=allocate_output
239    )