micro_sam.training.simple_sam_trainer
1import random 2 3from . import SamTrainer 4 5 6class SimpleSamTrainer(SamTrainer): 7 """Trainer class for creating a simple SAM trainer for limited prompt-based segmentation. 8 9 This class is inherited from `SamTrainer`. 10 Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py 11 for details on its implementation. 12 13 Args: 14 use_points: Whether to use point prompts for interactive segmentation. By default, set to 'True'. 15 use_box: Whether to use box prompts for interactive segmentation. By default, set to 'True'. 16 kwargs: The keyword arguments of the `SamTrainer` (and `DefaultTrainer`) class. 17 """ 18 19 def __init__(self, use_points: bool = True, use_box: bool = True, **kwargs): 20 super().__init__(n_sub_iteration=1, mask_prob=0, **kwargs) 21 self.use_points = use_points 22 self.use_box = use_box 23 24 if self.use_points and self.use_box: 25 self.random_prompt_choice = True 26 else: 27 self.random_prompt_choice = False 28 29 assert (self.use_points + self.use_box) != 0, "Please choose at least one of the prompt-based method." 30 31 def _choose_one_positive_point(self): 32 """Samples only a single positive point per object 33 """ 34 n_pos, n_neg = 1, 0 35 multimask_output = True 36 return n_pos, n_neg, None, multimask_output 37 38 def _choose_box(self): 39 """Samples only a single box per object 40 """ 41 n_pos, n_neg = 0, 0 42 multimask_output = False 43 get_boxes = True 44 return n_pos, n_neg, get_boxes, multimask_output 45 46 def _get_prompt_and_multimasking_choices(self, current_iteration): 47 if self.random_prompt_choice: # both "use_points" and "use_box" are True 48 available_choices = [self._choose_one_positive_point(), self._choose_box()] 49 return random.choice(available_choices) 50 else: # either of "use_points" or "use_box" are True 51 if self.use_points: 52 return self._choose_one_positive_point() 53 else: 54 return self._choose_box() 55 56 def _get_prompt_and_multimasking_choices_for_val(self, current_iteration): 57 return self._get_prompt_and_multimasking_choices(current_iteration) 58 59 60class MedSAMTrainer(SimpleSamTrainer): 61 """Trainer class for replicating the trainer of MedSAM (https://arxiv.org/abs/2304.12306). 62 63 This class is inherited from `SimpleSamTrainer`. 64 Check out 65 https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/simple_sam_trainer.py 66 for details on its implementation. 67 """ 68 69 def __init__(self, **kwargs): 70 super().__init__(use_points=False, use_box=True, **kwargs)
7class SimpleSamTrainer(SamTrainer): 8 """Trainer class for creating a simple SAM trainer for limited prompt-based segmentation. 9 10 This class is inherited from `SamTrainer`. 11 Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py 12 for details on its implementation. 13 14 Args: 15 use_points: Whether to use point prompts for interactive segmentation. By default, set to 'True'. 16 use_box: Whether to use box prompts for interactive segmentation. By default, set to 'True'. 17 kwargs: The keyword arguments of the `SamTrainer` (and `DefaultTrainer`) class. 18 """ 19 20 def __init__(self, use_points: bool = True, use_box: bool = True, **kwargs): 21 super().__init__(n_sub_iteration=1, mask_prob=0, **kwargs) 22 self.use_points = use_points 23 self.use_box = use_box 24 25 if self.use_points and self.use_box: 26 self.random_prompt_choice = True 27 else: 28 self.random_prompt_choice = False 29 30 assert (self.use_points + self.use_box) != 0, "Please choose at least one of the prompt-based method." 31 32 def _choose_one_positive_point(self): 33 """Samples only a single positive point per object 34 """ 35 n_pos, n_neg = 1, 0 36 multimask_output = True 37 return n_pos, n_neg, None, multimask_output 38 39 def _choose_box(self): 40 """Samples only a single box per object 41 """ 42 n_pos, n_neg = 0, 0 43 multimask_output = False 44 get_boxes = True 45 return n_pos, n_neg, get_boxes, multimask_output 46 47 def _get_prompt_and_multimasking_choices(self, current_iteration): 48 if self.random_prompt_choice: # both "use_points" and "use_box" are True 49 available_choices = [self._choose_one_positive_point(), self._choose_box()] 50 return random.choice(available_choices) 51 else: # either of "use_points" or "use_box" are True 52 if self.use_points: 53 return self._choose_one_positive_point() 54 else: 55 return self._choose_box() 56 57 def _get_prompt_and_multimasking_choices_for_val(self, current_iteration): 58 return self._get_prompt_and_multimasking_choices(current_iteration)
Trainer class for creating a simple SAM trainer for limited prompt-based segmentation.
This class is inherited from SamTrainer
.
Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py
for details on its implementation.
Arguments:
- use_points: Whether to use point prompts for interactive segmentation. By default, set to 'True'.
- use_box: Whether to use box prompts for interactive segmentation. By default, set to 'True'.
- kwargs: The keyword arguments of the
SamTrainer
(andDefaultTrainer
) class.
SimpleSamTrainer(use_points: bool = True, use_box: bool = True, **kwargs)
20 def __init__(self, use_points: bool = True, use_box: bool = True, **kwargs): 21 super().__init__(n_sub_iteration=1, mask_prob=0, **kwargs) 22 self.use_points = use_points 23 self.use_box = use_box 24 25 if self.use_points and self.use_box: 26 self.random_prompt_choice = True 27 else: 28 self.random_prompt_choice = False 29 30 assert (self.use_points + self.use_box) != 0, "Please choose at least one of the prompt-based method."
Inherited Members
- micro_sam.training.sam_trainer.SamTrainer
- convert_inputs
- mse_loss
- n_objects_per_batch
- n_sub_iteration
- prompt_generator
- mask_prob
- is_data_parallel
- torch_em.trainer.default_trainer.DefaultTrainer
- name
- id_
- train_loader
- val_loader
- model
- loss
- optimizer
- metric
- device
- lr_scheduler
- log_image_interval
- save_root
- compile_model
- rank
- mixed_precision
- early_stopping
- train_time
- logger_class
- logger_kwargs
- checkpoint_folder
- iteration
- epoch
- Deserializer
- Serializer
- fit
61class MedSAMTrainer(SimpleSamTrainer): 62 """Trainer class for replicating the trainer of MedSAM (https://arxiv.org/abs/2304.12306). 63 64 This class is inherited from `SimpleSamTrainer`. 65 Check out 66 https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/simple_sam_trainer.py 67 for details on its implementation. 68 """ 69 70 def __init__(self, **kwargs): 71 super().__init__(use_points=False, use_box=True, **kwargs)
Trainer class for replicating the trainer of MedSAM (https://arxiv.org/abs/2304.12306).
This class is inherited from SimpleSamTrainer
.
Check out
https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/simple_sam_trainer.py
for details on its implementation.
Inherited Members
- micro_sam.training.sam_trainer.SamTrainer
- convert_inputs
- mse_loss
- n_objects_per_batch
- n_sub_iteration
- prompt_generator
- mask_prob
- is_data_parallel
- torch_em.trainer.default_trainer.DefaultTrainer
- name
- id_
- train_loader
- val_loader
- model
- loss
- optimizer
- metric
- device
- lr_scheduler
- log_image_interval
- save_root
- compile_model
- rank
- mixed_precision
- early_stopping
- train_time
- logger_class
- logger_kwargs
- checkpoint_folder
- iteration
- epoch
- Deserializer
- Serializer
- fit