micro_sam.sam_annotator.training_ui

  1import os
  2import warnings
  3
  4from qtpy import QtWidgets
  5# from napari.qt.threading import thread_worker
  6
  7import torch_em
  8from torch.utils.data import random_split
  9
 10import micro_sam.util as util
 11import micro_sam.sam_annotator._widgets as widgets
 12from micro_sam.training.training import _find_best_configuration, _export_helper
 13from micro_sam.training import default_sam_dataset, train_sam_for_configuration, CONFIGURATIONS
 14
 15from ._tooltips import get_tooltip
 16
 17
 18class TrainingWidget(widgets._WidgetBase):
 19    def __init__(self, parent=None):
 20        super().__init__(parent=parent)
 21
 22        # Create the UI: the general options.
 23        self._create_options()
 24
 25        # Add the settings (collapsible).
 26        self.layout().addWidget(self._create_settings())
 27
 28        # Add the run button to trigger the embedding computation.
 29        self.run_button = QtWidgets.QPushButton("Start Training")
 30        self.run_button.clicked.connect(self.__call__)
 31        self.layout().addWidget(self.run_button)
 32
 33    def _create_options(self):
 34        self.raw_path = None
 35        _, layout = self._add_path_param(
 36            "raw_path", self.raw_path, "both", placeholder="/path/to/images", title="Path to images",
 37            tooltip=get_tooltip("training", "raw_path")
 38        )
 39        self.layout().addLayout(layout)
 40
 41        self.raw_key = None
 42        _, layout = self._add_string_param(
 43            "raw_key", self.raw_key, placeholder="e.g. \"*.tif\"", title="Image data key",
 44            tooltip=get_tooltip("training", "raw_key")
 45        )
 46        self.layout().addLayout(layout)
 47
 48        self.label_path = None
 49        _, layout = self._add_path_param(
 50            "label_path", self.label_path, "both", placeholder="/path/to/labels", title="Path to labels",
 51            tooltip=get_tooltip("training", "label_path")
 52        )
 53        self.layout().addLayout(layout)
 54
 55        self.label_key = None
 56        _, layout = self._add_string_param(
 57            "label_key", self.label_key, placeholder="e.g. \"*.tif\"", title="Label data key",
 58            tooltip=get_tooltip("training", "label_key")
 59        )
 60        self.layout().addLayout(layout)
 61
 62        self.configuration = _find_best_configuration()
 63        self.setting_dropdown, layout = self._add_choice_param(
 64            "configuration", self.configuration, list(CONFIGURATIONS.keys()), title="Configuration",
 65            tooltip=get_tooltip("training", "configuration")
 66        )
 67        self.layout().addLayout(layout)
 68
 69        self.with_segmentation_decoder = True
 70        self.layout().addWidget(self._add_boolean_param(
 71            "with_segmentation_decoder", self.with_segmentation_decoder, title="With segmentation decoder",
 72            tooltip=get_tooltip("training", "segmentation_decoder")
 73        ))
 74
 75    def _create_settings(self):
 76        setting_values = QtWidgets.QWidget()
 77        setting_values.setLayout(QtWidgets.QVBoxLayout())
 78
 79        # TODO use CPU instead of MPS on MAC because training with MPS is slower!
 80        # Device and patch shape settings.
 81        self.device = "auto"
 82        device_options = ["auto"] + util._available_devices()
 83        self.device_dropdown, layout = self._add_choice_param(
 84            "device", self.device, device_options, title="Device", tooltip=get_tooltip("training", "device")
 85        )
 86        setting_values.layout().addLayout(layout)
 87
 88        self.patch_x, self.patch_y = 512, 512
 89        self.patch_x_param, self.patch_y_param, layout = self._add_shape_param(
 90            ("patch_x", "patch_y"), (self.patch_x, self.patch_y), min_val=0, max_val=2048,
 91            tooltip=get_tooltip("training", "patch"), title=("Patch size x", "Patch size y")
 92        )
 93        setting_values.layout().addLayout(layout)
 94
 95        # Paths for validation data.
 96        self.raw_path_val = None
 97        _, layout = self._add_path_param(
 98            "raw_path_val", self.raw_path_val, "both", placeholder="/path/to/images",
 99            title="Path to validation images", tooltip=get_tooltip("training", "raw_path_val")
100        )
101        setting_values.layout().addLayout(layout)
102
103        self.label_path_val = None
104        _, layout = self._add_path_param(
105            "label_path_val", self.label_path_val, "both", placeholder="/path/to/images",
106            title="Path to validation labels", tooltip=get_tooltip("training", "label_path_val")
107        )
108        setting_values.layout().addLayout(layout)
109
110        # Name of the model to be trained and options to over-ride the initial model
111        # on top of which the finetuning is run.
112        self.name = "sam_model"
113        self.name_param, layout = self._add_string_param(
114            "name", self.name, title="Name of Trained Model", tooltip=get_tooltip("training", "name")
115        )
116        setting_values.layout().addLayout(layout)
117
118        # Add the model family widget section.
119        layout = self._create_model_section(default_model="vit_b", create_layout=False)
120        setting_values.layout().addLayout(layout)
121
122        # Add the model size widget section.
123        layout = self._create_model_size_section()
124        setting_values.layout().addLayout(layout)
125
126        self.custom_weights = None
127        self.custom_weights_param, layout = self._add_string_param(
128            "custom_weights", self.custom_weights, title="Custom Weights", tooltip=get_tooltip("training", "checkpoint")
129        )
130        setting_values.layout().addLayout(layout)
131
132        self.output_path = None
133        self.output_path_param, layout = self._add_string_param(
134            "output_path", self.output_path, title="Output Path", tooltip=get_tooltip("training", "output_path")
135        )
136        setting_values.layout().addLayout(layout)
137
138        self.n_epochs = 100
139        self.n_epochs_param, layout = self._add_int_param(
140            "n_epochs", self.n_epochs, title="Number of epochs", min_val=1, max_val=1000,
141            tooltip=get_tooltip("training", "n_epochs"),
142        )
143        setting_values.layout().addLayout(layout)
144
145        settings = widgets._make_collapsible(setting_values, title="Advanced Settings")
146        return settings
147
148    def _get_loaders(self):
149        batch_size = 1
150        num_workers = 1 if str(self.device) == "cpu" else 4
151
152        patch_shape = (self.patch_x, self.patch_y)
153        dataset = default_sam_dataset(
154            raw_paths=str(self.raw_path),
155            raw_key=self.raw_key,
156            label_paths=str(self.label_path),
157            label_key=self.label_key,
158            patch_shape=patch_shape,
159            with_segmentation_decoder=self.with_segmentation_decoder,
160        )
161
162        raw_path_val, label_path_val = self.raw_path_val, self.label_path_val
163        if raw_path_val is None:
164            # Use 10% of the dataset - at least one image - for validation.
165            n_val = min(1, int(0.1 * len(dataset)))
166            train_dataset, val_dataset = random_split(dataset, lengths=[len(dataset) - n_val, n_val])
167        else:
168            train_dataset = dataset
169            val_dataset = default_sam_dataset(
170                raw_paths=str(raw_path_val),
171                raw_key=self.raw_key,
172                label_paths=str(label_path_val),
173                label_key=self.label_key,
174                patch_shape=patch_shape,
175                with_segmentation_decoder=self.with_segmentation_decoder,
176            )
177
178        train_loader = torch_em.segmentation.get_data_loader(
179            train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers,
180        )
181        val_loader = torch_em.segmentation.get_data_loader(
182            val_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers,
183        )
184        return train_loader, val_loader
185
186    def _get_model_type(self):
187        # Consolidate initial model name, the checkpoint path and the model type according to the configuration.
188        suitable_model_type = CONFIGURATIONS[self.configuration]["model_type"]
189        if self.model_type[:5] == suitable_model_type:
190            self.model_type = suitable_model_type
191        else:
192            warnings.warn(
193                f"You have changed the model type for your chosen configuration '{self.configuration}' "
194                f"from '{suitable_model_type}' to '{self.model_type}'. "
195                "The training may be extremely slow. Please be aware of your custom model choice."
196            )
197
198        assert self.model_type is not None
199
200    # Make sure that raw and label path have been passed.
201    # If they haven't raise an error message.
202    # (We could do a more extensive validation here, but for now keep it minimal.)
203    def _validate_inputs(self):
204        missing_raw = self.raw_path is None or not os.path.exists(self.raw_path)
205        missing_label = self.label_path is None or not os.path.exists(self.label_path)
206        if missing_raw or missing_label:
207            msg = ""
208            if missing_raw:
209                msg += "The path to raw data is missing or does not exist. "
210            if missing_label:
211                msg += "The path to label data is missing or does not exist."
212            return widgets._generate_message("error", msg)
213        return False
214
215    def __call__(self, skip_validate=False):
216        self._validate_model_type_and_custom_weights()
217
218        if not skip_validate and self._validate_inputs():
219            return
220
221        # Set up progress bar and signals for using it within a threadworker.
222        pbar, pbar_signals = widgets._create_pbar_for_threadworker()
223
224        self._get_model_type()
225        if self.custom_weights is None:
226            model_registry = util.models()
227            checkpoint_path = model_registry.fetch(self.model_type)
228        else:
229            checkpoint_path = self.custom_weights
230
231        # @thread_worker()
232        def run_training():
233            train_loader, val_loader = self._get_loaders()
234            train_sam_for_configuration(
235                name=self.name,
236                configuration=self.configuration,
237                train_loader=train_loader,
238                val_loader=val_loader,
239                checkpoint_path=checkpoint_path,
240                with_segmentation_decoder=self.with_segmentation_decoder,
241                model_type=self.model_type,
242                device=self.device,
243                n_epochs=self.n_epochs,
244                pbar_signals=pbar_signals,
245            )
246
247            # The best checkpoint after training.
248            export_checkpoint = os.path.join("checkpoints", self.name, "best.pt")
249            assert os.path.exists(export_checkpoint), export_checkpoint
250
251            output_path = _export_helper(
252                "", self.name, self.output_path, self.model_type, self.with_segmentation_decoder, val_loader
253            )
254            pbar_signals.pbar_stop.emit()
255            return output_path
256
257        path = run_training()
258        print(f"Training has finished. The trained model is saved at {path}.")
259        # worker = run_training()
260        # worker.returned.connect(lambda path: print(f"Training has finished. The trained model is saved at {path}."))
261        # worker.start()
262        # return worker
class TrainingWidget(micro_sam.sam_annotator._widgets._WidgetBase):
 19class TrainingWidget(widgets._WidgetBase):
 20    def __init__(self, parent=None):
 21        super().__init__(parent=parent)
 22
 23        # Create the UI: the general options.
 24        self._create_options()
 25
 26        # Add the settings (collapsible).
 27        self.layout().addWidget(self._create_settings())
 28
 29        # Add the run button to trigger the embedding computation.
 30        self.run_button = QtWidgets.QPushButton("Start Training")
 31        self.run_button.clicked.connect(self.__call__)
 32        self.layout().addWidget(self.run_button)
 33
 34    def _create_options(self):
 35        self.raw_path = None
 36        _, layout = self._add_path_param(
 37            "raw_path", self.raw_path, "both", placeholder="/path/to/images", title="Path to images",
 38            tooltip=get_tooltip("training", "raw_path")
 39        )
 40        self.layout().addLayout(layout)
 41
 42        self.raw_key = None
 43        _, layout = self._add_string_param(
 44            "raw_key", self.raw_key, placeholder="e.g. \"*.tif\"", title="Image data key",
 45            tooltip=get_tooltip("training", "raw_key")
 46        )
 47        self.layout().addLayout(layout)
 48
 49        self.label_path = None
 50        _, layout = self._add_path_param(
 51            "label_path", self.label_path, "both", placeholder="/path/to/labels", title="Path to labels",
 52            tooltip=get_tooltip("training", "label_path")
 53        )
 54        self.layout().addLayout(layout)
 55
 56        self.label_key = None
 57        _, layout = self._add_string_param(
 58            "label_key", self.label_key, placeholder="e.g. \"*.tif\"", title="Label data key",
 59            tooltip=get_tooltip("training", "label_key")
 60        )
 61        self.layout().addLayout(layout)
 62
 63        self.configuration = _find_best_configuration()
 64        self.setting_dropdown, layout = self._add_choice_param(
 65            "configuration", self.configuration, list(CONFIGURATIONS.keys()), title="Configuration",
 66            tooltip=get_tooltip("training", "configuration")
 67        )
 68        self.layout().addLayout(layout)
 69
 70        self.with_segmentation_decoder = True
 71        self.layout().addWidget(self._add_boolean_param(
 72            "with_segmentation_decoder", self.with_segmentation_decoder, title="With segmentation decoder",
 73            tooltip=get_tooltip("training", "segmentation_decoder")
 74        ))
 75
 76    def _create_settings(self):
 77        setting_values = QtWidgets.QWidget()
 78        setting_values.setLayout(QtWidgets.QVBoxLayout())
 79
 80        # TODO use CPU instead of MPS on MAC because training with MPS is slower!
 81        # Device and patch shape settings.
 82        self.device = "auto"
 83        device_options = ["auto"] + util._available_devices()
 84        self.device_dropdown, layout = self._add_choice_param(
 85            "device", self.device, device_options, title="Device", tooltip=get_tooltip("training", "device")
 86        )
 87        setting_values.layout().addLayout(layout)
 88
 89        self.patch_x, self.patch_y = 512, 512
 90        self.patch_x_param, self.patch_y_param, layout = self._add_shape_param(
 91            ("patch_x", "patch_y"), (self.patch_x, self.patch_y), min_val=0, max_val=2048,
 92            tooltip=get_tooltip("training", "patch"), title=("Patch size x", "Patch size y")
 93        )
 94        setting_values.layout().addLayout(layout)
 95
 96        # Paths for validation data.
 97        self.raw_path_val = None
 98        _, layout = self._add_path_param(
 99            "raw_path_val", self.raw_path_val, "both", placeholder="/path/to/images",
100            title="Path to validation images", tooltip=get_tooltip("training", "raw_path_val")
101        )
102        setting_values.layout().addLayout(layout)
103
104        self.label_path_val = None
105        _, layout = self._add_path_param(
106            "label_path_val", self.label_path_val, "both", placeholder="/path/to/images",
107            title="Path to validation labels", tooltip=get_tooltip("training", "label_path_val")
108        )
109        setting_values.layout().addLayout(layout)
110
111        # Name of the model to be trained and options to over-ride the initial model
112        # on top of which the finetuning is run.
113        self.name = "sam_model"
114        self.name_param, layout = self._add_string_param(
115            "name", self.name, title="Name of Trained Model", tooltip=get_tooltip("training", "name")
116        )
117        setting_values.layout().addLayout(layout)
118
119        # Add the model family widget section.
120        layout = self._create_model_section(default_model="vit_b", create_layout=False)
121        setting_values.layout().addLayout(layout)
122
123        # Add the model size widget section.
124        layout = self._create_model_size_section()
125        setting_values.layout().addLayout(layout)
126
127        self.custom_weights = None
128        self.custom_weights_param, layout = self._add_string_param(
129            "custom_weights", self.custom_weights, title="Custom Weights", tooltip=get_tooltip("training", "checkpoint")
130        )
131        setting_values.layout().addLayout(layout)
132
133        self.output_path = None
134        self.output_path_param, layout = self._add_string_param(
135            "output_path", self.output_path, title="Output Path", tooltip=get_tooltip("training", "output_path")
136        )
137        setting_values.layout().addLayout(layout)
138
139        self.n_epochs = 100
140        self.n_epochs_param, layout = self._add_int_param(
141            "n_epochs", self.n_epochs, title="Number of epochs", min_val=1, max_val=1000,
142            tooltip=get_tooltip("training", "n_epochs"),
143        )
144        setting_values.layout().addLayout(layout)
145
146        settings = widgets._make_collapsible(setting_values, title="Advanced Settings")
147        return settings
148
149    def _get_loaders(self):
150        batch_size = 1
151        num_workers = 1 if str(self.device) == "cpu" else 4
152
153        patch_shape = (self.patch_x, self.patch_y)
154        dataset = default_sam_dataset(
155            raw_paths=str(self.raw_path),
156            raw_key=self.raw_key,
157            label_paths=str(self.label_path),
158            label_key=self.label_key,
159            patch_shape=patch_shape,
160            with_segmentation_decoder=self.with_segmentation_decoder,
161        )
162
163        raw_path_val, label_path_val = self.raw_path_val, self.label_path_val
164        if raw_path_val is None:
165            # Use 10% of the dataset - at least one image - for validation.
166            n_val = min(1, int(0.1 * len(dataset)))
167            train_dataset, val_dataset = random_split(dataset, lengths=[len(dataset) - n_val, n_val])
168        else:
169            train_dataset = dataset
170            val_dataset = default_sam_dataset(
171                raw_paths=str(raw_path_val),
172                raw_key=self.raw_key,
173                label_paths=str(label_path_val),
174                label_key=self.label_key,
175                patch_shape=patch_shape,
176                with_segmentation_decoder=self.with_segmentation_decoder,
177            )
178
179        train_loader = torch_em.segmentation.get_data_loader(
180            train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers,
181        )
182        val_loader = torch_em.segmentation.get_data_loader(
183            val_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers,
184        )
185        return train_loader, val_loader
186
187    def _get_model_type(self):
188        # Consolidate initial model name, the checkpoint path and the model type according to the configuration.
189        suitable_model_type = CONFIGURATIONS[self.configuration]["model_type"]
190        if self.model_type[:5] == suitable_model_type:
191            self.model_type = suitable_model_type
192        else:
193            warnings.warn(
194                f"You have changed the model type for your chosen configuration '{self.configuration}' "
195                f"from '{suitable_model_type}' to '{self.model_type}'. "
196                "The training may be extremely slow. Please be aware of your custom model choice."
197            )
198
199        assert self.model_type is not None
200
201    # Make sure that raw and label path have been passed.
202    # If they haven't raise an error message.
203    # (We could do a more extensive validation here, but for now keep it minimal.)
204    def _validate_inputs(self):
205        missing_raw = self.raw_path is None or not os.path.exists(self.raw_path)
206        missing_label = self.label_path is None or not os.path.exists(self.label_path)
207        if missing_raw or missing_label:
208            msg = ""
209            if missing_raw:
210                msg += "The path to raw data is missing or does not exist. "
211            if missing_label:
212                msg += "The path to label data is missing or does not exist."
213            return widgets._generate_message("error", msg)
214        return False
215
216    def __call__(self, skip_validate=False):
217        self._validate_model_type_and_custom_weights()
218
219        if not skip_validate and self._validate_inputs():
220            return
221
222        # Set up progress bar and signals for using it within a threadworker.
223        pbar, pbar_signals = widgets._create_pbar_for_threadworker()
224
225        self._get_model_type()
226        if self.custom_weights is None:
227            model_registry = util.models()
228            checkpoint_path = model_registry.fetch(self.model_type)
229        else:
230            checkpoint_path = self.custom_weights
231
232        # @thread_worker()
233        def run_training():
234            train_loader, val_loader = self._get_loaders()
235            train_sam_for_configuration(
236                name=self.name,
237                configuration=self.configuration,
238                train_loader=train_loader,
239                val_loader=val_loader,
240                checkpoint_path=checkpoint_path,
241                with_segmentation_decoder=self.with_segmentation_decoder,
242                model_type=self.model_type,
243                device=self.device,
244                n_epochs=self.n_epochs,
245                pbar_signals=pbar_signals,
246            )
247
248            # The best checkpoint after training.
249            export_checkpoint = os.path.join("checkpoints", self.name, "best.pt")
250            assert os.path.exists(export_checkpoint), export_checkpoint
251
252            output_path = _export_helper(
253                "", self.name, self.output_path, self.model_type, self.with_segmentation_decoder, val_loader
254            )
255            pbar_signals.pbar_stop.emit()
256            return output_path
257
258        path = run_training()
259        print(f"Training has finished. The trained model is saved at {path}.")
260        # worker = run_training()
261        # worker.returned.connect(lambda path: print(f"Training has finished. The trained model is saved at {path}."))
262        # worker.start()
263        # return worker

QWidget(parent: Optional[QWidget] = None, flags: Qt.WindowType = Qt.WindowFlags())

TrainingWidget(parent=None)
20    def __init__(self, parent=None):
21        super().__init__(parent=parent)
22
23        # Create the UI: the general options.
24        self._create_options()
25
26        # Add the settings (collapsible).
27        self.layout().addWidget(self._create_settings())
28
29        # Add the run button to trigger the embedding computation.
30        self.run_button = QtWidgets.QPushButton("Start Training")
31        self.run_button.clicked.connect(self.__call__)
32        self.layout().addWidget(self.run_button)
run_button