micro_sam.sam_annotator.util
1import os 2import pickle 3import warnings 4import argparse 5from glob import glob 6from pathlib import Path 7from typing import List, Optional, Tuple 8 9import h5py 10import napari 11import numpy as np 12from skimage import draw 13from scipy.ndimage import shift 14 15from .. import prompt_based_segmentation, util 16from .. import _model_settings as model_settings 17from ..multi_dimensional_segmentation import _validate_projection 18 19# Green and Red 20LABEL_COLOR_CYCLE = ["#00FF00", "#FF0000"] 21"""@private""" 22 23 24# 25# Misc helper functions 26# 27 28 29def toggle_label(prompts): 30 """@private""" 31 # get the currently selected label 32 current_properties = prompts.current_properties 33 current_label = current_properties["label"][0] 34 new_label = "negative" if current_label == "positive" else "positive" 35 current_properties["label"] = np.array([new_label]) 36 prompts.current_properties = current_properties 37 prompts.refresh() 38 prompts.refresh_colors() 39 40 41def _initialize_parser(description, with_segmentation_result=True, with_instance_segmentation=True): 42 43 available_models = list(util.get_model_names()) 44 available_models = ", ".join(available_models) 45 46 parser = argparse.ArgumentParser(description=description) 47 48 parser.add_argument( 49 "-i", "--input", required=True, 50 help="The filepath to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) " 51 "or elf.io.open_file (e.g. hdf5, zarr, mrc). For the latter you also need to pass the 'key' parameter." 52 ) 53 parser.add_argument( 54 "-k", "--key", 55 help="The key for opening data with elf.io.open_file. This is the internal path for a hdf5 or zarr container, " 56 "for a image series it is a wild-card, e.g. '*.png' and for mrc it is 'data'." 57 ) 58 parser.add_argument( 59 "-e", "--embedding_path", 60 help="The filepath for saving/loading the pre-computed image embeddings. " 61 "It is recommended to pass this argument and store the embeddings if you want to open the annotator " 62 "multiple times for this image. Otherwise the embeddings will be recomputed every time." 63 ) 64 65 if with_segmentation_result: 66 parser.add_argument( 67 "-s", "--segmentation_result", 68 help="Optional filepath to a precomputed segmentation. If passed this will be used to initialize the " 69 "'committed_objects' layer. This can be useful if you want to correct an existing segmentation or if you " 70 "have saved intermediate results from the annotator and want to continue with your annotations. " 71 "Supports the same file formats as 'input'." 72 ) 73 parser.add_argument( 74 "-sk", "--segmentation_key", 75 help="The key for opening the segmentation data. Same rules as for 'key' apply." 76 ) 77 78 parser.add_argument( 79 "-m", "--model_type", default=util._DEFAULT_MODEL, 80 help=f"The segment anything model that will be used, one of {available_models}." 81 ) 82 parser.add_argument( 83 "-c", "--checkpoint", default=None, 84 help="Checkpoint from which the SAM model will be loaded." 85 ) 86 parser.add_argument( 87 "--decoder_path", default=None, 88 help="Optional checkpoint path to decoder-only weights to enable decoder-based instance segmentation." 89 ) 90 parser.add_argument( 91 "-d", "--device", default=None, 92 help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)." 93 "By default the most performant available device will be selected." 94 ) 95 parser.add_argument( 96 "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction", default=None 97 ) 98 parser.add_argument( 99 "--halo", nargs="+", type=int, help="The halo for using tiled prediction", default=None 100 ) 101 102 if with_instance_segmentation: 103 parser.add_argument( 104 "--precompute_amg_state", action="store_true", 105 help="Whether to precompute the state for automatic instance segmentation. " 106 "This will lead to a longer start-up time, but the automatic instance segmentation can " 107 "be run directly once the tool has started." 108 ) 109 parser.add_argument( 110 "--prefer_decoder", action="store_false", 111 help="Whether to use decoder based instance segmentation if the model " 112 "being used has an additional decoder for that purpose." 113 ) 114 115 return parser 116 117 118def clear_annotations(viewer: napari.Viewer, clear_segmentations=True) -> None: 119 """@private""" 120 viewer.layers["point_prompts"].data = [] 121 viewer.layers["point_prompts"].refresh() 122 if "prompts" in viewer.layers: 123 # Select all prompts and then remove them. 124 # This is how it worked before napari 0.5. 125 # viewer.layers["prompts"].data = [] 126 viewer.layers["prompts"].selected_data = set(range(len(viewer.layers["prompts"].data))) 127 viewer.layers["prompts"].remove_selected() 128 viewer.layers["prompts"].refresh() 129 if not clear_segmentations: 130 return 131 viewer.layers["current_object"].data = np.zeros(viewer.layers["current_object"].data.shape, dtype="uint32") 132 viewer.layers["current_object"].refresh() 133 134 135def clear_annotations_slice(viewer: napari.Viewer, i: int, clear_segmentations=True) -> None: 136 """@private""" 137 point_prompts = viewer.layers["point_prompts"].data 138 point_prompts = point_prompts[point_prompts[:, 0] != i] 139 viewer.layers["point_prompts"].data = point_prompts 140 viewer.layers["point_prompts"].refresh() 141 if "prompts" in viewer.layers: 142 prompts = viewer.layers["prompts"].data 143 prompts = [prompt for prompt in prompts if not (prompt[:, 0] == i).all()] 144 viewer.layers["prompts"].data = prompts 145 viewer.layers["prompts"].refresh() 146 if not clear_segmentations: 147 return 148 viewer.layers["current_object"].data[i] = 0 149 viewer.layers["current_object"].refresh() 150 151 152# 153# Helper functions to extract prompts from napari layers. 154# 155 156 157def point_layer_to_prompts( 158 layer: napari.layers.Points, i=None, track_id=None, with_stop_annotation=True, 159) -> Optional[Tuple[np.ndarray, np.ndarray]]: 160 """Extract point prompts for SAM from a napari point layer. 161 162 Args: 163 layer: The point layer from which to extract the prompts. 164 i: Index for the data (required for 3d or timeseries data). 165 track_id: Id of the current track (required for tracking data). 166 with_stop_annotation: Whether a single negative point will be interpreted 167 as stop annotation or just returned as normal prompt. 168 169 Returns: 170 The point coordinates for the prompts. 171 The labels (positive or negative / 1 or 0) for the prompts. 172 """ 173 174 points = layer.data 175 labels = layer.properties["label"] 176 assert len(points) == len(labels) 177 178 if i is None: 179 assert points.shape[1] == 2, f"{points.shape}" 180 this_points, this_labels = points, labels 181 else: 182 assert points.shape[1] == 3, f"{points.shape}" 183 mask = np.round(points[:, 0]) == i 184 this_points = points[mask][:, 1:] 185 this_labels = labels[mask] 186 assert len(this_points) == len(this_labels) 187 188 if track_id is not None: 189 assert i is not None 190 track_ids = np.array(list(map(int, layer.properties["track_id"])))[mask] 191 track_id_mask = track_ids == track_id 192 this_labels, this_points = this_labels[track_id_mask], this_points[track_id_mask] 193 assert len(this_points) == len(this_labels) 194 195 this_labels = np.array([1 if label == "positive" else 0 for label in this_labels]) 196 # a single point with a negative label is interpreted as 'stop' signal 197 # in this case we return None 198 if with_stop_annotation and (len(this_points) == 1 and this_labels[0] == 0): 199 return None 200 201 return this_points, this_labels 202 203 204def shape_layer_to_prompts( 205 layer: napari.layers.Shapes, shape: Tuple[int, int], i=None, track_id=None 206) -> Tuple[List[np.ndarray], List[Optional[np.ndarray]]]: 207 """Extract prompts for SAM from a napari shape layer. 208 209 Extracts the bounding box for 'rectangle' shapes and the bounding box and corresponding mask 210 for 'ellipse' and 'polygon' shapes. 211 212 Args: 213 prompt_layer: The napari shape layer. 214 shape: The image shape. 215 i: Index for the data (required for 3d or timeseries data). 216 track_id: Id of the current track (required for tracking data). 217 218 Returns: 219 The box prompts. 220 The mask prompts. 221 """ 222 223 def _to_prompts(shape_data, shape_types): 224 boxes, masks = [], [] 225 226 for data, type_ in zip(shape_data, shape_types): 227 228 if type_ == "rectangle": 229 boxes.append(data) 230 masks.append(None) 231 232 elif type_ == "ellipse": 233 boxes.append(data) 234 center = np.mean(data, axis=0) 235 radius_r = ((data[2] - data[1]) / 2)[0] 236 radius_c = ((data[1] - data[0]) / 2)[1] 237 rr, cc = draw.ellipse(center[0], center[1], radius_r, radius_c, shape=shape) 238 mask = np.zeros(shape, dtype=bool) 239 mask[rr, cc] = 1 240 masks.append(mask) 241 242 elif type_ == "polygon": 243 boxes.append(data) 244 rr, cc = draw.polygon(data[:, 0], data[:, 1], shape=shape) 245 mask = np.zeros(shape, dtype=bool) 246 mask[rr, cc] = 1 247 masks.append(mask) 248 249 else: 250 warnings.warn(f"Shape type {type_} is not supported and will be ignored.") 251 252 # map to correct box format 253 boxes = [ 254 np.array([box[:, 0].min(), box[:, 1].min(), box[:, 0].max(), box[:, 1].max()]) for box in boxes 255 ] 256 return boxes, masks 257 258 shape_data, shape_types = layer.data, layer.shape_type 259 assert len(shape_data) == len(shape_types) 260 if len(shape_data) == 0: 261 return [], [] 262 263 if i is not None: 264 if track_id is None: 265 prompt_selection = [j for j, data in enumerate(shape_data) if (data[:, 0] == i).all()] 266 else: 267 track_ids = np.array(list(map(int, layer.properties["track_id"]))) 268 prompt_selection = [ 269 j for j, (data, this_track_id) in enumerate(zip(shape_data, track_ids)) 270 if ((data[:, 0] == i).all() and this_track_id == track_id) 271 ] 272 273 shape_data = [shape_data[j][:, 1:] for j in prompt_selection] 274 shape_types = [shape_types[j] for j in prompt_selection] 275 276 boxes, masks = _to_prompts(shape_data, shape_types) 277 return boxes, masks 278 279 280def prompt_layer_to_state(prompt_layer: napari.layers.Points, i: int) -> str: 281 """Get the state of the track from a point layer for a given timeframe. 282 283 Only relevant for annotator_tracking. 284 285 Args: 286 prompt_layer: The napari layer. 287 i: Timeframe of the data. 288 289 Returns: 290 The state of this frame (either "division" or "track"). 291 """ 292 state = prompt_layer.properties["state"] 293 294 points = prompt_layer.data 295 assert points.shape[1] == 3, f"{points.shape}" 296 mask = points[:, 0] == i 297 this_points = points[mask][:, 1:] 298 this_state = state[mask] 299 assert len(this_points) == len(this_state) 300 301 # we set the state to 'division' if at least one point in this frame has a division label 302 if any(st == "division" for st in this_state): 303 return "division" 304 else: 305 return "track" 306 307 308def prompt_layers_to_state(point_layer: napari.layers.Points, box_layer: napari.layers.Shapes, i: int) -> str: 309 """Get the state of the track from a point layer and shape layer for a given timeframe. 310 311 Only relevant for annotator_tracking. 312 313 Args: 314 point_layer: The napari point layer. 315 box_layer: The napari box layer. 316 i: Timeframe of the data. 317 318 Returns: 319 The state of this frame (either "division" or "track"). 320 """ 321 state = point_layer.properties["state"] 322 323 points = point_layer.data 324 assert points.shape[1] == 3, f"{points.shape}" 325 mask = points[:, 0] == i 326 if mask.sum() > 0: 327 this_state = state[mask].tolist() 328 else: 329 this_state = [] 330 331 box_states = box_layer.properties["state"] 332 this_box_states = [ 333 state for box, state in zip(box_layer.data, box_states) 334 if (box[:, 0] == i).all() 335 ] 336 this_state.extend(this_box_states) 337 338 # we set the state to 'division' if at least one point in this frame has a division label 339 if any(st == "division" for st in this_state): 340 return "division" 341 else: 342 return "track" 343 344 345# 346# Helper functions to run (multi-dimensional) segmentation on napari layers. 347# 348 349 350def segment_slices_with_prompts( 351 predictor, point_prompts, box_prompts, image_embeddings, shape, track_id=None, update_progress=None, 352): 353 """@private""" 354 assert len(shape) == 3 355 image_shape = shape[1:] 356 seg = np.zeros(shape, dtype="uint32") 357 358 z_values = np.round(point_prompts.data[:, 0]) 359 z_values_boxes = np.concatenate([box[:1, 0] for box in box_prompts.data]) if box_prompts.data else\ 360 np.zeros(0, dtype="int") 361 362 if track_id is not None: 363 track_ids_points = np.array(list(map(int, point_prompts.properties["track_id"]))) 364 assert len(track_ids_points) == len(z_values) 365 z_values = z_values[track_ids_points == track_id] 366 367 if len(z_values_boxes) > 0: 368 track_ids_boxes = np.array(list(map(int, box_prompts.properties["track_id"]))) 369 assert len(track_ids_boxes) == len(z_values_boxes), f"{len(track_ids_boxes)}, {len(z_values_boxes)}" 370 z_values_boxes = z_values_boxes[track_ids_boxes == track_id] 371 372 slices = np.unique(np.concatenate([z_values, z_values_boxes])).astype("int") 373 stop_lower, stop_upper = False, False 374 375 if update_progress is None: 376 def update_progress(*args): 377 pass 378 379 for i in slices: 380 points_i = point_layer_to_prompts(point_prompts, i, track_id) 381 382 # do we end the segmentation at the outer slices? 383 if points_i is None: 384 385 if i == slices[0]: # The bottom slice is a stop slice. 386 stop_lower = True 387 seg[i] = 0 388 elif i == slices[-1]: # The top sloce is a stop slice. 389 stop_upper = True 390 seg[i] = 0 391 else: # We have a stop annotation somewhere in the middle. Ignore this. 392 # Remove this slice from the annotated slices, so that it is segmented via 393 # projection in the next step. 394 slices = np.setdiff1d(slices, i) 395 print(f"You have provided a stop annotation (single red point) in slice {i},") 396 print("but you have annotated slices above or below it. This stop annotation will") 397 print(f"be ignored and the slice {i} will be segmented normally.") 398 399 update_progress(1) 400 continue 401 402 boxes, masks = shape_layer_to_prompts(box_prompts, image_shape, i=i, track_id=track_id) 403 points, labels = points_i 404 405 seg_i = prompt_segmentation( 406 predictor, points, labels, boxes, masks, image_shape, multiple_box_prompts=False, 407 image_embeddings=image_embeddings, i=i 408 ) 409 if seg_i is None: 410 print(f"The prompts at slice or frame {i} are invalid and the segmentation was skipped.") 411 print("This will lead to a wrong segmentation across slices or frames.") 412 print(f"Please correct the prompts in {i} and rerun the segmentation.") 413 continue 414 415 seg[i] = seg_i 416 update_progress(1) 417 418 return seg, slices, stop_lower, stop_upper 419 420 421# For advanced batching: match prompts to already segmented objects and continue segmentation. 422def _match_prompts(previous_segmentation, points, boxes, seg_ids): 423 # Create a mapping between ids and prompts. 424 batched_prompts = {} 425 # seg_boundaries = find_boundaries(previous_segmentation, mode="inner") 426 # indices = distance_transform_edt(seg_boundaries, return_distance=False, return_index=True) 427 return batched_prompts 428 429 430def _batched_interactive_segmentation(predictor, points, labels, boxes, image_embeddings, i, previous_segmentation): 431 prev_seg = previous_segmentation if i is None else previous_segmentation[i] 432 seg = np.zeros(prev_seg.shape, dtype="uint32") 433 434 # seg_ids = np.unique(previous_segmentation) 435 # assert seg_ids[0] == 0 436 437 batched_points, batched_labels = [], [] 438 negative_points, negative_labels = [], [] 439 for j in range(len(points)): 440 if labels[j] == 1: # positive point 441 batched_points.append(points[j:j+1]) 442 batched_labels.append(labels[j:j+1]) 443 else: # negative points 444 negative_points.append(points[j:j+1]) 445 negative_labels.append(labels[j:j+1]) 446 447 batched_prompts = [(None, point, label) for point, label in zip(batched_points, batched_labels)] 448 batched_prompts.extend([(box, None, None) for box in boxes]) 449 batched_prompts = {i: prompt for i, prompt in enumerate(batched_prompts, 1)} 450 451 # For advanced batching: match prompts to already segmented objects and continue segmentation. 452 # (This is left here as a reference for how this can be implemented. 453 # I have not decided yet if this is actually a good idea or not.) 454 # # If we have no objects: this is the first call for a batched segmentation. 455 # # We treat each positive point or box as a separate object. 456 # if len(seg_ids) == 1: 457 # # Create a list of all prompts. 458 # batched_prompts = [(None, point, label) for point, label in zip(batched_points, batched_labels)] 459 # batched_prompts.extend([(box, None, None) for box in boxes]) 460 # batched_prompts = {i: prompt for i, prompt in enumerate(batched_prompts, 1)} 461 462 # # Otherwise we match the prompts to existing objects. 463 # else: 464 # batched_prompts = _match_prompts(prev_seg, batched_points, boxes, seg_ids) 465 466 for seg_id, prompt in batched_prompts.items(): 467 box, point, label = prompt 468 if len(negative_points) > 0: 469 if point is None: 470 point, label = negative_points, negative_labels 471 else: 472 point = np.concatenate([point] + negative_points) 473 label = np.concatenate([label] + negative_labels) 474 475 if (box is not None) and (point is not None): 476 prediction = prompt_based_segmentation.segment_from_box_and_points( 477 predictor, box, point, label, image_embeddings=image_embeddings, i=i 478 ).squeeze() 479 elif (box is not None) and (point is None): 480 prediction = prompt_based_segmentation.segment_from_box( 481 predictor, box, image_embeddings=image_embeddings, i=i 482 ).squeeze() 483 else: 484 prediction = prompt_based_segmentation.segment_from_points( 485 predictor, point, label, image_embeddings=image_embeddings, i=i 486 ).squeeze() 487 488 seg[prediction] = seg_id 489 490 return seg 491 492 493def prompt_segmentation( 494 predictor, points, labels, boxes, masks, shape, multiple_box_prompts, 495 image_embeddings=None, i=None, box_extension=0, batched=None, previous_segmentation=None, 496): 497 """@private""" 498 assert len(points) == len(labels) 499 have_points = len(points) > 0 500 have_boxes = len(boxes) > 0 501 502 # No prompts were given, return None. 503 if not have_points and not have_boxes: 504 return 505 506 # Batched interactive segmentation. 507 elif batched: 508 assert previous_segmentation is not None 509 seg = _batched_interactive_segmentation( 510 predictor, points, labels, boxes, image_embeddings, i, previous_segmentation 511 ) 512 513 # Box and point prompts were given. 514 elif have_points and have_boxes: 515 if len(boxes) > 1: 516 print("You have provided point prompts and more than one box prompt.") 517 print("This setting is currently not supported.") 518 print("When providing both points and prompts you can only segment one object at a time.") 519 return 520 mask = masks[0] 521 if mask is None: 522 seg = prompt_based_segmentation.segment_from_box_and_points( 523 predictor, boxes[0], points, labels, image_embeddings=image_embeddings, i=i 524 ).squeeze() 525 else: 526 seg = prompt_based_segmentation.segment_from_mask( 527 predictor, mask, box=boxes[0], points=points, labels=labels, image_embeddings=image_embeddings, i=i 528 ).squeeze() 529 530 # Only point prompts were given. 531 elif have_points and not have_boxes: 532 seg = prompt_based_segmentation.segment_from_points( 533 predictor, points, labels, image_embeddings=image_embeddings, i=i 534 ).squeeze() 535 536 # Only box prompts were given. 537 elif not have_points and have_boxes: 538 seg = np.zeros(shape, dtype="uint32") 539 540 if len(boxes) > 1 and not multiple_box_prompts: 541 print("You have provided more than one box annotation. This is not yet supported in the 3d annotator.") 542 print("You can only segment one object at a time in 3d.") 543 return 544 545 # Batch this? 546 for seg_id, (box, mask) in enumerate(zip(boxes, masks), 1): 547 if mask is None: 548 prediction = prompt_based_segmentation.segment_from_box( 549 predictor, box, image_embeddings=image_embeddings, i=i 550 ).squeeze() 551 else: 552 prediction = prompt_based_segmentation.segment_from_mask( 553 predictor, mask, box=box, image_embeddings=image_embeddings, i=i, 554 box_extension=box_extension, 555 ).squeeze() 556 seg[prediction] = seg_id 557 558 return seg 559 560 561def _compute_movement(seg, t0, t1): 562 563 def compute_center(t): 564 # computation with center of mass 565 center = np.where(seg[t] == 1) 566 center = np.array([np.mean(center[0]), np.mean(center[1])]) 567 return center 568 569 center0 = compute_center(t0) 570 center1 = compute_center(t1) 571 572 move = center0 - center1 573 return move.astype("float64") 574 575 576def _shift_object(mask, motion_model): 577 mask_shifted = np.zeros_like(mask) 578 shift(mask, motion_model, output=mask_shifted, order=0, prefilter=False) 579 return mask_shifted 580 581 582def track_from_prompts( 583 point_prompts, box_prompts, seg, predictor, slices, image_embeddings, 584 stop_upper, threshold, projection, motion_smoothing=0.5, box_extension=0, update_progress=None, 585): 586 """@private 587 """ 588 use_box, use_mask, use_points, use_single_point = _validate_projection(projection) 589 590 if update_progress is None: 591 def update_progress(*args): 592 pass 593 594 # shift the segmentation based on the motion model and update the motion model 595 def _update_motion_model(seg, t, t0, motion_model): 596 if t in (t0, t0 + 1): # this is the first or second frame, we don't have a motion yet 597 pass 598 elif t == t0 + 2: # this the third frame, we initialize the motion model 599 current_move = _compute_movement(seg, t - 1, t - 2) 600 motion_model = current_move 601 else: # we already have a motion model and update it 602 current_move = _compute_movement(seg, t - 1, t - 2) 603 alpha = motion_smoothing 604 motion_model = alpha * motion_model + (1 - alpha) * current_move 605 606 return motion_model 607 608 has_division = False 609 motion_model = None 610 verbose = False 611 612 t0 = int(slices.min()) 613 t = t0 + 1 614 while True: 615 616 # update the motion model 617 motion_model = _update_motion_model(seg, t, t0, motion_model) 618 619 # use the segmentation from prompts if we are in a slice with prompts 620 if t in slices: 621 seg_prev = None 622 seg_t = seg[t] 623 # currently using the box layer doesn't work for keeping track of the track state 624 # track_state = prompt_layers_to_state(point_prompts, box_prompts, t) 625 track_state = prompt_layer_to_state(point_prompts, t) 626 627 # otherwise project the mask (under the motion model) and segment the next slice from the mask 628 else: 629 if verbose: 630 print(f"Tracking object in frame {t} with movement {motion_model}") 631 632 seg_prev = seg[t - 1] 633 # shift the segmentation according to the motion model 634 if motion_model is not None: 635 seg_prev = _shift_object(seg_prev, motion_model) 636 637 seg_t = prompt_based_segmentation.segment_from_mask( 638 predictor, seg_prev, image_embeddings=image_embeddings, i=t, 639 use_mask=use_mask, use_box=use_box, use_points=use_points, 640 box_extension=box_extension, use_single_point=use_single_point, 641 ) 642 track_state = "track" 643 644 # are we beyond the last slice with prompt? 645 # if no: we continue tracking because we know we need to connect to a future frame 646 # if yes: we only continue tracking if overlaps are above the threshold 647 if t < slices[-1]: 648 seg_prev = None 649 650 update_progress(1) 651 652 if (threshold is not None) and (seg_prev is not None): 653 iou = util.compute_iou(seg_prev, seg_t) 654 if iou < threshold: 655 msg = f"Segmentation stopped at frame {t} due to IOU {iou} < {threshold}." 656 print(msg) 657 break 658 659 # stop if we have a division 660 if track_state == "division": 661 has_division = True 662 break 663 664 seg[t] = seg_t 665 t += 1 666 667 # stop tracking if we have stop upper set (i.e. single negative point was set to indicate stop track) 668 if t == slices[-1] and stop_upper: 669 break 670 671 # stop if we are at the last slce 672 if t == seg.shape[0]: 673 break 674 675 return seg, has_division 676 677 678def _sync_embedding_widget(widget, model_type, save_path, checkpoint_path, device, tile_shape, halo): 679 680 # Update the index for model family, eg. 'Natural Images (SAM)', 'Light Microscopy', etc. 681 supported_dropdown_maps = { 682 "lm": "Light Microscopy", 683 "em_organelles": "Electron Microscopy", 684 "medical_imaging": "Medical Imaging", 685 "histopathology": "Histopathology", 686 } 687 688 model_family = "Natural Images (SAM)" # If no suffix patterns match, stick to 'Natural Images (SAM)' family. 689 for k, v in supported_dropdown_maps.items(): 690 if model_type.endswith(k): 691 model_family = v 692 break 693 694 index = widget.model_family_dropdown.findText(model_family) 695 if index > 0: 696 widget.model_family_dropdown.setCurrentIndex(index) 697 698 # Update the index for model size, eg. 'base', 'tiny', etc. 699 size_map = {"t": "tiny", "b": "base", "l": "large", "h": "huge"} 700 model_size = size_map[model_type[4]] 701 702 index = widget.model_size_dropdown.findText(model_size) 703 if index > 0: 704 widget.model_size_dropdown.setCurrentIndex(index) 705 706 if save_path is not None and isinstance(save_path, str): 707 widget.embeddings_save_path_param.setText(str(save_path)) 708 709 if checkpoint_path is not None: 710 widget.custom_weights_param.setText(str(checkpoint_path)) 711 712 if device is not None: 713 widget.device = device 714 index = widget.device_dropdown.findText(device) 715 widget.device_dropdown.setCurrentIndex(index) 716 717 if tile_shape is not None: 718 widget.tile_x_param.setValue(tile_shape[0]) 719 widget.tile_y_param.setValue(tile_shape[1]) 720 721 if halo is not None: 722 widget.halo_x_param.setValue(halo[0]) 723 widget.halo_y_param.setValue(halo[1]) 724 725 726# Read parameters from checkpoint path if it is given instead. 727def _sync_autosegment_widget(widget, model_type, checkpoint_path, update_decoder=None): 728 if update_decoder is not None: 729 widget._reset_segmentation_mode(update_decoder) 730 731 if widget.with_decoder: 732 settings = model_settings.AIS_SETTINGS.get(model_type, {}) 733 params = ("center_distance_thresh", "boundary_distance_thresh") 734 for param in params: 735 if param in settings: 736 getattr(widget, f"{param}_param").setValue(settings[param]) 737 else: 738 settings = model_settings.AMG_SETTINGS.get(model_type, {}) 739 params = ("pred_iou_thresh", "stability_score_thresh", "min_object_size") 740 for param in params: 741 if param in settings: 742 getattr(widget, f"{param}_param").setValue(settings[param]) 743 744 745# Read parameters from checkpoint path if it is given instead. 746def _sync_ndsegment_widget(widget, model_type, checkpoint_path): 747 settings = model_settings.ND_SEGMENT_SETTINGS.get(model_type, {}) 748 749 if "projection_mode" in settings: 750 projection_mode = settings["projection_mode"] 751 widget.projection = projection_mode 752 index = widget.projection_dropdown.findText(projection_mode) 753 if index > 0: 754 widget.projection_dropdown.setCurrentIndex(index) 755 756 params = ("iou_threshold", "box_extension") 757 for param in params: 758 if param in settings: 759 getattr(widget, f"{param}_param").setValue(settings[param]) 760 761 762def _load_amg_state(embedding_path): 763 if embedding_path is None or not os.path.exists(embedding_path): 764 return {"cache_folder": None} 765 766 cache_folder = os.path.join(embedding_path, "amg_state") 767 os.makedirs(cache_folder, exist_ok=True) 768 amg_state = {"cache_folder": cache_folder} 769 770 state_paths = glob(os.path.join(cache_folder, "*.pkl")) 771 for path in state_paths: 772 with open(path, "rb") as f: 773 state = pickle.load(f) 774 i = int(Path(path).stem.split("-")[-1]) 775 amg_state[i] = state 776 return amg_state 777 778 779def _load_is_state(embedding_path): 780 if embedding_path is None or not os.path.exists(embedding_path): 781 return {"cache_path": None} 782 783 cache_path = os.path.join(embedding_path, "is_state.h5") 784 is_state = {"cache_path": cache_path} 785 786 with h5py.File(cache_path, "a") as f: 787 for name, g in f.items(): 788 i = int(name.split("-")[-1]) 789 state = { 790 "foreground": g["foreground"][:], 791 "boundary_distances": g["boundary_distances"][:], 792 "center_distances": g["center_distances"][:], 793 } 794 is_state[i] = state 795 796 return is_state
def
point_layer_to_prompts( layer: napari.layers.points.points.Points, i=None, track_id=None, with_stop_annotation=True) -> Optional[Tuple[numpy.ndarray, numpy.ndarray]]:
158def point_layer_to_prompts( 159 layer: napari.layers.Points, i=None, track_id=None, with_stop_annotation=True, 160) -> Optional[Tuple[np.ndarray, np.ndarray]]: 161 """Extract point prompts for SAM from a napari point layer. 162 163 Args: 164 layer: The point layer from which to extract the prompts. 165 i: Index for the data (required for 3d or timeseries data). 166 track_id: Id of the current track (required for tracking data). 167 with_stop_annotation: Whether a single negative point will be interpreted 168 as stop annotation or just returned as normal prompt. 169 170 Returns: 171 The point coordinates for the prompts. 172 The labels (positive or negative / 1 or 0) for the prompts. 173 """ 174 175 points = layer.data 176 labels = layer.properties["label"] 177 assert len(points) == len(labels) 178 179 if i is None: 180 assert points.shape[1] == 2, f"{points.shape}" 181 this_points, this_labels = points, labels 182 else: 183 assert points.shape[1] == 3, f"{points.shape}" 184 mask = np.round(points[:, 0]) == i 185 this_points = points[mask][:, 1:] 186 this_labels = labels[mask] 187 assert len(this_points) == len(this_labels) 188 189 if track_id is not None: 190 assert i is not None 191 track_ids = np.array(list(map(int, layer.properties["track_id"])))[mask] 192 track_id_mask = track_ids == track_id 193 this_labels, this_points = this_labels[track_id_mask], this_points[track_id_mask] 194 assert len(this_points) == len(this_labels) 195 196 this_labels = np.array([1 if label == "positive" else 0 for label in this_labels]) 197 # a single point with a negative label is interpreted as 'stop' signal 198 # in this case we return None 199 if with_stop_annotation and (len(this_points) == 1 and this_labels[0] == 0): 200 return None 201 202 return this_points, this_labels
Extract point prompts for SAM from a napari point layer.
Arguments:
- layer: The point layer from which to extract the prompts.
- i: Index for the data (required for 3d or timeseries data).
- track_id: Id of the current track (required for tracking data).
- with_stop_annotation: Whether a single negative point will be interpreted as stop annotation or just returned as normal prompt.
Returns:
The point coordinates for the prompts. The labels (positive or negative / 1 or 0) for the prompts.
def
shape_layer_to_prompts( layer: napari.layers.shapes.shapes.Shapes, shape: Tuple[int, int], i=None, track_id=None) -> Tuple[List[numpy.ndarray], List[Optional[numpy.ndarray]]]:
205def shape_layer_to_prompts( 206 layer: napari.layers.Shapes, shape: Tuple[int, int], i=None, track_id=None 207) -> Tuple[List[np.ndarray], List[Optional[np.ndarray]]]: 208 """Extract prompts for SAM from a napari shape layer. 209 210 Extracts the bounding box for 'rectangle' shapes and the bounding box and corresponding mask 211 for 'ellipse' and 'polygon' shapes. 212 213 Args: 214 prompt_layer: The napari shape layer. 215 shape: The image shape. 216 i: Index for the data (required for 3d or timeseries data). 217 track_id: Id of the current track (required for tracking data). 218 219 Returns: 220 The box prompts. 221 The mask prompts. 222 """ 223 224 def _to_prompts(shape_data, shape_types): 225 boxes, masks = [], [] 226 227 for data, type_ in zip(shape_data, shape_types): 228 229 if type_ == "rectangle": 230 boxes.append(data) 231 masks.append(None) 232 233 elif type_ == "ellipse": 234 boxes.append(data) 235 center = np.mean(data, axis=0) 236 radius_r = ((data[2] - data[1]) / 2)[0] 237 radius_c = ((data[1] - data[0]) / 2)[1] 238 rr, cc = draw.ellipse(center[0], center[1], radius_r, radius_c, shape=shape) 239 mask = np.zeros(shape, dtype=bool) 240 mask[rr, cc] = 1 241 masks.append(mask) 242 243 elif type_ == "polygon": 244 boxes.append(data) 245 rr, cc = draw.polygon(data[:, 0], data[:, 1], shape=shape) 246 mask = np.zeros(shape, dtype=bool) 247 mask[rr, cc] = 1 248 masks.append(mask) 249 250 else: 251 warnings.warn(f"Shape type {type_} is not supported and will be ignored.") 252 253 # map to correct box format 254 boxes = [ 255 np.array([box[:, 0].min(), box[:, 1].min(), box[:, 0].max(), box[:, 1].max()]) for box in boxes 256 ] 257 return boxes, masks 258 259 shape_data, shape_types = layer.data, layer.shape_type 260 assert len(shape_data) == len(shape_types) 261 if len(shape_data) == 0: 262 return [], [] 263 264 if i is not None: 265 if track_id is None: 266 prompt_selection = [j for j, data in enumerate(shape_data) if (data[:, 0] == i).all()] 267 else: 268 track_ids = np.array(list(map(int, layer.properties["track_id"]))) 269 prompt_selection = [ 270 j for j, (data, this_track_id) in enumerate(zip(shape_data, track_ids)) 271 if ((data[:, 0] == i).all() and this_track_id == track_id) 272 ] 273 274 shape_data = [shape_data[j][:, 1:] for j in prompt_selection] 275 shape_types = [shape_types[j] for j in prompt_selection] 276 277 boxes, masks = _to_prompts(shape_data, shape_types) 278 return boxes, masks
Extract prompts for SAM from a napari shape layer.
Extracts the bounding box for 'rectangle' shapes and the bounding box and corresponding mask for 'ellipse' and 'polygon' shapes.
Arguments:
- prompt_layer: The napari shape layer.
- shape: The image shape.
- i: Index for the data (required for 3d or timeseries data).
- track_id: Id of the current track (required for tracking data).
Returns:
The box prompts. The mask prompts.
def
prompt_layer_to_state(prompt_layer: napari.layers.points.points.Points, i: int) -> str:
281def prompt_layer_to_state(prompt_layer: napari.layers.Points, i: int) -> str: 282 """Get the state of the track from a point layer for a given timeframe. 283 284 Only relevant for annotator_tracking. 285 286 Args: 287 prompt_layer: The napari layer. 288 i: Timeframe of the data. 289 290 Returns: 291 The state of this frame (either "division" or "track"). 292 """ 293 state = prompt_layer.properties["state"] 294 295 points = prompt_layer.data 296 assert points.shape[1] == 3, f"{points.shape}" 297 mask = points[:, 0] == i 298 this_points = points[mask][:, 1:] 299 this_state = state[mask] 300 assert len(this_points) == len(this_state) 301 302 # we set the state to 'division' if at least one point in this frame has a division label 303 if any(st == "division" for st in this_state): 304 return "division" 305 else: 306 return "track"
Get the state of the track from a point layer for a given timeframe.
Only relevant for annotator_tracking.
Arguments:
- prompt_layer: The napari layer.
- i: Timeframe of the data.
Returns:
The state of this frame (either "division" or "track").
def
prompt_layers_to_state( point_layer: napari.layers.points.points.Points, box_layer: napari.layers.shapes.shapes.Shapes, i: int) -> str:
309def prompt_layers_to_state(point_layer: napari.layers.Points, box_layer: napari.layers.Shapes, i: int) -> str: 310 """Get the state of the track from a point layer and shape layer for a given timeframe. 311 312 Only relevant for annotator_tracking. 313 314 Args: 315 point_layer: The napari point layer. 316 box_layer: The napari box layer. 317 i: Timeframe of the data. 318 319 Returns: 320 The state of this frame (either "division" or "track"). 321 """ 322 state = point_layer.properties["state"] 323 324 points = point_layer.data 325 assert points.shape[1] == 3, f"{points.shape}" 326 mask = points[:, 0] == i 327 if mask.sum() > 0: 328 this_state = state[mask].tolist() 329 else: 330 this_state = [] 331 332 box_states = box_layer.properties["state"] 333 this_box_states = [ 334 state for box, state in zip(box_layer.data, box_states) 335 if (box[:, 0] == i).all() 336 ] 337 this_state.extend(this_box_states) 338 339 # we set the state to 'division' if at least one point in this frame has a division label 340 if any(st == "division" for st in this_state): 341 return "division" 342 else: 343 return "track"
Get the state of the track from a point layer and shape layer for a given timeframe.
Only relevant for annotator_tracking.
Arguments:
- point_layer: The napari point layer.
- box_layer: The napari box layer.
- i: Timeframe of the data.
Returns:
The state of this frame (either "division" or "track").