micro_sam.sam_annotator.annotator_tracking
1from typing import Optional, Tuple, Union 2 3import napari 4import numpy as np 5 6import torch 7 8from magicgui.widgets import ComboBox, Container 9 10from .. import util 11from . import util as vutil 12from . import _widgets as widgets 13from ._tooltips import get_tooltip 14from ._state import AnnotatorState 15from ._annotator import _AnnotatorBase 16 17 18# Cyan (track) and Magenta (division) 19STATE_COLOR_CYCLE = ["#00FFFF", "#FF00FF", ] 20"""@private""" 21 22 23# This solution is a bit hacky, so I won't move it to _widgets.py yet. 24def create_tracking_menu(points_layer, box_layer, states, track_ids): 25 """@private""" 26 state = AnnotatorState() 27 28 state_menu = ComboBox(label="track_state", choices=states, tooltip=get_tooltip("annotator_tracking", "track_state")) 29 track_id_menu = ComboBox( 30 label="track_id", choices=list(map(str, track_ids)), tooltip=get_tooltip("annotator_tracking", "track_id") 31 ) 32 tracking_widget = Container(widgets=[state_menu, track_id_menu]) 33 34 def update_state(event): 35 new_state = str(points_layer.current_properties["state"][0]) 36 if new_state != state_menu.value: 37 state_menu.value = new_state 38 39 def update_track_id(event): 40 new_id = str(points_layer.current_properties["track_id"][0]) 41 if new_id != track_id_menu.value: 42 track_id_menu.value = new_id 43 state.current_track_id = int(new_id) 44 45 # def update_state_boxes(event): 46 # new_state = str(box_layer.current_properties["state"][0]) 47 # if new_state != state_menu.value: 48 # state_menu.value = new_state 49 50 def update_track_id_boxes(event): 51 new_id = str(box_layer.current_properties["track_id"][0]) 52 if new_id != track_id_menu.value: 53 track_id_menu.value = new_id 54 state.current_track_id = int(new_id) 55 56 points_layer.events.current_properties.connect(update_state) 57 points_layer.events.current_properties.connect(update_track_id) 58 # box_layer.events.current_properties.connect(update_state_boxes) 59 box_layer.events.current_properties.connect(update_track_id_boxes) 60 61 def state_changed(new_state): 62 current_properties = points_layer.current_properties 63 current_properties["state"] = np.array([new_state]) 64 points_layer.current_properties = current_properties 65 points_layer.refresh_colors() 66 67 def track_id_changed(new_track_id): 68 current_properties = points_layer.current_properties 69 current_properties["track_id"] = np.array([new_track_id]) 70 # Note: this fails with a key error after committing a lineage with multiple tracks. 71 # I think this does not cause any further errors, so we just skip this. 72 try: 73 points_layer.current_properties = current_properties 74 except KeyError: 75 pass 76 state.current_track_id = int(new_track_id) 77 78 # def state_changed_boxes(new_state): 79 # current_properties = box_layer.current_properties 80 # current_properties["state"] = np.array([new_state]) 81 # box_layer.current_properties = current_properties 82 # box_layer.refresh_colors() 83 84 def track_id_changed_boxes(new_track_id): 85 current_properties = box_layer.current_properties 86 current_properties["track_id"] = np.array([new_track_id]) 87 box_layer.current_properties = current_properties 88 state.current_track_id = int(new_track_id) 89 90 state_menu.changed.connect(state_changed) 91 track_id_menu.changed.connect(track_id_changed) 92 # state_menu.changed.connect(state_changed_boxes) 93 track_id_menu.changed.connect(track_id_changed_boxes) 94 95 state_menu.set_choice("track") 96 return tracking_widget 97 98 99class AnnotatorTracking(_AnnotatorBase): 100 101 # The tracking annotator needs different settings for the prompt layers 102 # to support the additional tracking state. 103 # That's why we over-ride this function. 104 def _create_layers(self): 105 self._point_labels = ["positive", "negative"] 106 self._track_state_labels = ["track", "division"] 107 108 self._point_prompt_layer = self._viewer.add_points( 109 name="point_prompts", 110 property_choices={ 111 "label": self._point_labels, 112 "state": self._track_state_labels, 113 "track_id": ["1"], # we use string to avoid pandas warning 114 }, 115 border_color="label", 116 border_color_cycle=vutil.LABEL_COLOR_CYCLE, 117 symbol="o", 118 face_color="state", 119 face_color_cycle=STATE_COLOR_CYCLE, 120 border_width=0.4, 121 size=12, 122 ndim=self._ndim, 123 ) 124 self._point_prompt_layer.border_color_mode = "cycle" 125 self._point_prompt_layer.face_color_mode = "cycle" 126 127 # Using the box layer to set divisions currently doesn't work. 128 # That's why some of the code below is commented out. 129 self._box_prompt_layer = self._viewer.add_shapes( 130 shape_type="rectangle", 131 edge_width=4, 132 ndim=self._ndim, 133 face_color="transparent", 134 name="prompts", 135 edge_color="green", 136 property_choices={"track_id": ["1"]}, 137 # property_choces={"track_id": ["1"], "state": self._track_state_labels}, 138 # edge_color_cycle=STATE_COLOR_CYCLE, 139 ) 140 # self._box_prompt_layer.edge_color_mode = "cycle" 141 142 # Add the label layers for the current object, the automatic segmentation and the committed segmentation. 143 dummy_data = np.zeros(self._shape, dtype="uint32") 144 self._viewer.add_labels(data=dummy_data, name="current_object") 145 self._viewer.add_labels(data=dummy_data, name="auto_segmentation") 146 self._viewer.add_labels(data=dummy_data, name="committed_objects") 147 # Randomize colors so it is easy to see when object committed. 148 self._viewer.layers["committed_objects"].new_colormap() 149 150 def _get_widgets(self): 151 state = AnnotatorState() 152 # Create the tracking state menu. 153 self._tracking_widget = create_tracking_menu( 154 self._point_prompt_layer, self._box_prompt_layer, 155 states=self._track_state_labels, track_ids=list(state.lineage.keys()), 156 ) 157 segment_nd = widgets.SegmentNDWidget(self._viewer, tracking=True) 158 autotrack = widgets.AutoTrackWidget(self._viewer, with_decoder=self._with_decoder, volumetric=True) 159 return { 160 "tracking": self._tracking_widget, 161 "segment": widgets.segment_frame(), 162 "segment_nd": segment_nd, 163 "autosegment": autotrack, 164 "commit": widgets.commit_track(), 165 "clear": widgets.clear_track(), 166 } 167 168 def __init__(self, viewer: "napari.viewer.Viewer") -> None: 169 # Initialize the state for tracking. 170 self._init_track_state() 171 self._with_decoder = AnnotatorState().decoder is not None 172 super().__init__(viewer=viewer, ndim=3) 173 # Go to t=0. 174 self._viewer.dims.current_step = (0, 0, 0) + tuple(sh // 2 for sh in self._shape[1:]) 175 176 def _init_track_state(self): 177 state = AnnotatorState() 178 state.current_track_id = 1 179 state.lineage = {1: []} 180 state.committed_lineages = [] 181 182 def _update_image(self): 183 super()._update_image() 184 self._init_track_state() 185 state = AnnotatorState() 186 if self._with_decoder: 187 state.amg_state = vutil._load_is_state(state.embedding_path) 188 else: 189 state.amg_state = vutil._load_amg_state(state.embedding_path) 190 191 192def annotator_tracking( 193 image: np.ndarray, 194 embedding_path: Optional[str] = None, 195 # tracking_result: Optional[str] = None, 196 model_type: str = util._DEFAULT_MODEL, 197 tile_shape: Optional[Tuple[int, int]] = None, 198 halo: Optional[Tuple[int, int]] = None, 199 return_viewer: bool = False, 200 viewer: Optional["napari.viewer.Viewer"] = None, 201 precompute_amg_state: bool = False, 202 checkpoint_path: Optional[str] = None, 203 device: Optional[Union[str, torch.device]] = None, 204) -> Optional["napari.viewer.Viewer"]: 205 """Start the tracking annotation tool fora given timeseries. 206 207 Args: 208 image: The image data. 209 embedding_path: Filepath for saving the precomputed embeddings. 210 model_type: The Segment Anything model to use. For details on the available models check out 211 https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models. 212 tile_shape: Shape of tiles for tiled embedding prediction. 213 If `None` then the whole image is passed to Segment Anything. 214 halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders. 215 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 216 viewer: The viewer to which the Segment Anything functionality should be added. 217 This enables using a pre-initialized viewer. 218 precompute_amg_state: Whether to precompute the state for automatic mask generation. 219 This will take more time when precomputing embeddings, but will then make 220 automatic mask generation much faster. 221 checkpoint_path: Path to a custom checkpoint from which to load the SAM model. 222 device: The computational device to use for the SAM model. 223 224 Returns: 225 The napari viewer, only returned if `return_viewer=True`. 226 """ 227 228 # Initialize the predictor state. 229 state = AnnotatorState() 230 state.initialize_predictor( 231 image, model_type=model_type, save_path=embedding_path, 232 halo=halo, tile_shape=tile_shape, prefer_decoder=True, 233 ndim=3, checkpoint_path=checkpoint_path, device=device, 234 precompute_amg_state=precompute_amg_state, 235 ) 236 state.image_shape = image.shape[:-1] if image.ndim == 4 else image.shape 237 238 if viewer is None: 239 viewer = napari.Viewer() 240 241 viewer.add_image(image, name="image") 242 annotator = AnnotatorTracking(viewer) 243 244 # Trigger layer update of the annotator so that layers have the correct shape. 245 annotator._update_image() 246 247 # Add the annotator widget to the viewer and sync widgets. 248 viewer.window.add_dock_widget(annotator) 249 vutil._sync_embedding_widget( 250 widget=state.widgets["embeddings"], 251 model_type=model_type if checkpoint_path is None else state.predictor.model_type, 252 save_path=embedding_path, 253 checkpoint_path=checkpoint_path, 254 device=device, 255 tile_shape=tile_shape, 256 halo=halo, 257 ) 258 259 if return_viewer: 260 return viewer 261 262 napari.run() 263 264 265def main(): 266 """@private""" 267 parser = vutil._initialize_parser( 268 description="Run interactive segmentation for an image volume.", 269 with_segmentation_result=False, 270 with_instance_segmentation=False, 271 ) 272 273 # Tracking result is not yet supported, we need to also deserialize the lineage. 274 # parser.add_argument( 275 # "-t", "--tracking_result", 276 # help="Optional filepath to a precomputed tracking result. If passed this will be used to initialize the " 277 # "'committed_tracks' layer. This can be useful if you want to correct an existing tracking result or if you " 278 # "have saved intermediate results from the annotator and want to continue. " 279 # "Supports the same file formats as 'input'." 280 # ) 281 # parser.add_argument( 282 # "-tk", "--tracking_key", 283 # help="The key for opening the tracking result. Same rules as for 'key' apply." 284 # ) 285 286 args = parser.parse_args() 287 image = util.load_image_data(args.input, key=args.key) 288 289 annotator_tracking( 290 image, embedding_path=args.embedding_path, model_type=args.model_type, 291 tile_shape=args.tile_shape, halo=args.halo, 292 checkpoint_path=args.checkpoint, device=args.device, 293 )
100class AnnotatorTracking(_AnnotatorBase): 101 102 # The tracking annotator needs different settings for the prompt layers 103 # to support the additional tracking state. 104 # That's why we over-ride this function. 105 def _create_layers(self): 106 self._point_labels = ["positive", "negative"] 107 self._track_state_labels = ["track", "division"] 108 109 self._point_prompt_layer = self._viewer.add_points( 110 name="point_prompts", 111 property_choices={ 112 "label": self._point_labels, 113 "state": self._track_state_labels, 114 "track_id": ["1"], # we use string to avoid pandas warning 115 }, 116 border_color="label", 117 border_color_cycle=vutil.LABEL_COLOR_CYCLE, 118 symbol="o", 119 face_color="state", 120 face_color_cycle=STATE_COLOR_CYCLE, 121 border_width=0.4, 122 size=12, 123 ndim=self._ndim, 124 ) 125 self._point_prompt_layer.border_color_mode = "cycle" 126 self._point_prompt_layer.face_color_mode = "cycle" 127 128 # Using the box layer to set divisions currently doesn't work. 129 # That's why some of the code below is commented out. 130 self._box_prompt_layer = self._viewer.add_shapes( 131 shape_type="rectangle", 132 edge_width=4, 133 ndim=self._ndim, 134 face_color="transparent", 135 name="prompts", 136 edge_color="green", 137 property_choices={"track_id": ["1"]}, 138 # property_choces={"track_id": ["1"], "state": self._track_state_labels}, 139 # edge_color_cycle=STATE_COLOR_CYCLE, 140 ) 141 # self._box_prompt_layer.edge_color_mode = "cycle" 142 143 # Add the label layers for the current object, the automatic segmentation and the committed segmentation. 144 dummy_data = np.zeros(self._shape, dtype="uint32") 145 self._viewer.add_labels(data=dummy_data, name="current_object") 146 self._viewer.add_labels(data=dummy_data, name="auto_segmentation") 147 self._viewer.add_labels(data=dummy_data, name="committed_objects") 148 # Randomize colors so it is easy to see when object committed. 149 self._viewer.layers["committed_objects"].new_colormap() 150 151 def _get_widgets(self): 152 state = AnnotatorState() 153 # Create the tracking state menu. 154 self._tracking_widget = create_tracking_menu( 155 self._point_prompt_layer, self._box_prompt_layer, 156 states=self._track_state_labels, track_ids=list(state.lineage.keys()), 157 ) 158 segment_nd = widgets.SegmentNDWidget(self._viewer, tracking=True) 159 autotrack = widgets.AutoTrackWidget(self._viewer, with_decoder=self._with_decoder, volumetric=True) 160 return { 161 "tracking": self._tracking_widget, 162 "segment": widgets.segment_frame(), 163 "segment_nd": segment_nd, 164 "autosegment": autotrack, 165 "commit": widgets.commit_track(), 166 "clear": widgets.clear_track(), 167 } 168 169 def __init__(self, viewer: "napari.viewer.Viewer") -> None: 170 # Initialize the state for tracking. 171 self._init_track_state() 172 self._with_decoder = AnnotatorState().decoder is not None 173 super().__init__(viewer=viewer, ndim=3) 174 # Go to t=0. 175 self._viewer.dims.current_step = (0, 0, 0) + tuple(sh // 2 for sh in self._shape[1:]) 176 177 def _init_track_state(self): 178 state = AnnotatorState() 179 state.current_track_id = 1 180 state.lineage = {1: []} 181 state.committed_lineages = [] 182 183 def _update_image(self): 184 super()._update_image() 185 self._init_track_state() 186 state = AnnotatorState() 187 if self._with_decoder: 188 state.amg_state = vutil._load_is_state(state.embedding_path) 189 else: 190 state.amg_state = vutil._load_amg_state(state.embedding_path)
Base class for micro_sam annotation plugins.
Implements the logic for the 2d, 3d and tracking annotator. The annotators differ in their data dimensionality and the widgets.
AnnotatorTracking(viewer: napari.viewer.Viewer)
169 def __init__(self, viewer: "napari.viewer.Viewer") -> None: 170 # Initialize the state for tracking. 171 self._init_track_state() 172 self._with_decoder = AnnotatorState().decoder is not None 173 super().__init__(viewer=viewer, ndim=3) 174 # Go to t=0. 175 self._viewer.dims.current_step = (0, 0, 0) + tuple(sh // 2 for sh in self._shape[1:])
Create the annotator GUI.
Arguments:
- viewer: The napari viewer.
- ndim: The number of spatial dimension of the image data (2 or 3).
Inherited Members
- PyQt5.QtWidgets.QScrollArea
- alignment
- ensureVisible
- ensureWidgetVisible
- event
- eventFilter
- focusNextPrevChild
- resizeEvent
- scrollContentsBy
- setAlignment
- setWidget
- setWidgetResizable
- sizeHint
- takeWidget
- viewportSizeHint
- widget
- widgetResizable
- PyQt5.QtWidgets.QAbstractScrollArea
- SizeAdjustPolicy
- addScrollBarWidget
- contextMenuEvent
- cornerWidget
- dragEnterEvent
- dragLeaveEvent
- dragMoveEvent
- dropEvent
- horizontalScrollBar
- horizontalScrollBarPolicy
- keyPressEvent
- maximumViewportSize
- minimumSizeHint
- mouseDoubleClickEvent
- mouseMoveEvent
- mousePressEvent
- mouseReleaseEvent
- paintEvent
- scrollBarWidgets
- setCornerWidget
- setHorizontalScrollBar
- setHorizontalScrollBarPolicy
- setSizeAdjustPolicy
- setVerticalScrollBar
- setVerticalScrollBarPolicy
- setViewport
- setViewportMargins
- setupViewport
- sizeAdjustPolicy
- verticalScrollBar
- verticalScrollBarPolicy
- viewport
- viewportEvent
- viewportMargins
- wheelEvent
- AdjustIgnored
- AdjustToContents
- AdjustToContentsOnFirstShow
- PyQt5.QtWidgets.QFrame
- Shadow
- Shape
- StyleMask
- changeEvent
- drawFrame
- frameRect
- frameShadow
- frameShape
- frameStyle
- frameWidth
- initStyleOption
- lineWidth
- midLineWidth
- setFrameRect
- setFrameShadow
- setFrameShape
- setFrameStyle
- setLineWidth
- setMidLineWidth
- Box
- HLine
- NoFrame
- Panel
- Plain
- Raised
- Shadow_Mask
- Shape_Mask
- StyledPanel
- Sunken
- VLine
- WinPanel
- PyQt5.QtWidgets.QWidget
- RenderFlag
- RenderFlags
- acceptDrops
- accessibleDescription
- accessibleName
- actionEvent
- actions
- activateWindow
- addAction
- addActions
- adjustSize
- autoFillBackground
- backgroundRole
- baseSize
- childAt
- childrenRect
- childrenRegion
- clearFocus
- clearMask
- close
- closeEvent
- contentsMargins
- contentsRect
- contextMenuPolicy
- create
- createWindowContainer
- cursor
- destroy
- devType
- effectiveWinId
- ensurePolished
- enterEvent
- find
- focusInEvent
- focusNextChild
- focusOutEvent
- focusPolicy
- focusPreviousChild
- focusProxy
- focusWidget
- font
- fontInfo
- fontMetrics
- foregroundRole
- frameGeometry
- frameSize
- geometry
- getContentsMargins
- grab
- grabGesture
- grabKeyboard
- grabMouse
- grabShortcut
- graphicsEffect
- graphicsProxyWidget
- hasFocus
- hasHeightForWidth
- hasMouseTracking
- hasTabletTracking
- height
- heightForWidth
- hide
- hideEvent
- initPainter
- inputMethodEvent
- inputMethodHints
- inputMethodQuery
- insertAction
- insertActions
- isActiveWindow
- isAncestorOf
- isEnabled
- isEnabledTo
- isFullScreen
- isHidden
- isLeftToRight
- isMaximized
- isMinimized
- isModal
- isRightToLeft
- isVisible
- isVisibleTo
- isWindow
- isWindowModified
- keyReleaseEvent
- keyboardGrabber
- layout
- layoutDirection
- leaveEvent
- locale
- lower
- mapFrom
- mapFromGlobal
- mapFromParent
- mapTo
- mapToGlobal
- mapToParent
- mask
- maximumHeight
- maximumSize
- maximumWidth
- metric
- minimumHeight
- minimumSize
- minimumWidth
- mouseGrabber
- move
- moveEvent
- nativeEvent
- nativeParentWidget
- nextInFocusChain
- normalGeometry
- overrideWindowFlags
- overrideWindowState
- paintEngine
- palette
- parentWidget
- pos
- previousInFocusChain
- raise_
- rect
- releaseKeyboard
- releaseMouse
- releaseShortcut
- removeAction
- render
- repaint
- resize
- restoreGeometry
- saveGeometry
- screen
- scroll
- setAcceptDrops
- setAccessibleDescription
- setAccessibleName
- setAttribute
- setAutoFillBackground
- setBackgroundRole
- setBaseSize
- setContentsMargins
- setContextMenuPolicy
- setCursor
- setDisabled
- setEnabled
- setFixedHeight
- setFixedSize
- setFixedWidth
- setFocus
- setFocusPolicy
- setFocusProxy
- setFont
- setForegroundRole
- setGeometry
- setGraphicsEffect
- setHidden
- setInputMethodHints
- setLayout
- setLayoutDirection
- setLocale
- setMask
- setMaximumHeight
- setMaximumSize
- setMaximumWidth
- setMinimumHeight
- setMinimumSize
- setMinimumWidth
- setMouseTracking
- setPalette
- setParent
- setShortcutAutoRepeat
- setShortcutEnabled
- setSizeIncrement
- setSizePolicy
- setStatusTip
- setStyle
- setStyleSheet
- setTabOrder
- setTabletTracking
- setToolTip
- setToolTipDuration
- setUpdatesEnabled
- setVisible
- setWhatsThis
- setWindowFilePath
- setWindowFlag
- setWindowFlags
- setWindowIcon
- setWindowIconText
- setWindowModality
- setWindowModified
- setWindowOpacity
- setWindowRole
- setWindowState
- setWindowTitle
- show
- showEvent
- showFullScreen
- showMaximized
- showMinimized
- showNormal
- size
- sizeIncrement
- sizePolicy
- stackUnder
- statusTip
- style
- styleSheet
- tabletEvent
- testAttribute
- toolTip
- toolTipDuration
- underMouse
- ungrabGesture
- unsetCursor
- unsetLayoutDirection
- unsetLocale
- update
- updateGeometry
- updateMicroFocus
- updatesEnabled
- visibleRegion
- whatsThis
- width
- winId
- window
- windowFilePath
- windowFlags
- windowHandle
- windowIcon
- windowIconText
- windowModality
- windowOpacity
- windowRole
- windowState
- windowTitle
- windowType
- x
- y
- DrawChildren
- DrawWindowBackground
- IgnoreMask
- windowIconTextChanged
- windowIconChanged
- windowTitleChanged
- customContextMenuRequested
- PyQt5.QtCore.QObject
- blockSignals
- childEvent
- children
- connectNotify
- customEvent
- deleteLater
- disconnect
- disconnectNotify
- dumpObjectInfo
- dumpObjectTree
- dynamicPropertyNames
- findChild
- findChildren
- inherits
- installEventFilter
- isSignalConnected
- isWidgetType
- isWindowType
- killTimer
- metaObject
- moveToThread
- objectName
- parent
- property
- pyqtConfigure
- receivers
- removeEventFilter
- sender
- senderSignalIndex
- setObjectName
- setProperty
- signalsBlocked
- startTimer
- thread
- timerEvent
- tr
- staticMetaObject
- objectNameChanged
- destroyed
- PyQt5.QtGui.QPaintDevice
- PaintDeviceMetric
- colorCount
- depth
- devicePixelRatio
- devicePixelRatioF
- devicePixelRatioFScale
- heightMM
- logicalDpiX
- logicalDpiY
- paintingActive
- physicalDpiX
- physicalDpiY
- widthMM
- PdmDepth
- PdmDevicePixelRatio
- PdmDevicePixelRatioScaled
- PdmDpiX
- PdmDpiY
- PdmHeight
- PdmHeightMM
- PdmNumColors
- PdmPhysicalDpiX
- PdmPhysicalDpiY
- PdmWidth
- PdmWidthMM
def
annotator_tracking( image: numpy.ndarray, embedding_path: Optional[str] = None, model_type: str = 'vit_b_lm', tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, return_viewer: bool = False, viewer: Optional[napari.viewer.Viewer] = None, precompute_amg_state: bool = False, checkpoint_path: Optional[str] = None, device: Union[str, torch.device, NoneType] = None) -> Optional[napari.viewer.Viewer]:
193def annotator_tracking( 194 image: np.ndarray, 195 embedding_path: Optional[str] = None, 196 # tracking_result: Optional[str] = None, 197 model_type: str = util._DEFAULT_MODEL, 198 tile_shape: Optional[Tuple[int, int]] = None, 199 halo: Optional[Tuple[int, int]] = None, 200 return_viewer: bool = False, 201 viewer: Optional["napari.viewer.Viewer"] = None, 202 precompute_amg_state: bool = False, 203 checkpoint_path: Optional[str] = None, 204 device: Optional[Union[str, torch.device]] = None, 205) -> Optional["napari.viewer.Viewer"]: 206 """Start the tracking annotation tool fora given timeseries. 207 208 Args: 209 image: The image data. 210 embedding_path: Filepath for saving the precomputed embeddings. 211 model_type: The Segment Anything model to use. For details on the available models check out 212 https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models. 213 tile_shape: Shape of tiles for tiled embedding prediction. 214 If `None` then the whole image is passed to Segment Anything. 215 halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders. 216 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 217 viewer: The viewer to which the Segment Anything functionality should be added. 218 This enables using a pre-initialized viewer. 219 precompute_amg_state: Whether to precompute the state for automatic mask generation. 220 This will take more time when precomputing embeddings, but will then make 221 automatic mask generation much faster. 222 checkpoint_path: Path to a custom checkpoint from which to load the SAM model. 223 device: The computational device to use for the SAM model. 224 225 Returns: 226 The napari viewer, only returned if `return_viewer=True`. 227 """ 228 229 # Initialize the predictor state. 230 state = AnnotatorState() 231 state.initialize_predictor( 232 image, model_type=model_type, save_path=embedding_path, 233 halo=halo, tile_shape=tile_shape, prefer_decoder=True, 234 ndim=3, checkpoint_path=checkpoint_path, device=device, 235 precompute_amg_state=precompute_amg_state, 236 ) 237 state.image_shape = image.shape[:-1] if image.ndim == 4 else image.shape 238 239 if viewer is None: 240 viewer = napari.Viewer() 241 242 viewer.add_image(image, name="image") 243 annotator = AnnotatorTracking(viewer) 244 245 # Trigger layer update of the annotator so that layers have the correct shape. 246 annotator._update_image() 247 248 # Add the annotator widget to the viewer and sync widgets. 249 viewer.window.add_dock_widget(annotator) 250 vutil._sync_embedding_widget( 251 widget=state.widgets["embeddings"], 252 model_type=model_type if checkpoint_path is None else state.predictor.model_type, 253 save_path=embedding_path, 254 checkpoint_path=checkpoint_path, 255 device=device, 256 tile_shape=tile_shape, 257 halo=halo, 258 ) 259 260 if return_viewer: 261 return viewer 262 263 napari.run()
Start the tracking annotation tool fora given timeseries.
Arguments:
- image: The image data.
- embedding_path: Filepath for saving the precomputed embeddings.
- model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
- tile_shape: Shape of tiles for tiled embedding prediction.
If
None
then the whole image is passed to Segment Anything. - halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
- return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
- viewer: The viewer to which the Segment Anything functionality should be added. This enables using a pre-initialized viewer.
- precompute_amg_state: Whether to precompute the state for automatic mask generation. This will take more time when precomputing embeddings, but will then make automatic mask generation much faster.
- checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
- device: The computational device to use for the SAM model.
Returns:
The napari viewer, only returned if
return_viewer=True
.