micro_sam.sam_annotator.util
1import os 2import pickle 3import warnings 4import argparse 5from glob import glob 6from pathlib import Path 7from typing import List, Optional, Tuple 8 9import h5py 10import napari 11import numpy as np 12from skimage import draw 13from scipy.ndimage import shift 14 15from .. import prompt_based_segmentation, util 16from .. import _model_settings as model_settings 17from ..multi_dimensional_segmentation import _validate_projection 18 19# Green and Red 20LABEL_COLOR_CYCLE = ["#00FF00", "#FF0000"] 21"""@private""" 22 23 24# 25# Misc helper functions 26# 27 28 29def toggle_label(prompts): 30 """@private""" 31 # get the currently selected label 32 current_properties = prompts.current_properties 33 current_label = current_properties["label"][0] 34 new_label = "negative" if current_label == "positive" else "positive" 35 current_properties["label"] = np.array([new_label]) 36 prompts.current_properties = current_properties 37 prompts.refresh() 38 prompts.refresh_colors() 39 40 41def _initialize_parser(description, with_segmentation_result=True, with_instance_segmentation=True): 42 43 available_models = list(util.get_model_names()) 44 available_models = ", ".join(available_models) 45 46 parser = argparse.ArgumentParser(description=description) 47 48 parser.add_argument( 49 "-i", "--input", required=True, 50 help="The filepath to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) " 51 "or elf.io.open_file (e.g. hdf5, zarr, mrc). For the latter you also need to pass the 'key' parameter." 52 ) 53 parser.add_argument( 54 "-k", "--key", 55 help="The key for opening data with elf.io.open_file. This is the internal path for a hdf5 or zarr container, " 56 "for a image series it is a wild-card, e.g. '*.png' and for mrc it is 'data'." 57 ) 58 59 parser.add_argument( 60 "-e", "--embedding_path", 61 help="The filepath for saving/loading the pre-computed image embeddings. " 62 "It is recommended to pass this argument and store the embeddings if you want to open the annotator " 63 "multiple times for this image. Otherwise the embeddings will be recomputed every time." 64 ) 65 66 if with_segmentation_result: 67 parser.add_argument( 68 "-s", "--segmentation_result", 69 help="Optional filepath to a precomputed segmentation. If passed this will be used to initialize the " 70 "'committed_objects' layer. This can be useful if you want to correct an existing segmentation or if you " 71 "have saved intermediate results from the annotator and want to continue with your annotations. " 72 "Supports the same file formats as 'input'." 73 ) 74 parser.add_argument( 75 "-sk", "--segmentation_key", 76 help="The key for opening the segmentation data. Same rules as for 'key' apply." 77 ) 78 79 parser.add_argument( 80 "-m", "--model_type", default=util._DEFAULT_MODEL, 81 help=f"The segment anything model that will be used, one of {available_models}." 82 ) 83 parser.add_argument( 84 "-c", "--checkpoint", default=None, 85 help="Checkpoint from which the SAM model will be loaded loaded." 86 ) 87 parser.add_argument( 88 "-d", "--device", default=None, 89 help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)." 90 "By default the most performant available device will be selected." 91 ) 92 93 parser.add_argument( 94 "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction", default=None 95 ) 96 parser.add_argument( 97 "--halo", nargs="+", type=int, help="The halo for using tiled prediction", default=None 98 ) 99 100 if with_instance_segmentation: 101 parser.add_argument( 102 "--precompute_amg_state", action="store_true", 103 help="Whether to precompute the state for automatic instance segmentation. " 104 "This will lead to a longer start-up time, but the automatic instance segmentation can " 105 "be run directly once the tool has started." 106 ) 107 parser.add_argument( 108 "--prefer_decoder", action="store_false", 109 help="Whether to use decoder based instance segmentation if the model " 110 "being used has an additional decoder for that purpose." 111 ) 112 113 return parser 114 115 116def clear_annotations(viewer: napari.Viewer, clear_segmentations=True) -> None: 117 """@private""" 118 viewer.layers["point_prompts"].data = [] 119 viewer.layers["point_prompts"].refresh() 120 if "prompts" in viewer.layers: 121 # Select all prompts and then remove them. 122 # This is how it worked before napari 0.5. 123 # viewer.layers["prompts"].data = [] 124 viewer.layers["prompts"].selected_data = set(range(len(viewer.layers["prompts"].data))) 125 viewer.layers["prompts"].remove_selected() 126 viewer.layers["prompts"].refresh() 127 if not clear_segmentations: 128 return 129 viewer.layers["current_object"].data = np.zeros(viewer.layers["current_object"].data.shape, dtype="uint32") 130 viewer.layers["current_object"].refresh() 131 132 133def clear_annotations_slice(viewer: napari.Viewer, i: int, clear_segmentations=True) -> None: 134 """@private""" 135 point_prompts = viewer.layers["point_prompts"].data 136 point_prompts = point_prompts[point_prompts[:, 0] != i] 137 viewer.layers["point_prompts"].data = point_prompts 138 viewer.layers["point_prompts"].refresh() 139 if "prompts" in viewer.layers: 140 prompts = viewer.layers["prompts"].data 141 prompts = [prompt for prompt in prompts if not (prompt[:, 0] == i).all()] 142 viewer.layers["prompts"].data = prompts 143 viewer.layers["prompts"].refresh() 144 if not clear_segmentations: 145 return 146 viewer.layers["current_object"].data[i] = 0 147 viewer.layers["current_object"].refresh() 148 149 150# 151# Helper functions to extract prompts from napari layers. 152# 153 154 155def point_layer_to_prompts( 156 layer: napari.layers.Points, i=None, track_id=None, with_stop_annotation=True, 157) -> Optional[Tuple[np.ndarray, np.ndarray]]: 158 """Extract point prompts for SAM from a napari point layer. 159 160 Args: 161 layer: The point layer from which to extract the prompts. 162 i: Index for the data (required for 3d or timeseries data). 163 track_id: Id of the current track (required for tracking data). 164 with_stop_annotation: Whether a single negative point will be interpreted 165 as stop annotation or just returned as normal prompt. 166 167 Returns: 168 The point coordinates for the prompts. 169 The labels (positive or negative / 1 or 0) for the prompts. 170 """ 171 172 points = layer.data 173 labels = layer.properties["label"] 174 assert len(points) == len(labels) 175 176 if i is None: 177 assert points.shape[1] == 2, f"{points.shape}" 178 this_points, this_labels = points, labels 179 else: 180 assert points.shape[1] == 3, f"{points.shape}" 181 mask = points[:, 0] == i 182 this_points = points[mask][:, 1:] 183 this_labels = labels[mask] 184 assert len(this_points) == len(this_labels) 185 186 if track_id is not None: 187 assert i is not None 188 track_ids = np.array(list(map(int, layer.properties["track_id"])))[mask] 189 track_id_mask = track_ids == track_id 190 this_labels, this_points = this_labels[track_id_mask], this_points[track_id_mask] 191 assert len(this_points) == len(this_labels) 192 193 this_labels = np.array([1 if label == "positive" else 0 for label in this_labels]) 194 # a single point with a negative label is interpreted as 'stop' signal 195 # in this case we return None 196 if with_stop_annotation and (len(this_points) == 1 and this_labels[0] == 0): 197 return None 198 199 return this_points, this_labels 200 201 202def shape_layer_to_prompts( 203 layer: napari.layers.Shapes, shape: Tuple[int, int], i=None, track_id=None 204) -> Tuple[List[np.ndarray], List[Optional[np.ndarray]]]: 205 """Extract prompts for SAM from a napari shape layer. 206 207 Extracts the bounding box for 'rectangle' shapes and the bounding box and corresponding mask 208 for 'ellipse' and 'polygon' shapes. 209 210 Args: 211 prompt_layer: The napari shape layer. 212 shape: The image shape. 213 i: Index for the data (required for 3d or timeseries data). 214 track_id: Id of the current track (required for tracking data). 215 216 Returns: 217 The box prompts. 218 The mask prompts. 219 """ 220 221 def _to_prompts(shape_data, shape_types): 222 boxes, masks = [], [] 223 224 for data, type_ in zip(shape_data, shape_types): 225 226 if type_ == "rectangle": 227 boxes.append(data) 228 masks.append(None) 229 230 elif type_ == "ellipse": 231 boxes.append(data) 232 center = np.mean(data, axis=0) 233 radius_r = ((data[2] - data[1]) / 2)[0] 234 radius_c = ((data[1] - data[0]) / 2)[1] 235 rr, cc = draw.ellipse(center[0], center[1], radius_r, radius_c, shape=shape) 236 mask = np.zeros(shape, dtype=bool) 237 mask[rr, cc] = 1 238 masks.append(mask) 239 240 elif type_ == "polygon": 241 boxes.append(data) 242 rr, cc = draw.polygon(data[:, 0], data[:, 1], shape=shape) 243 mask = np.zeros(shape, dtype=bool) 244 mask[rr, cc] = 1 245 masks.append(mask) 246 247 else: 248 warnings.warn(f"Shape type {type_} is not supported and will be ignored.") 249 250 # map to correct box format 251 boxes = [ 252 np.array([box[:, 0].min(), box[:, 1].min(), box[:, 0].max(), box[:, 1].max()]) for box in boxes 253 ] 254 return boxes, masks 255 256 shape_data, shape_types = layer.data, layer.shape_type 257 assert len(shape_data) == len(shape_types) 258 if len(shape_data) == 0: 259 return [], [] 260 261 if i is not None: 262 if track_id is None: 263 prompt_selection = [j for j, data in enumerate(shape_data) if (data[:, 0] == i).all()] 264 else: 265 track_ids = np.array(list(map(int, layer.properties["track_id"]))) 266 prompt_selection = [ 267 j for j, (data, this_track_id) in enumerate(zip(shape_data, track_ids)) 268 if ((data[:, 0] == i).all() and this_track_id == track_id) 269 ] 270 271 shape_data = [shape_data[j][:, 1:] for j in prompt_selection] 272 shape_types = [shape_types[j] for j in prompt_selection] 273 274 boxes, masks = _to_prompts(shape_data, shape_types) 275 return boxes, masks 276 277 278def prompt_layer_to_state(prompt_layer: napari.layers.Points, i: int) -> str: 279 """Get the state of the track from a point layer for a given timeframe. 280 281 Only relevant for annotator_tracking. 282 283 Args: 284 prompt_layer: The napari layer. 285 i: Timeframe of the data. 286 287 Returns: 288 The state of this frame (either "division" or "track"). 289 """ 290 state = prompt_layer.properties["state"] 291 292 points = prompt_layer.data 293 assert points.shape[1] == 3, f"{points.shape}" 294 mask = points[:, 0] == i 295 this_points = points[mask][:, 1:] 296 this_state = state[mask] 297 assert len(this_points) == len(this_state) 298 299 # we set the state to 'division' if at least one point in this frame has a division label 300 if any(st == "division" for st in this_state): 301 return "division" 302 else: 303 return "track" 304 305 306def prompt_layers_to_state( 307 point_layer: napari.layers.Points, box_layer: napari.layers.Shapes, i: int 308) -> str: 309 """Get the state of the track from a point layer and shape layer for a given timeframe. 310 311 Only relevant for annotator_tracking. 312 313 Args: 314 point_layer: The napari point layer. 315 box_layer: The napari box layer. 316 i: Timeframe of the data. 317 318 Returns: 319 The state of this frame (either "division" or "track"). 320 """ 321 state = point_layer.properties["state"] 322 323 points = point_layer.data 324 assert points.shape[1] == 3, f"{points.shape}" 325 mask = points[:, 0] == i 326 if mask.sum() > 0: 327 this_state = state[mask].tolist() 328 else: 329 this_state = [] 330 331 box_states = box_layer.properties["state"] 332 this_box_states = [ 333 state for box, state in zip(box_layer.data, box_states) 334 if (box[:, 0] == i).all() 335 ] 336 this_state.extend(this_box_states) 337 338 # we set the state to 'division' if at least one point in this frame has a division label 339 if any(st == "division" for st in this_state): 340 return "division" 341 else: 342 return "track" 343 344 345# 346# Helper functions to run (multi-dimensional) segmentation on napari layers. 347# 348 349 350def segment_slices_with_prompts( 351 predictor, point_prompts, box_prompts, image_embeddings, shape, track_id=None, update_progress=None, 352): 353 """@private""" 354 assert len(shape) == 3 355 image_shape = shape[1:] 356 seg = np.zeros(shape, dtype="uint32") 357 358 z_values = point_prompts.data[:, 0] 359 z_values_boxes = np.concatenate([box[:1, 0] for box in box_prompts.data]) if box_prompts.data else\ 360 np.zeros(0, dtype="int") 361 362 if track_id is not None: 363 track_ids_points = np.array(list(map(int, point_prompts.properties["track_id"]))) 364 assert len(track_ids_points) == len(z_values) 365 z_values = z_values[track_ids_points == track_id] 366 367 if len(z_values_boxes) > 0: 368 track_ids_boxes = np.array(list(map(int, box_prompts.properties["track_id"]))) 369 assert len(track_ids_boxes) == len(z_values_boxes), f"{len(track_ids_boxes)}, {len(z_values_boxes)}" 370 z_values_boxes = z_values_boxes[track_ids_boxes == track_id] 371 372 slices = np.unique(np.concatenate([z_values, z_values_boxes])).astype("int") 373 stop_lower, stop_upper = False, False 374 375 if update_progress is None: 376 def update_progress(*args): 377 pass 378 379 for i in slices: 380 points_i = point_layer_to_prompts(point_prompts, i, track_id) 381 382 # do we end the segmentation at the outer slices? 383 if points_i is None: 384 385 if i == slices[0]: # The bottom slice is a stop slice. 386 stop_lower = True 387 seg[i] = 0 388 elif i == slices[-1]: # The top sloce is a stop slice. 389 stop_upper = True 390 seg[i] = 0 391 else: # We have a stop annotation somewhere in the middle. Ignore this. 392 # Remove this slice from the annotated slices, so that it is segmented via 393 # projection in the next step. 394 slices = np.setdiff1d(slices, i) 395 print(f"You have provided a stop annotation (single red point) in slice {i},") 396 print("but you have annotated slices above or below it. This stop annotation will") 397 print(f"be ignored and the slice {i} will be segmented normally.") 398 399 update_progress(1) 400 continue 401 402 boxes, masks = shape_layer_to_prompts(box_prompts, image_shape, i=i, track_id=track_id) 403 points, labels = points_i 404 405 seg_i = prompt_segmentation( 406 predictor, points, labels, boxes, masks, image_shape, multiple_box_prompts=False, 407 image_embeddings=image_embeddings, i=i 408 ) 409 if seg_i is None: 410 print(f"The prompts at slice or frame {i} are invalid and the segmentation was skipped.") 411 print("This will lead to a wrong segmentation across slices or frames.") 412 print(f"Please correct the prompts in {i} and rerun the segmentation.") 413 continue 414 415 seg[i] = seg_i 416 update_progress(1) 417 418 return seg, slices, stop_lower, stop_upper 419 420 421# For advanced batching: match prompts to already segmented objects and continue segmentation. 422def _match_prompts(previous_segmentation, points, boxes, seg_ids): 423 # Create a mapping between ids and prompts. 424 batched_prompts = {} 425 # seg_boundaries = find_boundaries(previous_segmentation, mode="inner") 426 # indices = distance_transform_edt(seg_boundaries, return_distance=False, return_index=True) 427 return batched_prompts 428 429 430def _batched_interactive_segmentation(predictor, points, labels, boxes, image_embeddings, i, previous_segmentation): 431 prev_seg = previous_segmentation if i is None else previous_segmentation[i] 432 seg = np.zeros(prev_seg.shape, dtype="uint32") 433 434 # seg_ids = np.unique(previous_segmentation) 435 # assert seg_ids[0] == 0 436 437 batched_points, batched_labels = [], [] 438 negative_points, negative_labels = [], [] 439 for j in range(len(points)): 440 if labels[j] == 1: # positive point 441 batched_points.append(points[j:j+1]) 442 batched_labels.append(labels[j:j+1]) 443 else: # negative points 444 negative_points.append(points[j:j+1]) 445 negative_labels.append(labels[j:j+1]) 446 447 batched_prompts = [(None, point, label) for point, label in zip(batched_points, batched_labels)] 448 batched_prompts.extend([(box, None, None) for box in boxes]) 449 batched_prompts = {i: prompt for i, prompt in enumerate(batched_prompts, 1)} 450 451 # For advanced batching: match prompts to already segmented objects and continue segmentation. 452 # (This is left here as a reference for how this can be implemented. 453 # I have not decided yet if this is actually a good idea or not.) 454 # # If we have no objects: this is the first call for a batched segmentation. 455 # # We treat each positive point or box as a separate obejct. 456 # if len(seg_ids) == 1: 457 # # Create a list of all prompts. 458 # batched_prompts = [(None, point, label) for point, label in zip(batched_points, batched_labels)] 459 # batched_prompts.extend([(box, None, None) for box in boxes]) 460 # batched_prompts = {i: prompt for i, prompt in enumerate(batched_prompts, 1)} 461 462 # # Otherwise we match the prompts to existing objects. 463 # else: 464 # batched_prompts = _match_prompts(prev_seg, batched_points, boxes, seg_ids) 465 466 for seg_id, prompt in batched_prompts.items(): 467 box, point, label = prompt 468 if len(negative_points) > 0: 469 if point is None: 470 point, label = negative_points, negative_labels 471 else: 472 point = np.concatenate([point] + negative_points) 473 label = np.concatenate([label] + negative_labels) 474 475 if (box is not None) and (point is not None): 476 prediction = prompt_based_segmentation.segment_from_box_and_points( 477 predictor, box, point, label, image_embeddings=image_embeddings, i=i 478 ).squeeze() 479 elif (box is not None) and (point is None): 480 prediction = prompt_based_segmentation.segment_from_box( 481 predictor, box, image_embeddings=image_embeddings, i=i 482 ).squeeze() 483 else: 484 prediction = prompt_based_segmentation.segment_from_points( 485 predictor, point, label, image_embeddings=image_embeddings, i=i 486 ).squeeze() 487 488 seg[prediction] = seg_id 489 490 return seg 491 492 493def prompt_segmentation( 494 predictor, points, labels, boxes, masks, shape, multiple_box_prompts, 495 image_embeddings=None, i=None, box_extension=0, batched=None, 496 previous_segmentation=None, 497): 498 """@private""" 499 assert len(points) == len(labels) 500 have_points = len(points) > 0 501 have_boxes = len(boxes) > 0 502 503 # No prompts were given, return None. 504 if not have_points and not have_boxes: 505 return 506 507 # Batched interactive segmentation. 508 elif batched: 509 assert previous_segmentation is not None 510 seg = _batched_interactive_segmentation( 511 predictor, points, labels, boxes, image_embeddings, i, previous_segmentation 512 ) 513 514 # Box and point prompts were given. 515 elif have_points and have_boxes: 516 if len(boxes) > 1: 517 print("You have provided point prompts and more than one box prompt.") 518 print("This setting is currently not supported.") 519 print("When providing both points and prompts you can only segment one object at a time.") 520 return 521 mask = masks[0] 522 if mask is None: 523 seg = prompt_based_segmentation.segment_from_box_and_points( 524 predictor, boxes[0], points, labels, image_embeddings=image_embeddings, i=i 525 ).squeeze() 526 else: 527 seg = prompt_based_segmentation.segment_from_mask( 528 predictor, mask, box=boxes[0], points=points, labels=labels, image_embeddings=image_embeddings, i=i 529 ).squeeze() 530 531 # Only point prompts were given. 532 elif have_points and not have_boxes: 533 seg = prompt_based_segmentation.segment_from_points( 534 predictor, points, labels, image_embeddings=image_embeddings, i=i 535 ).squeeze() 536 537 # Only box prompts were given. 538 elif not have_points and have_boxes: 539 seg = np.zeros(shape, dtype="uint32") 540 541 if len(boxes) > 1 and not multiple_box_prompts: 542 print("You have provided more than one box annotation. This is not yet supported in the 3d annotator.") 543 print("You can only segment one object at a time in 3d.") 544 return 545 546 # Batch this? 547 for seg_id, (box, mask) in enumerate(zip(boxes, masks), 1): 548 if mask is None: 549 prediction = prompt_based_segmentation.segment_from_box( 550 predictor, box, image_embeddings=image_embeddings, i=i 551 ).squeeze() 552 else: 553 prediction = prompt_based_segmentation.segment_from_mask( 554 predictor, mask, box=box, image_embeddings=image_embeddings, i=i, 555 box_extension=box_extension, 556 ).squeeze() 557 seg[prediction] = seg_id 558 559 return seg 560 561 562def _compute_movement(seg, t0, t1): 563 564 def compute_center(t): 565 # computation with center of mass 566 center = np.where(seg[t] == 1) 567 center = np.array(np.mean(center[0]), np.mean(center[1])) 568 return center 569 570 center0 = compute_center(t0) 571 center1 = compute_center(t1) 572 573 move = center1 - center0 574 return move.astype("float64") 575 576 577def _shift_object(mask, motion_model): 578 mask_shifted = np.zeros_like(mask) 579 shift(mask, motion_model, output=mask_shifted, order=0, prefilter=False) 580 return mask_shifted 581 582 583def track_from_prompts( 584 point_prompts, box_prompts, seg, predictor, slices, image_embeddings, 585 stop_upper, threshold, projection, 586 motion_smoothing=0.5, box_extension=0, update_progress=None, 587): 588 """@private 589 """ 590 use_box, use_mask, use_points, use_single_point = _validate_projection(projection) 591 592 if update_progress is None: 593 def update_progress(*args): 594 pass 595 596 # shift the segmentation based on the motion model and update the motion model 597 def _update_motion_model(seg, t, t0, motion_model): 598 if t in (t0, t0 + 1): # this is the first or second frame, we don't have a motion yet 599 pass 600 elif t == t0 + 2: # this the third frame, we initialize the motion model 601 current_move = _compute_movement(seg, t - 1, t - 2) 602 motion_model = current_move 603 else: # we already have a motion model and update it 604 current_move = _compute_movement(seg, t - 1, t - 2) 605 alpha = motion_smoothing 606 motion_model = alpha * motion_model + (1 - alpha) * current_move 607 608 return motion_model 609 610 has_division = False 611 motion_model = None 612 verbose = False 613 614 t0 = int(slices.min()) 615 t = t0 + 1 616 while True: 617 618 # update the motion model 619 motion_model = _update_motion_model(seg, t, t0, motion_model) 620 621 # use the segmentation from prompts if we are in a slice with prompts 622 if t in slices: 623 seg_prev = None 624 seg_t = seg[t] 625 # currently using the box layer doesn't work for keeping track of the track state 626 # track_state = prompt_layers_to_state(point_prompts, box_prompts, t) 627 track_state = prompt_layer_to_state(point_prompts, t) 628 629 # otherwise project the mask (under the motion model) and segment the next slice from the mask 630 else: 631 if verbose: 632 print(f"Tracking object in frame {t} with movement {motion_model}") 633 634 seg_prev = seg[t - 1] 635 # shift the segmentation according to the motion model 636 if motion_model is not None: 637 seg_prev = _shift_object(seg_prev, motion_model) 638 639 seg_t = prompt_based_segmentation.segment_from_mask( 640 predictor, seg_prev, image_embeddings=image_embeddings, i=t, 641 use_mask=use_mask, use_box=use_box, use_points=use_points, 642 box_extension=box_extension, use_single_point=use_single_point, 643 ) 644 track_state = "track" 645 646 # are we beyond the last slice with prompt? 647 # if no: we continue tracking because we know we need to connect to a future frame 648 # if yes: we only continue tracking if overlaps are above the threshold 649 if t < slices[-1]: 650 seg_prev = None 651 652 update_progress(1) 653 654 if (threshold is not None) and (seg_prev is not None): 655 iou = util.compute_iou(seg_prev, seg_t) 656 if iou < threshold: 657 msg = f"Segmentation stopped at frame {t} due to IOU {iou} < {threshold}." 658 print(msg) 659 break 660 661 # stop if we have a division 662 if track_state == "division": 663 has_division = True 664 break 665 666 seg[t] = seg_t 667 t += 1 668 669 # stop tracking if we have stop upper set (i.e. single negative point was set to indicate stop track) 670 if t == slices[-1] and stop_upper: 671 break 672 673 # stop if we are at the last slce 674 if t == seg.shape[0]: 675 break 676 677 return seg, has_division 678 679 680def _sync_embedding_widget(widget, model_type, save_path, checkpoint_path, device, tile_shape, halo): 681 widget.model_type = model_type 682 index = widget.model_dropdown.findText(model_type) 683 if index > 0: 684 widget.model_dropdown.setCurrentIndex(index) 685 686 if save_path is not None: 687 widget.embeddings_save_path_param.setText(str(save_path)) 688 689 if checkpoint_path is not None: 690 widget.custom_weights_param.setText(str(checkpoint_path)) 691 692 if device is not None: 693 widget.device = device 694 index = widget.device_dropdown.findText(device) 695 widget.device_dropdown.setCurrentIndex(index) 696 697 if tile_shape is not None: 698 widget.tile_x_param.setValue(tile_shape[0]) 699 widget.tile_y_param.setValue(tile_shape[1]) 700 701 if halo is not None: 702 widget.halo_x_param.setValue(halo[0]) 703 widget.halo_y_param.setValue(halo[1]) 704 705 706# Read parameters from checkpoint path if it is given instead. 707def _sync_autosegment_widget(widget, model_type, checkpoint_path, update_decoder=None): 708 if update_decoder is not None: 709 widget._reset_segmentation_mode(update_decoder) 710 711 if widget.with_decoder: 712 settings = model_settings.AIS_SETTINGS.get(model_type, {}) 713 params = ("center_distance_thresh", "boundary_distance_thresh") 714 for param in params: 715 if param in settings: 716 getattr(widget, f"{param}_param").setValue(settings[param]) 717 else: 718 settings = model_settings.AMG_SETTINGS.get(model_type, {}) 719 params = ("pred_iou_thresh", "stability_score_thresh", "min_object_size") 720 for param in params: 721 if param in settings: 722 getattr(widget, f"{param}_param").setValue(settings[param]) 723 724 725# Read parameters from checkpoint path if it is given instead. 726def _sync_ndsegment_widget(widget, model_type, checkpoint_path): 727 settings = model_settings.ND_SEGMENT_SETTINGS.get(model_type, {}) 728 729 if "projection_mode" in settings: 730 projection_mode = settings["projection_mode"] 731 widget.projection = projection_mode 732 index = widget.projection_dropdown.findText(projection_mode) 733 if index > 0: 734 widget.projection_dropdown.setCurrentIndex(index) 735 736 params = ("iou_threshold", "box_extension") 737 for param in params: 738 if param in settings: 739 getattr(widget, f"{param}_param").setValue(settings[param]) 740 741 742def _load_amg_state(embedding_path): 743 if embedding_path is None or not os.path.exists(embedding_path): 744 return {"cache_folder": None} 745 746 cache_folder = os.path.join(embedding_path, "amg_state") 747 os.makedirs(cache_folder, exist_ok=True) 748 amg_state = {"cache_folder": cache_folder} 749 750 state_paths = glob(os.path.join(cache_folder, "*.pkl")) 751 for path in state_paths: 752 with open(path, "rb") as f: 753 state = pickle.load(f) 754 i = int(Path(path).stem.split("-")[-1]) 755 amg_state[i] = state 756 return amg_state 757 758 759def _load_is_state(embedding_path): 760 if embedding_path is None or not os.path.exists(embedding_path): 761 return {"cache_path": None} 762 763 cache_path = os.path.join(embedding_path, "is_state.h5") 764 is_state = {"cache_path": cache_path} 765 766 with h5py.File(cache_path, "a") as f: 767 for name, g in f.items(): 768 i = int(name.split("-")[-1]) 769 state = { 770 "foreground": g["foreground"][:], 771 "boundary_distances": g["boundary_distances"][:], 772 "center_distances": g["center_distances"][:], 773 } 774 is_state[i] = state 775 776 return is_state
def
point_layer_to_prompts( layer: napari.layers.points.points.Points, i=None, track_id=None, with_stop_annotation=True) -> Optional[Tuple[numpy.ndarray, numpy.ndarray]]:
156def point_layer_to_prompts( 157 layer: napari.layers.Points, i=None, track_id=None, with_stop_annotation=True, 158) -> Optional[Tuple[np.ndarray, np.ndarray]]: 159 """Extract point prompts for SAM from a napari point layer. 160 161 Args: 162 layer: The point layer from which to extract the prompts. 163 i: Index for the data (required for 3d or timeseries data). 164 track_id: Id of the current track (required for tracking data). 165 with_stop_annotation: Whether a single negative point will be interpreted 166 as stop annotation or just returned as normal prompt. 167 168 Returns: 169 The point coordinates for the prompts. 170 The labels (positive or negative / 1 or 0) for the prompts. 171 """ 172 173 points = layer.data 174 labels = layer.properties["label"] 175 assert len(points) == len(labels) 176 177 if i is None: 178 assert points.shape[1] == 2, f"{points.shape}" 179 this_points, this_labels = points, labels 180 else: 181 assert points.shape[1] == 3, f"{points.shape}" 182 mask = points[:, 0] == i 183 this_points = points[mask][:, 1:] 184 this_labels = labels[mask] 185 assert len(this_points) == len(this_labels) 186 187 if track_id is not None: 188 assert i is not None 189 track_ids = np.array(list(map(int, layer.properties["track_id"])))[mask] 190 track_id_mask = track_ids == track_id 191 this_labels, this_points = this_labels[track_id_mask], this_points[track_id_mask] 192 assert len(this_points) == len(this_labels) 193 194 this_labels = np.array([1 if label == "positive" else 0 for label in this_labels]) 195 # a single point with a negative label is interpreted as 'stop' signal 196 # in this case we return None 197 if with_stop_annotation and (len(this_points) == 1 and this_labels[0] == 0): 198 return None 199 200 return this_points, this_labels
Extract point prompts for SAM from a napari point layer.
Arguments:
- layer: The point layer from which to extract the prompts.
- i: Index for the data (required for 3d or timeseries data).
- track_id: Id of the current track (required for tracking data).
- with_stop_annotation: Whether a single negative point will be interpreted as stop annotation or just returned as normal prompt.
Returns:
The point coordinates for the prompts. The labels (positive or negative / 1 or 0) for the prompts.
def
shape_layer_to_prompts( layer: napari.layers.shapes.shapes.Shapes, shape: Tuple[int, int], i=None, track_id=None) -> Tuple[List[numpy.ndarray], List[Optional[numpy.ndarray]]]:
203def shape_layer_to_prompts( 204 layer: napari.layers.Shapes, shape: Tuple[int, int], i=None, track_id=None 205) -> Tuple[List[np.ndarray], List[Optional[np.ndarray]]]: 206 """Extract prompts for SAM from a napari shape layer. 207 208 Extracts the bounding box for 'rectangle' shapes and the bounding box and corresponding mask 209 for 'ellipse' and 'polygon' shapes. 210 211 Args: 212 prompt_layer: The napari shape layer. 213 shape: The image shape. 214 i: Index for the data (required for 3d or timeseries data). 215 track_id: Id of the current track (required for tracking data). 216 217 Returns: 218 The box prompts. 219 The mask prompts. 220 """ 221 222 def _to_prompts(shape_data, shape_types): 223 boxes, masks = [], [] 224 225 for data, type_ in zip(shape_data, shape_types): 226 227 if type_ == "rectangle": 228 boxes.append(data) 229 masks.append(None) 230 231 elif type_ == "ellipse": 232 boxes.append(data) 233 center = np.mean(data, axis=0) 234 radius_r = ((data[2] - data[1]) / 2)[0] 235 radius_c = ((data[1] - data[0]) / 2)[1] 236 rr, cc = draw.ellipse(center[0], center[1], radius_r, radius_c, shape=shape) 237 mask = np.zeros(shape, dtype=bool) 238 mask[rr, cc] = 1 239 masks.append(mask) 240 241 elif type_ == "polygon": 242 boxes.append(data) 243 rr, cc = draw.polygon(data[:, 0], data[:, 1], shape=shape) 244 mask = np.zeros(shape, dtype=bool) 245 mask[rr, cc] = 1 246 masks.append(mask) 247 248 else: 249 warnings.warn(f"Shape type {type_} is not supported and will be ignored.") 250 251 # map to correct box format 252 boxes = [ 253 np.array([box[:, 0].min(), box[:, 1].min(), box[:, 0].max(), box[:, 1].max()]) for box in boxes 254 ] 255 return boxes, masks 256 257 shape_data, shape_types = layer.data, layer.shape_type 258 assert len(shape_data) == len(shape_types) 259 if len(shape_data) == 0: 260 return [], [] 261 262 if i is not None: 263 if track_id is None: 264 prompt_selection = [j for j, data in enumerate(shape_data) if (data[:, 0] == i).all()] 265 else: 266 track_ids = np.array(list(map(int, layer.properties["track_id"]))) 267 prompt_selection = [ 268 j for j, (data, this_track_id) in enumerate(zip(shape_data, track_ids)) 269 if ((data[:, 0] == i).all() and this_track_id == track_id) 270 ] 271 272 shape_data = [shape_data[j][:, 1:] for j in prompt_selection] 273 shape_types = [shape_types[j] for j in prompt_selection] 274 275 boxes, masks = _to_prompts(shape_data, shape_types) 276 return boxes, masks
Extract prompts for SAM from a napari shape layer.
Extracts the bounding box for 'rectangle' shapes and the bounding box and corresponding mask for 'ellipse' and 'polygon' shapes.
Arguments:
- prompt_layer: The napari shape layer.
- shape: The image shape.
- i: Index for the data (required for 3d or timeseries data).
- track_id: Id of the current track (required for tracking data).
Returns:
The box prompts. The mask prompts.
def
prompt_layer_to_state(prompt_layer: napari.layers.points.points.Points, i: int) -> str:
279def prompt_layer_to_state(prompt_layer: napari.layers.Points, i: int) -> str: 280 """Get the state of the track from a point layer for a given timeframe. 281 282 Only relevant for annotator_tracking. 283 284 Args: 285 prompt_layer: The napari layer. 286 i: Timeframe of the data. 287 288 Returns: 289 The state of this frame (either "division" or "track"). 290 """ 291 state = prompt_layer.properties["state"] 292 293 points = prompt_layer.data 294 assert points.shape[1] == 3, f"{points.shape}" 295 mask = points[:, 0] == i 296 this_points = points[mask][:, 1:] 297 this_state = state[mask] 298 assert len(this_points) == len(this_state) 299 300 # we set the state to 'division' if at least one point in this frame has a division label 301 if any(st == "division" for st in this_state): 302 return "division" 303 else: 304 return "track"
Get the state of the track from a point layer for a given timeframe.
Only relevant for annotator_tracking.
Arguments:
- prompt_layer: The napari layer.
- i: Timeframe of the data.
Returns:
The state of this frame (either "division" or "track").
def
prompt_layers_to_state( point_layer: napari.layers.points.points.Points, box_layer: napari.layers.shapes.shapes.Shapes, i: int) -> str:
307def prompt_layers_to_state( 308 point_layer: napari.layers.Points, box_layer: napari.layers.Shapes, i: int 309) -> str: 310 """Get the state of the track from a point layer and shape layer for a given timeframe. 311 312 Only relevant for annotator_tracking. 313 314 Args: 315 point_layer: The napari point layer. 316 box_layer: The napari box layer. 317 i: Timeframe of the data. 318 319 Returns: 320 The state of this frame (either "division" or "track"). 321 """ 322 state = point_layer.properties["state"] 323 324 points = point_layer.data 325 assert points.shape[1] == 3, f"{points.shape}" 326 mask = points[:, 0] == i 327 if mask.sum() > 0: 328 this_state = state[mask].tolist() 329 else: 330 this_state = [] 331 332 box_states = box_layer.properties["state"] 333 this_box_states = [ 334 state for box, state in zip(box_layer.data, box_states) 335 if (box[:, 0] == i).all() 336 ] 337 this_state.extend(this_box_states) 338 339 # we set the state to 'division' if at least one point in this frame has a division label 340 if any(st == "division" for st in this_state): 341 return "division" 342 else: 343 return "track"
Get the state of the track from a point layer and shape layer for a given timeframe.
Only relevant for annotator_tracking.
Arguments:
- point_layer: The napari point layer.
- box_layer: The napari box layer.
- i: Timeframe of the data.
Returns:
The state of this frame (either "division" or "track").