micro_sam.training.simple_sam_trainer

 1import random
 2
 3from . import SamTrainer
 4
 5
 6class SimpleSamTrainer(SamTrainer):
 7    """Trainer class for creating a simple SAM trainer for limited prompt-based segmentation.
 8
 9    This class is inherited from `SamTrainer`.
10    Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py
11    for details on its implementation.
12
13    Args:
14        use_points: Whether to use point prompts for interactive segmentation.
15        use_box: Whether to use box prompts for interactive segmentation.
16        kwargs: The keyword arguments of the `SamTrainer` (and `DefaultTrainer`) class.
17    """
18
19    def __init__(self, use_points: bool = True, use_box: bool = True, **kwargs):
20        super().__init__(
21            n_sub_iteration=1,
22            mask_prob=0,
23            **kwargs
24        )
25        self.use_points = use_points
26        self.use_box = use_box
27
28        if self.use_points and self.use_box:
29            self.random_prompt_choice = True
30        else:
31            self.random_prompt_choice = False
32
33        assert (self.use_points + self.use_box) != 0, "Please choose at least one of the prompt-based method."
34
35    def _choose_one_positive_point(self):
36        """Samples only a single positive point per object
37        """
38        n_pos, n_neg = 1, 0
39        multimask_output = True
40        return n_pos, n_neg, None, multimask_output
41
42    def _choose_box(self):
43        """Samples only a single box per object
44        """
45        n_pos, n_neg = 0, 0
46        multimask_output = False
47        get_boxes = True
48        return n_pos, n_neg, get_boxes, multimask_output
49
50    def _get_prompt_and_multimasking_choices(self, current_iteration):
51        if self.random_prompt_choice:  # both "use_points" and "use_box" are True
52            available_choices = [self._choose_one_positive_point(), self._choose_box()]
53            return random.choice(available_choices)
54        else:  # either of "use_points" or "use_box" are True
55            if self.use_points:
56                return self._choose_one_positive_point()
57            else:
58                return self._choose_box()
59
60    def _get_prompt_and_multimasking_choices_for_val(self, current_iteration):
61        return self._get_prompt_and_multimasking_choices(current_iteration)
62
63
64class MedSAMTrainer(SimpleSamTrainer):
65    """Trainer class for replicating the trainer of MedSAM (https://arxiv.org/abs/2304.12306).
66
67    This class is inherited from `SimpleSamTrainer`.
68    Check out
69    https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/simple_sam_trainer.py
70    for details on its implementation.
71    """
72
73    def __init__(self, **kwargs):
74        super().__init__(use_points=False, use_box=True, **kwargs)
class SimpleSamTrainer(micro_sam.training.sam_trainer.SamTrainer):
 7class SimpleSamTrainer(SamTrainer):
 8    """Trainer class for creating a simple SAM trainer for limited prompt-based segmentation.
 9
10    This class is inherited from `SamTrainer`.
11    Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py
12    for details on its implementation.
13
14    Args:
15        use_points: Whether to use point prompts for interactive segmentation.
16        use_box: Whether to use box prompts for interactive segmentation.
17        kwargs: The keyword arguments of the `SamTrainer` (and `DefaultTrainer`) class.
18    """
19
20    def __init__(self, use_points: bool = True, use_box: bool = True, **kwargs):
21        super().__init__(
22            n_sub_iteration=1,
23            mask_prob=0,
24            **kwargs
25        )
26        self.use_points = use_points
27        self.use_box = use_box
28
29        if self.use_points and self.use_box:
30            self.random_prompt_choice = True
31        else:
32            self.random_prompt_choice = False
33
34        assert (self.use_points + self.use_box) != 0, "Please choose at least one of the prompt-based method."
35
36    def _choose_one_positive_point(self):
37        """Samples only a single positive point per object
38        """
39        n_pos, n_neg = 1, 0
40        multimask_output = True
41        return n_pos, n_neg, None, multimask_output
42
43    def _choose_box(self):
44        """Samples only a single box per object
45        """
46        n_pos, n_neg = 0, 0
47        multimask_output = False
48        get_boxes = True
49        return n_pos, n_neg, get_boxes, multimask_output
50
51    def _get_prompt_and_multimasking_choices(self, current_iteration):
52        if self.random_prompt_choice:  # both "use_points" and "use_box" are True
53            available_choices = [self._choose_one_positive_point(), self._choose_box()]
54            return random.choice(available_choices)
55        else:  # either of "use_points" or "use_box" are True
56            if self.use_points:
57                return self._choose_one_positive_point()
58            else:
59                return self._choose_box()
60
61    def _get_prompt_and_multimasking_choices_for_val(self, current_iteration):
62        return self._get_prompt_and_multimasking_choices(current_iteration)

Trainer class for creating a simple SAM trainer for limited prompt-based segmentation.

This class is inherited from SamTrainer. Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py for details on its implementation.

Arguments:
  • use_points: Whether to use point prompts for interactive segmentation.
  • use_box: Whether to use box prompts for interactive segmentation.
  • kwargs: The keyword arguments of the SamTrainer (and DefaultTrainer) class.
SimpleSamTrainer(use_points: bool = True, use_box: bool = True, **kwargs)
20    def __init__(self, use_points: bool = True, use_box: bool = True, **kwargs):
21        super().__init__(
22            n_sub_iteration=1,
23            mask_prob=0,
24            **kwargs
25        )
26        self.use_points = use_points
27        self.use_box = use_box
28
29        if self.use_points and self.use_box:
30            self.random_prompt_choice = True
31        else:
32            self.random_prompt_choice = False
33
34        assert (self.use_points + self.use_box) != 0, "Please choose at least one of the prompt-based method."
use_points
use_box
Inherited Members
micro_sam.training.sam_trainer.SamTrainer
convert_inputs
mse_loss
n_objects_per_batch
n_sub_iteration
prompt_generator
mask_prob
torch_em.trainer.default_trainer.DefaultTrainer
name
id_
train_loader
val_loader
model
loss
optimizer
metric
device
lr_scheduler
log_image_interval
save_root
compile_model
rank
mixed_precision
early_stopping
train_time
logger_class
logger_kwargs
checkpoint_folder
iteration
epoch
Deserializer
Serializer
fit
class MedSAMTrainer(SimpleSamTrainer):
65class MedSAMTrainer(SimpleSamTrainer):
66    """Trainer class for replicating the trainer of MedSAM (https://arxiv.org/abs/2304.12306).
67
68    This class is inherited from `SimpleSamTrainer`.
69    Check out
70    https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/simple_sam_trainer.py
71    for details on its implementation.
72    """
73
74    def __init__(self, **kwargs):
75        super().__init__(use_points=False, use_box=True, **kwargs)

Trainer class for replicating the trainer of MedSAM (https://arxiv.org/abs/2304.12306).

This class is inherited from SimpleSamTrainer. Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/simple_sam_trainer.py for details on its implementation.

MedSAMTrainer(**kwargs)
74    def __init__(self, **kwargs):
75        super().__init__(use_points=False, use_box=True, **kwargs)
Inherited Members
SimpleSamTrainer
use_points
use_box
micro_sam.training.sam_trainer.SamTrainer
convert_inputs
mse_loss
n_objects_per_batch
n_sub_iteration
prompt_generator
mask_prob
torch_em.trainer.default_trainer.DefaultTrainer
name
id_
train_loader
val_loader
model
loss
optimizer
metric
device
lr_scheduler
log_image_interval
save_root
compile_model
rank
mixed_precision
early_stopping
train_time
logger_class
logger_kwargs
checkpoint_folder
iteration
epoch
Deserializer
Serializer
fit