synapse_net.tools.cli

  1import argparse
  2import os
  3from functools import partial
  4
  5import torch
  6import torch_em
  7from ..imod.to_imod import export_helper, write_segmentation_to_imod_as_points, write_segmentation_to_imod
  8from ..inference.inference import _get_model_registry, get_model, get_model_training_resolution, run_segmentation
  9from ..inference.scalable_segmentation import scalable_segmentation
 10from ..inference.util import inference_helper, parse_tiling
 11from .pool_visualization import _visualize_vesicle_pools
 12
 13
 14def imod_point_cli():
 15    parser = argparse.ArgumentParser(
 16        description="Convert a vesicle segmentation to an IMOD point model, "
 17        "corresponding to a sphere for each vesicle in the segmentation."
 18    )
 19    parser.add_argument(
 20        "--input_path", "-i", required=True,
 21        help="The filepath to the mrc file or the directory containing the tomogram data."
 22    )
 23    parser.add_argument(
 24        "--segmentation_path", "-s", required=True,
 25        help="The filepath to the file or the directory containing the segmentations."
 26    )
 27    parser.add_argument(
 28        "--output_path", "-o", required=True,
 29        help="The filepath to directory where the segmentations will be saved."
 30    )
 31    parser.add_argument(
 32        "--segmentation_key", "-k",
 33        help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif."
 34        "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset."
 35    )
 36    parser.add_argument(
 37        "--min_radius", type=float, default=10.0,
 38        help="The minimum vesicle radius in nm. Objects that are smaller than this radius will be exclded from the export."  # noqa
 39    )
 40    parser.add_argument(
 41        "--radius_factor", type=float, default=1.0,
 42        help="A factor for scaling the sphere radius for the export. "
 43        "This can be used to fit the size of segmented vesicles to the best matching spheres.",
 44    )
 45    parser.add_argument(
 46        "--force", action="store_true",
 47        help="Whether to over-write already present export results."
 48    )
 49    args = parser.parse_args()
 50
 51    export_function = partial(
 52        write_segmentation_to_imod_as_points,
 53        min_radius=args.min_radius,
 54        radius_factor=args.radius_factor,
 55    )
 56
 57    export_helper(
 58        input_path=args.input_path,
 59        segmentation_path=args.segmentation_path,
 60        output_root=args.output_path,
 61        export_function=export_function,
 62        force=args.force,
 63        segmentation_key=args.segmentation_key,
 64    )
 65
 66
 67def imod_object_cli():
 68    parser = argparse.ArgumentParser(
 69        description="Convert segmented objects to close contour IMOD models."
 70    )
 71    parser.add_argument(
 72        "--input_path", "-i", required=True,
 73        help="The filepath to the mrc file or the directory containing the tomogram data."
 74    )
 75    parser.add_argument(
 76        "--segmentation_path", "-s", required=True,
 77        help="The filepath to the file or the directory containing the segmentations."
 78    )
 79    parser.add_argument(
 80        "--output_path", "-o", required=True,
 81        help="The filepath to directory where the segmentations will be saved."
 82    )
 83    parser.add_argument(
 84        "--segmentation_key", "-k",
 85        help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif."
 86        "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset."
 87    )
 88    parser.add_argument(
 89        "--force", action="store_true",
 90        help="Whether to over-write already present export results."
 91    )
 92    args = parser.parse_args()
 93    export_helper(
 94        input_path=args.input_path,
 95        segmentation_path=args.segmentation_path,
 96        output_root=args.output_path,
 97        export_function=write_segmentation_to_imod,
 98        force=args.force,
 99        segmentation_key=args.segmentation_key,
100    )
101
102
103def pool_visualization_cli():
104    parser = argparse.ArgumentParser(description="Load tomogram data, vesicle pools and additional segmentations for viualization.")  # noqa
105    parser.add_argument(
106        "--input_path", "-i", required=True,
107        help="The filepath to the mrc file containing the tomogram data."
108    )
109    parser.add_argument(
110        "--vesicle_paths", "-v", required=True, nargs="+",
111        help="The filepath(s) to the tif file(s) containing the vesicle segmentation."
112    )
113    parser.add_argument(
114        "--table_paths", "-t", required=True, nargs="+",
115        help="The filepath(s) to the table(s) with the vesicle pool assignments."
116    )
117    parser.add_argument(
118        "-s", "--segmentation_paths", nargs="+", help="Filepaths for additional segmentations."
119    )
120    parser.add_argument(
121        "--split_pools", action="store_true", help="Whether to split the pools into individual layers.",
122    )
123    args = parser.parse_args()
124    _visualize_vesicle_pools(
125        args.input_path, args.vesicle_paths, args.table_paths, args.segmentation_paths, args.split_pools
126    )
127
128
129# TODO: handle kwargs
130def segmentation_cli():
131    parser = argparse.ArgumentParser(description="Run segmentation.")
132    parser.add_argument(
133        "--input_path", "-i", required=True,
134        help="The filepath to the mrc file or the directory containing the tomogram data."
135    )
136    parser.add_argument(
137        "--output_path", "-o", required=True,
138        help="The filepath to directory where the segmentations will be saved."
139    )
140    model_names = list(_get_model_registry().urls.keys())
141    model_names = ", ".join(model_names)
142    parser.add_argument(
143        "--model", "-m", required=True,
144        help=f"The model type. The following models are currently available: {model_names}"
145    )
146    parser.add_argument(
147        "--mask_path", help="The filepath to a tif file with a mask that will be used to restrict the segmentation."
148        "Can also be a directory with tifs if the filestructure matches input_path."
149    )
150    parser.add_argument("--input_key", "-k", required=False)
151    parser.add_argument(
152        "--force", action="store_true",
153        help="Whether to over-write already present segmentation results."
154    )
155    parser.add_argument(
156        "--tile_shape", type=int, nargs=3,
157        help="The tile shape for prediction, in ZYX order. Lower the tile shape if GPU memory is insufficient."
158    )
159    parser.add_argument(
160        "--halo", type=int, nargs=3,
161        help="The halo for prediction, in ZYX order. Increase the halo to minimize boundary artifacts."
162    )
163    parser.add_argument(
164        "--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc."
165    )
166    parser.add_argument(
167        "--checkpoint", "-c", help="Path to a custom model, e.g. from domain adaptation.",
168    )
169    parser.add_argument(
170        "--segmentation_key", "-s",
171        help="If given, the outputs will be saved to an hdf5 file with this key. Otherwise they will be saved as tif.",
172    )
173    parser.add_argument(
174        "--scale", type=float,
175        help="The factor for rescaling the data before inference. "
176        "By default, the scaling factor will be derived from the voxel size of the input data. "
177        "If this parameter is given it will over-ride the default behavior. "
178    )
179    parser.add_argument(
180        "--verbose", "-v", action="store_true",
181        help="Whether to print verbose information about the segmentation progress."
182    )
183    parser.add_argument(
184        "--scalable", action="store_true", help="Use the scalable segmentation implementation. "
185        "Currently this only works for vesicles, mitochondria, or active zones."
186    )
187    parser.add_argument(
188        "--extra_input_path", default=None, help="Filepath to extra inputs, needed for cristae segmentation."
189    )
190    parser.add_argument(
191        "--extra_input_ext", default=".tif", help="File extension for the extra inputs, default is tif."
192    )
193    args = parser.parse_args()
194
195    if args.checkpoint is None:
196        model = get_model(args.model)
197    else:
198        checkpoint_path = args.checkpoint
199        if checkpoint_path.endswith("best.pt"):
200            checkpoint_path = os.path.split(checkpoint_path)[0]
201
202        if os.path.isdir(checkpoint_path):  # Load the model from a torch_em checkpoint.
203            model = torch_em.util.load_model(checkpoint=checkpoint_path)
204        else:
205            model = torch.load(checkpoint_path, weights_only=False)
206        assert model is not None, f"The model from {args.checkpoint} could not be loaded."
207
208    is_2d = "2d" in args.model
209    tiling = parse_tiling(args.tile_shape, args.halo, is_2d=is_2d)
210
211    # If the scale argument is not passed, then we get the average training resolution for the model.
212    # The inputs will then be scaled to match this resolution based on the voxel size from the mrc files.
213    if args.scale is None:
214        model_resolution = get_model_training_resolution(args.model)
215        model_resolution = tuple(model_resolution[ax] for ax in ("yx" if is_2d else "zyx"))
216        scale = None
217    # Otherwise, we set the model resolution to None and use the scaling factor provided by the user.
218    else:
219        model_resolution = None
220        scale = (2 if is_2d else 3) * (args.scale,)
221
222    if args.scalable:
223        if not args.model.startswith(("vesicle", "mito", "active")):
224            raise ValueError(
225                "The scalable segmentation implementation is currently only supported for "
226                f"vesicles, mitochondria, or active zones, not for {args.model}."
227            )
228        segmentation_function = partial(
229            scalable_segmentation, model=model, tiling=tiling, verbose=args.verbose
230        )
231        allocate_output = True
232
233    else:
234        segmentation_function = partial(
235            run_segmentation, model=model, model_type=args.model, verbose=args.verbose, tiling=tiling,
236        )
237        allocate_output = False
238
239    inference_helper(
240        args.input_path, args.output_path, segmentation_function,
241        mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext,
242        output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale,
243        allocate_output=allocate_output, extra_input_path=args.extra_input_path,
244        extra_input_ext=args.extra_input_ext
245    )
def imod_point_cli():
15def imod_point_cli():
16    parser = argparse.ArgumentParser(
17        description="Convert a vesicle segmentation to an IMOD point model, "
18        "corresponding to a sphere for each vesicle in the segmentation."
19    )
20    parser.add_argument(
21        "--input_path", "-i", required=True,
22        help="The filepath to the mrc file or the directory containing the tomogram data."
23    )
24    parser.add_argument(
25        "--segmentation_path", "-s", required=True,
26        help="The filepath to the file or the directory containing the segmentations."
27    )
28    parser.add_argument(
29        "--output_path", "-o", required=True,
30        help="The filepath to directory where the segmentations will be saved."
31    )
32    parser.add_argument(
33        "--segmentation_key", "-k",
34        help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif."
35        "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset."
36    )
37    parser.add_argument(
38        "--min_radius", type=float, default=10.0,
39        help="The minimum vesicle radius in nm. Objects that are smaller than this radius will be exclded from the export."  # noqa
40    )
41    parser.add_argument(
42        "--radius_factor", type=float, default=1.0,
43        help="A factor for scaling the sphere radius for the export. "
44        "This can be used to fit the size of segmented vesicles to the best matching spheres.",
45    )
46    parser.add_argument(
47        "--force", action="store_true",
48        help="Whether to over-write already present export results."
49    )
50    args = parser.parse_args()
51
52    export_function = partial(
53        write_segmentation_to_imod_as_points,
54        min_radius=args.min_radius,
55        radius_factor=args.radius_factor,
56    )
57
58    export_helper(
59        input_path=args.input_path,
60        segmentation_path=args.segmentation_path,
61        output_root=args.output_path,
62        export_function=export_function,
63        force=args.force,
64        segmentation_key=args.segmentation_key,
65    )
def imod_object_cli():
 68def imod_object_cli():
 69    parser = argparse.ArgumentParser(
 70        description="Convert segmented objects to close contour IMOD models."
 71    )
 72    parser.add_argument(
 73        "--input_path", "-i", required=True,
 74        help="The filepath to the mrc file or the directory containing the tomogram data."
 75    )
 76    parser.add_argument(
 77        "--segmentation_path", "-s", required=True,
 78        help="The filepath to the file or the directory containing the segmentations."
 79    )
 80    parser.add_argument(
 81        "--output_path", "-o", required=True,
 82        help="The filepath to directory where the segmentations will be saved."
 83    )
 84    parser.add_argument(
 85        "--segmentation_key", "-k",
 86        help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif."
 87        "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset."
 88    )
 89    parser.add_argument(
 90        "--force", action="store_true",
 91        help="Whether to over-write already present export results."
 92    )
 93    args = parser.parse_args()
 94    export_helper(
 95        input_path=args.input_path,
 96        segmentation_path=args.segmentation_path,
 97        output_root=args.output_path,
 98        export_function=write_segmentation_to_imod,
 99        force=args.force,
100        segmentation_key=args.segmentation_key,
101    )
def pool_visualization_cli():
104def pool_visualization_cli():
105    parser = argparse.ArgumentParser(description="Load tomogram data, vesicle pools and additional segmentations for viualization.")  # noqa
106    parser.add_argument(
107        "--input_path", "-i", required=True,
108        help="The filepath to the mrc file containing the tomogram data."
109    )
110    parser.add_argument(
111        "--vesicle_paths", "-v", required=True, nargs="+",
112        help="The filepath(s) to the tif file(s) containing the vesicle segmentation."
113    )
114    parser.add_argument(
115        "--table_paths", "-t", required=True, nargs="+",
116        help="The filepath(s) to the table(s) with the vesicle pool assignments."
117    )
118    parser.add_argument(
119        "-s", "--segmentation_paths", nargs="+", help="Filepaths for additional segmentations."
120    )
121    parser.add_argument(
122        "--split_pools", action="store_true", help="Whether to split the pools into individual layers.",
123    )
124    args = parser.parse_args()
125    _visualize_vesicle_pools(
126        args.input_path, args.vesicle_paths, args.table_paths, args.segmentation_paths, args.split_pools
127    )
def segmentation_cli():
131def segmentation_cli():
132    parser = argparse.ArgumentParser(description="Run segmentation.")
133    parser.add_argument(
134        "--input_path", "-i", required=True,
135        help="The filepath to the mrc file or the directory containing the tomogram data."
136    )
137    parser.add_argument(
138        "--output_path", "-o", required=True,
139        help="The filepath to directory where the segmentations will be saved."
140    )
141    model_names = list(_get_model_registry().urls.keys())
142    model_names = ", ".join(model_names)
143    parser.add_argument(
144        "--model", "-m", required=True,
145        help=f"The model type. The following models are currently available: {model_names}"
146    )
147    parser.add_argument(
148        "--mask_path", help="The filepath to a tif file with a mask that will be used to restrict the segmentation."
149        "Can also be a directory with tifs if the filestructure matches input_path."
150    )
151    parser.add_argument("--input_key", "-k", required=False)
152    parser.add_argument(
153        "--force", action="store_true",
154        help="Whether to over-write already present segmentation results."
155    )
156    parser.add_argument(
157        "--tile_shape", type=int, nargs=3,
158        help="The tile shape for prediction, in ZYX order. Lower the tile shape if GPU memory is insufficient."
159    )
160    parser.add_argument(
161        "--halo", type=int, nargs=3,
162        help="The halo for prediction, in ZYX order. Increase the halo to minimize boundary artifacts."
163    )
164    parser.add_argument(
165        "--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc."
166    )
167    parser.add_argument(
168        "--checkpoint", "-c", help="Path to a custom model, e.g. from domain adaptation.",
169    )
170    parser.add_argument(
171        "--segmentation_key", "-s",
172        help="If given, the outputs will be saved to an hdf5 file with this key. Otherwise they will be saved as tif.",
173    )
174    parser.add_argument(
175        "--scale", type=float,
176        help="The factor for rescaling the data before inference. "
177        "By default, the scaling factor will be derived from the voxel size of the input data. "
178        "If this parameter is given it will over-ride the default behavior. "
179    )
180    parser.add_argument(
181        "--verbose", "-v", action="store_true",
182        help="Whether to print verbose information about the segmentation progress."
183    )
184    parser.add_argument(
185        "--scalable", action="store_true", help="Use the scalable segmentation implementation. "
186        "Currently this only works for vesicles, mitochondria, or active zones."
187    )
188    parser.add_argument(
189        "--extra_input_path", default=None, help="Filepath to extra inputs, needed for cristae segmentation."
190    )
191    parser.add_argument(
192        "--extra_input_ext", default=".tif", help="File extension for the extra inputs, default is tif."
193    )
194    args = parser.parse_args()
195
196    if args.checkpoint is None:
197        model = get_model(args.model)
198    else:
199        checkpoint_path = args.checkpoint
200        if checkpoint_path.endswith("best.pt"):
201            checkpoint_path = os.path.split(checkpoint_path)[0]
202
203        if os.path.isdir(checkpoint_path):  # Load the model from a torch_em checkpoint.
204            model = torch_em.util.load_model(checkpoint=checkpoint_path)
205        else:
206            model = torch.load(checkpoint_path, weights_only=False)
207        assert model is not None, f"The model from {args.checkpoint} could not be loaded."
208
209    is_2d = "2d" in args.model
210    tiling = parse_tiling(args.tile_shape, args.halo, is_2d=is_2d)
211
212    # If the scale argument is not passed, then we get the average training resolution for the model.
213    # The inputs will then be scaled to match this resolution based on the voxel size from the mrc files.
214    if args.scale is None:
215        model_resolution = get_model_training_resolution(args.model)
216        model_resolution = tuple(model_resolution[ax] for ax in ("yx" if is_2d else "zyx"))
217        scale = None
218    # Otherwise, we set the model resolution to None and use the scaling factor provided by the user.
219    else:
220        model_resolution = None
221        scale = (2 if is_2d else 3) * (args.scale,)
222
223    if args.scalable:
224        if not args.model.startswith(("vesicle", "mito", "active")):
225            raise ValueError(
226                "The scalable segmentation implementation is currently only supported for "
227                f"vesicles, mitochondria, or active zones, not for {args.model}."
228            )
229        segmentation_function = partial(
230            scalable_segmentation, model=model, tiling=tiling, verbose=args.verbose
231        )
232        allocate_output = True
233
234    else:
235        segmentation_function = partial(
236            run_segmentation, model=model, model_type=args.model, verbose=args.verbose, tiling=tiling,
237        )
238        allocate_output = False
239
240    inference_helper(
241        args.input_path, args.output_path, segmentation_function,
242        mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext,
243        output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale,
244        allocate_output=allocate_output, extra_input_path=args.extra_input_path,
245        extra_input_ext=args.extra_input_ext
246    )