micro_sam.training.simple_sam_trainer

 1import random
 2
 3from . import SamTrainer
 4
 5
 6class SimpleSamTrainer(SamTrainer):
 7    """Trainer class for creating a simple SAM trainer for limited prompt-based segmentation.
 8
 9    This class is inherited from `SamTrainer`.
10    Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py
11    for details on its implementation.
12
13    Args:
14        use_points: Whether to use point prompts for interactive segmentation.
15        use_box: Whether to use box prompts for interactive segmentation.
16        kwargs: The keyword arguments of the `SamTrainer` (and `DefaultTrainer`) class.
17    """
18    def __init__(
19        self,
20        use_points: bool = True,
21        use_box: bool = True,
22        **kwargs
23    ):
24        super().__init__(
25            n_sub_iteration=1,
26            mask_prob=0,
27            **kwargs
28        )
29        self.use_points = use_points
30        self.use_box = use_box
31
32        if self.use_points and self.use_box:
33            self.random_prompt_choice = True
34        else:
35            self.random_prompt_choice = False
36
37        assert (self.use_points + self.use_box) != 0, "Please choose at least one of the prompt-based method."
38
39    def _choose_one_positive_point(self):
40        """Samples only a single positive point per object
41        """
42        n_pos, n_neg = 1, 0
43        multimask_output = True
44        return n_pos, n_neg, None, multimask_output
45
46    def _choose_box(self):
47        """Samples only a single box per object
48        """
49        n_pos, n_neg = 0, 0
50        multimask_output = False
51        get_boxes = True
52        return n_pos, n_neg, get_boxes, multimask_output
53
54    def _get_prompt_and_multimasking_choices(self, current_iteration):
55        if self.random_prompt_choice:  # both "use_points" and "use_box" are True
56            available_choices = [self._choose_one_positive_point(), self._choose_box()]
57            return random.choice(available_choices)
58        else:  # either of "use_points" or "use_box" are True
59            if self.use_points:
60                return self._choose_one_positive_point()
61            else:
62                return self._choose_box()
63
64    def _get_prompt_and_multimasking_choices_for_val(self, current_iteration):
65        return self._get_prompt_and_multimasking_choices(current_iteration)
66
67
68class MedSAMTrainer(SimpleSamTrainer):
69    """Trainer class for replicating the trainer of MedSAM (https://arxiv.org/abs/2304.12306).
70
71    This class is inherited from `SimpleSamTrainer`.
72    Check out
73    https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/simple_sam_trainer.py
74    for details on its implementation.
75    """
76    def __init__(self, **kwargs):
77        super().__init__(
78            use_points=False,
79            use_box=True,
80            **kwargs
81        )
class SimpleSamTrainer(micro_sam.training.sam_trainer.SamTrainer):
 7class SimpleSamTrainer(SamTrainer):
 8    """Trainer class for creating a simple SAM trainer for limited prompt-based segmentation.
 9
10    This class is inherited from `SamTrainer`.
11    Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py
12    for details on its implementation.
13
14    Args:
15        use_points: Whether to use point prompts for interactive segmentation.
16        use_box: Whether to use box prompts for interactive segmentation.
17        kwargs: The keyword arguments of the `SamTrainer` (and `DefaultTrainer`) class.
18    """
19    def __init__(
20        self,
21        use_points: bool = True,
22        use_box: bool = True,
23        **kwargs
24    ):
25        super().__init__(
26            n_sub_iteration=1,
27            mask_prob=0,
28            **kwargs
29        )
30        self.use_points = use_points
31        self.use_box = use_box
32
33        if self.use_points and self.use_box:
34            self.random_prompt_choice = True
35        else:
36            self.random_prompt_choice = False
37
38        assert (self.use_points + self.use_box) != 0, "Please choose at least one of the prompt-based method."
39
40    def _choose_one_positive_point(self):
41        """Samples only a single positive point per object
42        """
43        n_pos, n_neg = 1, 0
44        multimask_output = True
45        return n_pos, n_neg, None, multimask_output
46
47    def _choose_box(self):
48        """Samples only a single box per object
49        """
50        n_pos, n_neg = 0, 0
51        multimask_output = False
52        get_boxes = True
53        return n_pos, n_neg, get_boxes, multimask_output
54
55    def _get_prompt_and_multimasking_choices(self, current_iteration):
56        if self.random_prompt_choice:  # both "use_points" and "use_box" are True
57            available_choices = [self._choose_one_positive_point(), self._choose_box()]
58            return random.choice(available_choices)
59        else:  # either of "use_points" or "use_box" are True
60            if self.use_points:
61                return self._choose_one_positive_point()
62            else:
63                return self._choose_box()
64
65    def _get_prompt_and_multimasking_choices_for_val(self, current_iteration):
66        return self._get_prompt_and_multimasking_choices(current_iteration)

Trainer class for creating a simple SAM trainer for limited prompt-based segmentation.

This class is inherited from SamTrainer. Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py for details on its implementation.

Arguments:
  • use_points: Whether to use point prompts for interactive segmentation.
  • use_box: Whether to use box prompts for interactive segmentation.
  • kwargs: The keyword arguments of the SamTrainer (and DefaultTrainer) class.
SimpleSamTrainer(use_points: bool = True, use_box: bool = True, **kwargs)
19    def __init__(
20        self,
21        use_points: bool = True,
22        use_box: bool = True,
23        **kwargs
24    ):
25        super().__init__(
26            n_sub_iteration=1,
27            mask_prob=0,
28            **kwargs
29        )
30        self.use_points = use_points
31        self.use_box = use_box
32
33        if self.use_points and self.use_box:
34            self.random_prompt_choice = True
35        else:
36            self.random_prompt_choice = False
37
38        assert (self.use_points + self.use_box) != 0, "Please choose at least one of the prompt-based method."
use_points
use_box
Inherited Members
micro_sam.training.sam_trainer.SamTrainer
convert_inputs
mse_loss
n_objects_per_batch
n_sub_iteration
prompt_generator
mask_prob
torch_em.trainer.default_trainer.DefaultTrainer
name
id_
train_loader
val_loader
model
loss
optimizer
metric
device
lr_scheduler
log_image_interval
save_root
compile_model
rank
mixed_precision
early_stopping
train_time
logger_class
logger_kwargs
checkpoint_folder
iteration
epoch
Deserializer
from_checkpoint
Serializer
save_checkpoint
load_checkpoint
fit
class MedSAMTrainer(SimpleSamTrainer):
69class MedSAMTrainer(SimpleSamTrainer):
70    """Trainer class for replicating the trainer of MedSAM (https://arxiv.org/abs/2304.12306).
71
72    This class is inherited from `SimpleSamTrainer`.
73    Check out
74    https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/simple_sam_trainer.py
75    for details on its implementation.
76    """
77    def __init__(self, **kwargs):
78        super().__init__(
79            use_points=False,
80            use_box=True,
81            **kwargs
82        )

Trainer class for replicating the trainer of MedSAM (https://arxiv.org/abs/2304.12306).

This class is inherited from SimpleSamTrainer. Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/simple_sam_trainer.py for details on its implementation.

MedSAMTrainer(**kwargs)
77    def __init__(self, **kwargs):
78        super().__init__(
79            use_points=False,
80            use_box=True,
81            **kwargs
82        )
Inherited Members
SimpleSamTrainer
use_points
use_box
micro_sam.training.sam_trainer.SamTrainer
convert_inputs
mse_loss
n_objects_per_batch
n_sub_iteration
prompt_generator
mask_prob
torch_em.trainer.default_trainer.DefaultTrainer
name
id_
train_loader
val_loader
model
loss
optimizer
metric
device
lr_scheduler
log_image_interval
save_root
compile_model
rank
mixed_precision
early_stopping
train_time
logger_class
logger_kwargs
checkpoint_folder
iteration
epoch
Deserializer
from_checkpoint
Serializer
save_checkpoint
load_checkpoint
fit