micro_sam.training.simple_sam_trainer
1import random 2 3from . import SamTrainer 4 5 6class SimpleSamTrainer(SamTrainer): 7 """Trainer class for creating a simple SAM trainer for limited prompt-based segmentation. 8 9 This class is inherited from `SamTrainer`. 10 Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py 11 for details on its implementation. 12 13 Args: 14 use_points: Whether to use point prompts for interactive segmentation. 15 use_box: Whether to use box prompts for interactive segmentation. 16 kwargs: The keyword arguments of the `SamTrainer` (and `DefaultTrainer`) class. 17 """ 18 19 def __init__(self, use_points: bool = True, use_box: bool = True, **kwargs): 20 super().__init__( 21 n_sub_iteration=1, 22 mask_prob=0, 23 **kwargs 24 ) 25 self.use_points = use_points 26 self.use_box = use_box 27 28 if self.use_points and self.use_box: 29 self.random_prompt_choice = True 30 else: 31 self.random_prompt_choice = False 32 33 assert (self.use_points + self.use_box) != 0, "Please choose at least one of the prompt-based method." 34 35 def _choose_one_positive_point(self): 36 """Samples only a single positive point per object 37 """ 38 n_pos, n_neg = 1, 0 39 multimask_output = True 40 return n_pos, n_neg, None, multimask_output 41 42 def _choose_box(self): 43 """Samples only a single box per object 44 """ 45 n_pos, n_neg = 0, 0 46 multimask_output = False 47 get_boxes = True 48 return n_pos, n_neg, get_boxes, multimask_output 49 50 def _get_prompt_and_multimasking_choices(self, current_iteration): 51 if self.random_prompt_choice: # both "use_points" and "use_box" are True 52 available_choices = [self._choose_one_positive_point(), self._choose_box()] 53 return random.choice(available_choices) 54 else: # either of "use_points" or "use_box" are True 55 if self.use_points: 56 return self._choose_one_positive_point() 57 else: 58 return self._choose_box() 59 60 def _get_prompt_and_multimasking_choices_for_val(self, current_iteration): 61 return self._get_prompt_and_multimasking_choices(current_iteration) 62 63 64class MedSAMTrainer(SimpleSamTrainer): 65 """Trainer class for replicating the trainer of MedSAM (https://arxiv.org/abs/2304.12306). 66 67 This class is inherited from `SimpleSamTrainer`. 68 Check out 69 https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/simple_sam_trainer.py 70 for details on its implementation. 71 """ 72 73 def __init__(self, **kwargs): 74 super().__init__(use_points=False, use_box=True, **kwargs)
7class SimpleSamTrainer(SamTrainer): 8 """Trainer class for creating a simple SAM trainer for limited prompt-based segmentation. 9 10 This class is inherited from `SamTrainer`. 11 Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py 12 for details on its implementation. 13 14 Args: 15 use_points: Whether to use point prompts for interactive segmentation. 16 use_box: Whether to use box prompts for interactive segmentation. 17 kwargs: The keyword arguments of the `SamTrainer` (and `DefaultTrainer`) class. 18 """ 19 20 def __init__(self, use_points: bool = True, use_box: bool = True, **kwargs): 21 super().__init__( 22 n_sub_iteration=1, 23 mask_prob=0, 24 **kwargs 25 ) 26 self.use_points = use_points 27 self.use_box = use_box 28 29 if self.use_points and self.use_box: 30 self.random_prompt_choice = True 31 else: 32 self.random_prompt_choice = False 33 34 assert (self.use_points + self.use_box) != 0, "Please choose at least one of the prompt-based method." 35 36 def _choose_one_positive_point(self): 37 """Samples only a single positive point per object 38 """ 39 n_pos, n_neg = 1, 0 40 multimask_output = True 41 return n_pos, n_neg, None, multimask_output 42 43 def _choose_box(self): 44 """Samples only a single box per object 45 """ 46 n_pos, n_neg = 0, 0 47 multimask_output = False 48 get_boxes = True 49 return n_pos, n_neg, get_boxes, multimask_output 50 51 def _get_prompt_and_multimasking_choices(self, current_iteration): 52 if self.random_prompt_choice: # both "use_points" and "use_box" are True 53 available_choices = [self._choose_one_positive_point(), self._choose_box()] 54 return random.choice(available_choices) 55 else: # either of "use_points" or "use_box" are True 56 if self.use_points: 57 return self._choose_one_positive_point() 58 else: 59 return self._choose_box() 60 61 def _get_prompt_and_multimasking_choices_for_val(self, current_iteration): 62 return self._get_prompt_and_multimasking_choices(current_iteration)
Trainer class for creating a simple SAM trainer for limited prompt-based segmentation.
This class is inherited from SamTrainer
.
Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py
for details on its implementation.
Arguments:
- use_points: Whether to use point prompts for interactive segmentation.
- use_box: Whether to use box prompts for interactive segmentation.
- kwargs: The keyword arguments of the
SamTrainer
(andDefaultTrainer
) class.
SimpleSamTrainer(use_points: bool = True, use_box: bool = True, **kwargs)
20 def __init__(self, use_points: bool = True, use_box: bool = True, **kwargs): 21 super().__init__( 22 n_sub_iteration=1, 23 mask_prob=0, 24 **kwargs 25 ) 26 self.use_points = use_points 27 self.use_box = use_box 28 29 if self.use_points and self.use_box: 30 self.random_prompt_choice = True 31 else: 32 self.random_prompt_choice = False 33 34 assert (self.use_points + self.use_box) != 0, "Please choose at least one of the prompt-based method."
Inherited Members
- micro_sam.training.sam_trainer.SamTrainer
- convert_inputs
- mse_loss
- n_objects_per_batch
- n_sub_iteration
- prompt_generator
- mask_prob
- torch_em.trainer.default_trainer.DefaultTrainer
- name
- id_
- train_loader
- val_loader
- model
- loss
- optimizer
- metric
- device
- lr_scheduler
- log_image_interval
- save_root
- compile_model
- rank
- mixed_precision
- early_stopping
- train_time
- logger_class
- logger_kwargs
- checkpoint_folder
- iteration
- epoch
- Deserializer
- Serializer
- fit
65class MedSAMTrainer(SimpleSamTrainer): 66 """Trainer class for replicating the trainer of MedSAM (https://arxiv.org/abs/2304.12306). 67 68 This class is inherited from `SimpleSamTrainer`. 69 Check out 70 https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/simple_sam_trainer.py 71 for details on its implementation. 72 """ 73 74 def __init__(self, **kwargs): 75 super().__init__(use_points=False, use_box=True, **kwargs)
Trainer class for replicating the trainer of MedSAM (https://arxiv.org/abs/2304.12306).
This class is inherited from SimpleSamTrainer
.
Check out
https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/simple_sam_trainer.py
for details on its implementation.
Inherited Members
- micro_sam.training.sam_trainer.SamTrainer
- convert_inputs
- mse_loss
- n_objects_per_batch
- n_sub_iteration
- prompt_generator
- mask_prob
- torch_em.trainer.default_trainer.DefaultTrainer
- name
- id_
- train_loader
- val_loader
- model
- loss
- optimizer
- metric
- device
- lr_scheduler
- log_image_interval
- save_root
- compile_model
- rank
- mixed_precision
- early_stopping
- train_time
- logger_class
- logger_kwargs
- checkpoint_folder
- iteration
- epoch
- Deserializer
- Serializer
- fit