micro_sam.training.simple_sam_trainer

 1import random
 2
 3from . import SamTrainer
 4
 5
 6class SimpleSamTrainer(SamTrainer):
 7    """Trainer class for creating a simple SAM trainer for limited prompt-based segmentation.
 8
 9    This class is inherited from `SamTrainer`.
10    Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py
11    for details on its implementation.
12
13    Args:
14        use_points: Whether to use point prompts for interactive segmentation. By default, set to 'True'.
15        use_box: Whether to use box prompts for interactive segmentation. By default, set to 'True'.
16        kwargs: The keyword arguments of the `SamTrainer` (and `DefaultTrainer`) class.
17    """
18
19    def __init__(self, use_points: bool = True, use_box: bool = True, **kwargs):
20        super().__init__(n_sub_iteration=1, mask_prob=0, **kwargs)
21        self.use_points = use_points
22        self.use_box = use_box
23
24        if self.use_points and self.use_box:
25            self.random_prompt_choice = True
26        else:
27            self.random_prompt_choice = False
28
29        assert (self.use_points + self.use_box) != 0, "Please choose at least one of the prompt-based method."
30
31    def _choose_one_positive_point(self):
32        """Samples only a single positive point per object
33        """
34        n_pos, n_neg = 1, 0
35        multimask_output = True
36        return n_pos, n_neg, None, multimask_output
37
38    def _choose_box(self):
39        """Samples only a single box per object
40        """
41        n_pos, n_neg = 0, 0
42        multimask_output = False
43        get_boxes = True
44        return n_pos, n_neg, get_boxes, multimask_output
45
46    def _get_prompt_and_multimasking_choices(self, current_iteration):
47        if self.random_prompt_choice:  # both "use_points" and "use_box" are True
48            available_choices = [self._choose_one_positive_point(), self._choose_box()]
49            return random.choice(available_choices)
50        else:  # either of "use_points" or "use_box" are True
51            if self.use_points:
52                return self._choose_one_positive_point()
53            else:
54                return self._choose_box()
55
56    def _get_prompt_and_multimasking_choices_for_val(self, current_iteration):
57        return self._get_prompt_and_multimasking_choices(current_iteration)
58
59
60class MedSAMTrainer(SimpleSamTrainer):
61    """Trainer class for replicating the trainer of MedSAM (https://arxiv.org/abs/2304.12306).
62
63    This class is inherited from `SimpleSamTrainer`.
64    Check out
65    https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/simple_sam_trainer.py
66    for details on its implementation.
67    """
68
69    def __init__(self, **kwargs):
70        super().__init__(use_points=False, use_box=True, **kwargs)
class SimpleSamTrainer(micro_sam.training.sam_trainer.SamTrainer):
 7class SimpleSamTrainer(SamTrainer):
 8    """Trainer class for creating a simple SAM trainer for limited prompt-based segmentation.
 9
10    This class is inherited from `SamTrainer`.
11    Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py
12    for details on its implementation.
13
14    Args:
15        use_points: Whether to use point prompts for interactive segmentation. By default, set to 'True'.
16        use_box: Whether to use box prompts for interactive segmentation. By default, set to 'True'.
17        kwargs: The keyword arguments of the `SamTrainer` (and `DefaultTrainer`) class.
18    """
19
20    def __init__(self, use_points: bool = True, use_box: bool = True, **kwargs):
21        super().__init__(n_sub_iteration=1, mask_prob=0, **kwargs)
22        self.use_points = use_points
23        self.use_box = use_box
24
25        if self.use_points and self.use_box:
26            self.random_prompt_choice = True
27        else:
28            self.random_prompt_choice = False
29
30        assert (self.use_points + self.use_box) != 0, "Please choose at least one of the prompt-based method."
31
32    def _choose_one_positive_point(self):
33        """Samples only a single positive point per object
34        """
35        n_pos, n_neg = 1, 0
36        multimask_output = True
37        return n_pos, n_neg, None, multimask_output
38
39    def _choose_box(self):
40        """Samples only a single box per object
41        """
42        n_pos, n_neg = 0, 0
43        multimask_output = False
44        get_boxes = True
45        return n_pos, n_neg, get_boxes, multimask_output
46
47    def _get_prompt_and_multimasking_choices(self, current_iteration):
48        if self.random_prompt_choice:  # both "use_points" and "use_box" are True
49            available_choices = [self._choose_one_positive_point(), self._choose_box()]
50            return random.choice(available_choices)
51        else:  # either of "use_points" or "use_box" are True
52            if self.use_points:
53                return self._choose_one_positive_point()
54            else:
55                return self._choose_box()
56
57    def _get_prompt_and_multimasking_choices_for_val(self, current_iteration):
58        return self._get_prompt_and_multimasking_choices(current_iteration)

Trainer class for creating a simple SAM trainer for limited prompt-based segmentation.

This class is inherited from SamTrainer. Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py for details on its implementation.

Arguments:
  • use_points: Whether to use point prompts for interactive segmentation. By default, set to 'True'.
  • use_box: Whether to use box prompts for interactive segmentation. By default, set to 'True'.
  • kwargs: The keyword arguments of the SamTrainer (and DefaultTrainer) class.
SimpleSamTrainer(use_points: bool = True, use_box: bool = True, **kwargs)
20    def __init__(self, use_points: bool = True, use_box: bool = True, **kwargs):
21        super().__init__(n_sub_iteration=1, mask_prob=0, **kwargs)
22        self.use_points = use_points
23        self.use_box = use_box
24
25        if self.use_points and self.use_box:
26            self.random_prompt_choice = True
27        else:
28            self.random_prompt_choice = False
29
30        assert (self.use_points + self.use_box) != 0, "Please choose at least one of the prompt-based method."
use_points
use_box
Inherited Members
micro_sam.training.sam_trainer.SamTrainer
convert_inputs
mse_loss
n_objects_per_batch
n_sub_iteration
prompt_generator
mask_prob
is_data_parallel
torch_em.trainer.default_trainer.DefaultTrainer
name
id_
train_loader
val_loader
model
loss
optimizer
metric
device
lr_scheduler
log_image_interval
save_root
compile_model
rank
mixed_precision
early_stopping
train_time
logger_class
logger_kwargs
checkpoint_folder
iteration
epoch
Deserializer
Serializer
fit
class MedSAMTrainer(SimpleSamTrainer):
61class MedSAMTrainer(SimpleSamTrainer):
62    """Trainer class for replicating the trainer of MedSAM (https://arxiv.org/abs/2304.12306).
63
64    This class is inherited from `SimpleSamTrainer`.
65    Check out
66    https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/simple_sam_trainer.py
67    for details on its implementation.
68    """
69
70    def __init__(self, **kwargs):
71        super().__init__(use_points=False, use_box=True, **kwargs)

Trainer class for replicating the trainer of MedSAM (https://arxiv.org/abs/2304.12306).

This class is inherited from SimpleSamTrainer. Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/simple_sam_trainer.py for details on its implementation.

MedSAMTrainer(**kwargs)
70    def __init__(self, **kwargs):
71        super().__init__(use_points=False, use_box=True, **kwargs)
Inherited Members
SimpleSamTrainer
use_points
use_box
micro_sam.training.sam_trainer.SamTrainer
convert_inputs
mse_loss
n_objects_per_batch
n_sub_iteration
prompt_generator
mask_prob
is_data_parallel
torch_em.trainer.default_trainer.DefaultTrainer
name
id_
train_loader
val_loader
model
loss
optimizer
metric
device
lr_scheduler
log_image_interval
save_root
compile_model
rank
mixed_precision
early_stopping
train_time
logger_class
logger_kwargs
checkpoint_folder
iteration
epoch
Deserializer
Serializer
fit