micro_sam.training.simple_sam_trainer
1import random 2 3from . import SamTrainer 4 5 6class SimpleSamTrainer(SamTrainer): 7 """Trainer class for creating a simple SAM trainer for limited prompt-based segmentation. 8 9 This class is inherited from `SamTrainer`. 10 Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py 11 for details on its implementation. 12 13 Args: 14 use_points: Whether to use point prompts for interactive segmentation. 15 use_box: Whether to use box prompts for interactive segmentation. 16 kwargs: The keyword arguments of the `SamTrainer` (and `DefaultTrainer`) class. 17 """ 18 def __init__( 19 self, 20 use_points: bool = True, 21 use_box: bool = True, 22 **kwargs 23 ): 24 super().__init__( 25 n_sub_iteration=1, 26 mask_prob=0, 27 **kwargs 28 ) 29 self.use_points = use_points 30 self.use_box = use_box 31 32 if self.use_points and self.use_box: 33 self.random_prompt_choice = True 34 else: 35 self.random_prompt_choice = False 36 37 assert (self.use_points + self.use_box) != 0, "Please choose at least one of the prompt-based method." 38 39 def _choose_one_positive_point(self): 40 """Samples only a single positive point per object 41 """ 42 n_pos, n_neg = 1, 0 43 multimask_output = True 44 return n_pos, n_neg, None, multimask_output 45 46 def _choose_box(self): 47 """Samples only a single box per object 48 """ 49 n_pos, n_neg = 0, 0 50 multimask_output = False 51 get_boxes = True 52 return n_pos, n_neg, get_boxes, multimask_output 53 54 def _get_prompt_and_multimasking_choices(self, current_iteration): 55 if self.random_prompt_choice: # both "use_points" and "use_box" are True 56 available_choices = [self._choose_one_positive_point(), self._choose_box()] 57 return random.choice(available_choices) 58 else: # either of "use_points" or "use_box" are True 59 if self.use_points: 60 return self._choose_one_positive_point() 61 else: 62 return self._choose_box() 63 64 def _get_prompt_and_multimasking_choices_for_val(self, current_iteration): 65 return self._get_prompt_and_multimasking_choices(current_iteration) 66 67 68class MedSAMTrainer(SimpleSamTrainer): 69 """Trainer class for replicating the trainer of MedSAM (https://arxiv.org/abs/2304.12306). 70 71 This class is inherited from `SimpleSamTrainer`. 72 Check out 73 https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/simple_sam_trainer.py 74 for details on its implementation. 75 """ 76 def __init__(self, **kwargs): 77 super().__init__( 78 use_points=False, 79 use_box=True, 80 **kwargs 81 )
7class SimpleSamTrainer(SamTrainer): 8 """Trainer class for creating a simple SAM trainer for limited prompt-based segmentation. 9 10 This class is inherited from `SamTrainer`. 11 Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py 12 for details on its implementation. 13 14 Args: 15 use_points: Whether to use point prompts for interactive segmentation. 16 use_box: Whether to use box prompts for interactive segmentation. 17 kwargs: The keyword arguments of the `SamTrainer` (and `DefaultTrainer`) class. 18 """ 19 def __init__( 20 self, 21 use_points: bool = True, 22 use_box: bool = True, 23 **kwargs 24 ): 25 super().__init__( 26 n_sub_iteration=1, 27 mask_prob=0, 28 **kwargs 29 ) 30 self.use_points = use_points 31 self.use_box = use_box 32 33 if self.use_points and self.use_box: 34 self.random_prompt_choice = True 35 else: 36 self.random_prompt_choice = False 37 38 assert (self.use_points + self.use_box) != 0, "Please choose at least one of the prompt-based method." 39 40 def _choose_one_positive_point(self): 41 """Samples only a single positive point per object 42 """ 43 n_pos, n_neg = 1, 0 44 multimask_output = True 45 return n_pos, n_neg, None, multimask_output 46 47 def _choose_box(self): 48 """Samples only a single box per object 49 """ 50 n_pos, n_neg = 0, 0 51 multimask_output = False 52 get_boxes = True 53 return n_pos, n_neg, get_boxes, multimask_output 54 55 def _get_prompt_and_multimasking_choices(self, current_iteration): 56 if self.random_prompt_choice: # both "use_points" and "use_box" are True 57 available_choices = [self._choose_one_positive_point(), self._choose_box()] 58 return random.choice(available_choices) 59 else: # either of "use_points" or "use_box" are True 60 if self.use_points: 61 return self._choose_one_positive_point() 62 else: 63 return self._choose_box() 64 65 def _get_prompt_and_multimasking_choices_for_val(self, current_iteration): 66 return self._get_prompt_and_multimasking_choices(current_iteration)
Trainer class for creating a simple SAM trainer for limited prompt-based segmentation.
This class is inherited from SamTrainer
.
Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py
for details on its implementation.
Arguments:
- use_points: Whether to use point prompts for interactive segmentation.
- use_box: Whether to use box prompts for interactive segmentation.
- kwargs: The keyword arguments of the
SamTrainer
(andDefaultTrainer
) class.
SimpleSamTrainer(use_points: bool = True, use_box: bool = True, **kwargs)
19 def __init__( 20 self, 21 use_points: bool = True, 22 use_box: bool = True, 23 **kwargs 24 ): 25 super().__init__( 26 n_sub_iteration=1, 27 mask_prob=0, 28 **kwargs 29 ) 30 self.use_points = use_points 31 self.use_box = use_box 32 33 if self.use_points and self.use_box: 34 self.random_prompt_choice = True 35 else: 36 self.random_prompt_choice = False 37 38 assert (self.use_points + self.use_box) != 0, "Please choose at least one of the prompt-based method."
Inherited Members
- micro_sam.training.sam_trainer.SamTrainer
- convert_inputs
- mse_loss
- n_objects_per_batch
- n_sub_iteration
- prompt_generator
- mask_prob
- torch_em.trainer.default_trainer.DefaultTrainer
- name
- id_
- train_loader
- val_loader
- model
- loss
- optimizer
- metric
- device
- lr_scheduler
- log_image_interval
- save_root
- compile_model
- rank
- mixed_precision
- early_stopping
- train_time
- logger_class
- logger_kwargs
- checkpoint_folder
- iteration
- epoch
- Deserializer
- from_checkpoint
- Serializer
- save_checkpoint
- load_checkpoint
- fit
69class MedSAMTrainer(SimpleSamTrainer): 70 """Trainer class for replicating the trainer of MedSAM (https://arxiv.org/abs/2304.12306). 71 72 This class is inherited from `SimpleSamTrainer`. 73 Check out 74 https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/simple_sam_trainer.py 75 for details on its implementation. 76 """ 77 def __init__(self, **kwargs): 78 super().__init__( 79 use_points=False, 80 use_box=True, 81 **kwargs 82 )
Trainer class for replicating the trainer of MedSAM (https://arxiv.org/abs/2304.12306).
This class is inherited from SimpleSamTrainer
.
Check out
https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/simple_sam_trainer.py
for details on its implementation.
Inherited Members
- micro_sam.training.sam_trainer.SamTrainer
- convert_inputs
- mse_loss
- n_objects_per_batch
- n_sub_iteration
- prompt_generator
- mask_prob
- torch_em.trainer.default_trainer.DefaultTrainer
- name
- id_
- train_loader
- val_loader
- model
- loss
- optimizer
- metric
- device
- lr_scheduler
- log_image_interval
- save_root
- compile_model
- rank
- mixed_precision
- early_stopping
- train_time
- logger_class
- logger_kwargs
- checkpoint_folder
- iteration
- epoch
- Deserializer
- from_checkpoint
- Serializer
- save_checkpoint
- load_checkpoint
- fit