micro_sam.sam_annotator.util
1import os 2import pickle 3import warnings 4import argparse 5from glob import glob 6from pathlib import Path 7from typing import List, Optional, Tuple 8 9import h5py 10import napari 11import numpy as np 12from skimage import draw 13from scipy.ndimage import shift 14 15from .. import prompt_based_segmentation, util 16from .. import _model_settings as model_settings 17from ..multi_dimensional_segmentation import _validate_projection 18 19# Green and Red 20LABEL_COLOR_CYCLE = ["#00FF00", "#FF0000"] 21"""@private""" 22 23 24# 25# Misc helper functions 26# 27 28 29def toggle_label(prompts): 30 """@private""" 31 # get the currently selected label 32 current_properties = prompts.current_properties 33 current_label = current_properties["label"][0] 34 new_label = "negative" if current_label == "positive" else "positive" 35 current_properties["label"] = np.array([new_label]) 36 prompts.current_properties = current_properties 37 prompts.refresh() 38 prompts.refresh_colors() 39 40 41def _initialize_parser(description, with_segmentation_result=True, with_instance_segmentation=True): 42 43 available_models = list(util.get_model_names()) 44 available_models = ", ".join(available_models) 45 46 parser = argparse.ArgumentParser(description=description) 47 48 parser.add_argument( 49 "-i", "--input", required=True, 50 help="The filepath to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) " 51 "or elf.io.open_file (e.g. hdf5, zarr, mrc). For the latter you also need to pass the 'key' parameter." 52 ) 53 parser.add_argument( 54 "-k", "--key", 55 help="The key for opening data with elf.io.open_file. This is the internal path for a hdf5 or zarr container, " 56 "for a image series it is a wild-card, e.g. '*.png' and for mrc it is 'data'." 57 ) 58 parser.add_argument( 59 "-e", "--embedding_path", 60 help="The filepath for saving/loading the pre-computed image embeddings. " 61 "It is recommended to pass this argument and store the embeddings if you want to open the annotator " 62 "multiple times for this image. Otherwise the embeddings will be recomputed every time." 63 ) 64 65 if with_segmentation_result: 66 parser.add_argument( 67 "-s", "--segmentation_result", 68 help="Optional filepath to a precomputed segmentation. If passed this will be used to initialize the " 69 "'committed_objects' layer. This can be useful if you want to correct an existing segmentation or if you " 70 "have saved intermediate results from the annotator and want to continue with your annotations. " 71 "Supports the same file formats as 'input'." 72 ) 73 parser.add_argument( 74 "-sk", "--segmentation_key", 75 help="The key for opening the segmentation data. Same rules as for 'key' apply." 76 ) 77 78 parser.add_argument( 79 "-m", "--model_type", default=util._DEFAULT_MODEL, 80 help=f"The segment anything model that will be used, one of {available_models}." 81 ) 82 parser.add_argument( 83 "-c", "--checkpoint", default=None, 84 help="Checkpoint from which the SAM model will be loaded loaded." 85 ) 86 parser.add_argument( 87 "-d", "--device", default=None, 88 help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)." 89 "By default the most performant available device will be selected." 90 ) 91 parser.add_argument( 92 "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction", default=None 93 ) 94 parser.add_argument( 95 "--halo", nargs="+", type=int, help="The halo for using tiled prediction", default=None 96 ) 97 98 if with_instance_segmentation: 99 parser.add_argument( 100 "--precompute_amg_state", action="store_true", 101 help="Whether to precompute the state for automatic instance segmentation. " 102 "This will lead to a longer start-up time, but the automatic instance segmentation can " 103 "be run directly once the tool has started." 104 ) 105 parser.add_argument( 106 "--prefer_decoder", action="store_false", 107 help="Whether to use decoder based instance segmentation if the model " 108 "being used has an additional decoder for that purpose." 109 ) 110 111 return parser 112 113 114def clear_annotations(viewer: napari.Viewer, clear_segmentations=True) -> None: 115 """@private""" 116 viewer.layers["point_prompts"].data = [] 117 viewer.layers["point_prompts"].refresh() 118 if "prompts" in viewer.layers: 119 # Select all prompts and then remove them. 120 # This is how it worked before napari 0.5. 121 # viewer.layers["prompts"].data = [] 122 viewer.layers["prompts"].selected_data = set(range(len(viewer.layers["prompts"].data))) 123 viewer.layers["prompts"].remove_selected() 124 viewer.layers["prompts"].refresh() 125 if not clear_segmentations: 126 return 127 viewer.layers["current_object"].data = np.zeros(viewer.layers["current_object"].data.shape, dtype="uint32") 128 viewer.layers["current_object"].refresh() 129 130 131def clear_annotations_slice(viewer: napari.Viewer, i: int, clear_segmentations=True) -> None: 132 """@private""" 133 point_prompts = viewer.layers["point_prompts"].data 134 point_prompts = point_prompts[point_prompts[:, 0] != i] 135 viewer.layers["point_prompts"].data = point_prompts 136 viewer.layers["point_prompts"].refresh() 137 if "prompts" in viewer.layers: 138 prompts = viewer.layers["prompts"].data 139 prompts = [prompt for prompt in prompts if not (prompt[:, 0] == i).all()] 140 viewer.layers["prompts"].data = prompts 141 viewer.layers["prompts"].refresh() 142 if not clear_segmentations: 143 return 144 viewer.layers["current_object"].data[i] = 0 145 viewer.layers["current_object"].refresh() 146 147 148# 149# Helper functions to extract prompts from napari layers. 150# 151 152 153def point_layer_to_prompts( 154 layer: napari.layers.Points, i=None, track_id=None, with_stop_annotation=True, 155) -> Optional[Tuple[np.ndarray, np.ndarray]]: 156 """Extract point prompts for SAM from a napari point layer. 157 158 Args: 159 layer: The point layer from which to extract the prompts. 160 i: Index for the data (required for 3d or timeseries data). 161 track_id: Id of the current track (required for tracking data). 162 with_stop_annotation: Whether a single negative point will be interpreted 163 as stop annotation or just returned as normal prompt. 164 165 Returns: 166 The point coordinates for the prompts. 167 The labels (positive or negative / 1 or 0) for the prompts. 168 """ 169 170 points = layer.data 171 labels = layer.properties["label"] 172 assert len(points) == len(labels) 173 174 if i is None: 175 assert points.shape[1] == 2, f"{points.shape}" 176 this_points, this_labels = points, labels 177 else: 178 assert points.shape[1] == 3, f"{points.shape}" 179 mask = np.round(points[:, 0]) == i 180 this_points = points[mask][:, 1:] 181 this_labels = labels[mask] 182 assert len(this_points) == len(this_labels) 183 184 if track_id is not None: 185 assert i is not None 186 track_ids = np.array(list(map(int, layer.properties["track_id"])))[mask] 187 track_id_mask = track_ids == track_id 188 this_labels, this_points = this_labels[track_id_mask], this_points[track_id_mask] 189 assert len(this_points) == len(this_labels) 190 191 this_labels = np.array([1 if label == "positive" else 0 for label in this_labels]) 192 # a single point with a negative label is interpreted as 'stop' signal 193 # in this case we return None 194 if with_stop_annotation and (len(this_points) == 1 and this_labels[0] == 0): 195 return None 196 197 return this_points, this_labels 198 199 200def shape_layer_to_prompts( 201 layer: napari.layers.Shapes, shape: Tuple[int, int], i=None, track_id=None 202) -> Tuple[List[np.ndarray], List[Optional[np.ndarray]]]: 203 """Extract prompts for SAM from a napari shape layer. 204 205 Extracts the bounding box for 'rectangle' shapes and the bounding box and corresponding mask 206 for 'ellipse' and 'polygon' shapes. 207 208 Args: 209 prompt_layer: The napari shape layer. 210 shape: The image shape. 211 i: Index for the data (required for 3d or timeseries data). 212 track_id: Id of the current track (required for tracking data). 213 214 Returns: 215 The box prompts. 216 The mask prompts. 217 """ 218 219 def _to_prompts(shape_data, shape_types): 220 boxes, masks = [], [] 221 222 for data, type_ in zip(shape_data, shape_types): 223 224 if type_ == "rectangle": 225 boxes.append(data) 226 masks.append(None) 227 228 elif type_ == "ellipse": 229 boxes.append(data) 230 center = np.mean(data, axis=0) 231 radius_r = ((data[2] - data[1]) / 2)[0] 232 radius_c = ((data[1] - data[0]) / 2)[1] 233 rr, cc = draw.ellipse(center[0], center[1], radius_r, radius_c, shape=shape) 234 mask = np.zeros(shape, dtype=bool) 235 mask[rr, cc] = 1 236 masks.append(mask) 237 238 elif type_ == "polygon": 239 boxes.append(data) 240 rr, cc = draw.polygon(data[:, 0], data[:, 1], shape=shape) 241 mask = np.zeros(shape, dtype=bool) 242 mask[rr, cc] = 1 243 masks.append(mask) 244 245 else: 246 warnings.warn(f"Shape type {type_} is not supported and will be ignored.") 247 248 # map to correct box format 249 boxes = [ 250 np.array([box[:, 0].min(), box[:, 1].min(), box[:, 0].max(), box[:, 1].max()]) for box in boxes 251 ] 252 return boxes, masks 253 254 shape_data, shape_types = layer.data, layer.shape_type 255 assert len(shape_data) == len(shape_types) 256 if len(shape_data) == 0: 257 return [], [] 258 259 if i is not None: 260 if track_id is None: 261 prompt_selection = [j for j, data in enumerate(shape_data) if (data[:, 0] == i).all()] 262 else: 263 track_ids = np.array(list(map(int, layer.properties["track_id"]))) 264 prompt_selection = [ 265 j for j, (data, this_track_id) in enumerate(zip(shape_data, track_ids)) 266 if ((data[:, 0] == i).all() and this_track_id == track_id) 267 ] 268 269 shape_data = [shape_data[j][:, 1:] for j in prompt_selection] 270 shape_types = [shape_types[j] for j in prompt_selection] 271 272 boxes, masks = _to_prompts(shape_data, shape_types) 273 return boxes, masks 274 275 276def prompt_layer_to_state(prompt_layer: napari.layers.Points, i: int) -> str: 277 """Get the state of the track from a point layer for a given timeframe. 278 279 Only relevant for annotator_tracking. 280 281 Args: 282 prompt_layer: The napari layer. 283 i: Timeframe of the data. 284 285 Returns: 286 The state of this frame (either "division" or "track"). 287 """ 288 state = prompt_layer.properties["state"] 289 290 points = prompt_layer.data 291 assert points.shape[1] == 3, f"{points.shape}" 292 mask = points[:, 0] == i 293 this_points = points[mask][:, 1:] 294 this_state = state[mask] 295 assert len(this_points) == len(this_state) 296 297 # we set the state to 'division' if at least one point in this frame has a division label 298 if any(st == "division" for st in this_state): 299 return "division" 300 else: 301 return "track" 302 303 304def prompt_layers_to_state(point_layer: napari.layers.Points, box_layer: napari.layers.Shapes, i: int) -> str: 305 """Get the state of the track from a point layer and shape layer for a given timeframe. 306 307 Only relevant for annotator_tracking. 308 309 Args: 310 point_layer: The napari point layer. 311 box_layer: The napari box layer. 312 i: Timeframe of the data. 313 314 Returns: 315 The state of this frame (either "division" or "track"). 316 """ 317 state = point_layer.properties["state"] 318 319 points = point_layer.data 320 assert points.shape[1] == 3, f"{points.shape}" 321 mask = points[:, 0] == i 322 if mask.sum() > 0: 323 this_state = state[mask].tolist() 324 else: 325 this_state = [] 326 327 box_states = box_layer.properties["state"] 328 this_box_states = [ 329 state for box, state in zip(box_layer.data, box_states) 330 if (box[:, 0] == i).all() 331 ] 332 this_state.extend(this_box_states) 333 334 # we set the state to 'division' if at least one point in this frame has a division label 335 if any(st == "division" for st in this_state): 336 return "division" 337 else: 338 return "track" 339 340 341# 342# Helper functions to run (multi-dimensional) segmentation on napari layers. 343# 344 345 346def segment_slices_with_prompts( 347 predictor, point_prompts, box_prompts, image_embeddings, shape, track_id=None, update_progress=None, 348): 349 """@private""" 350 assert len(shape) == 3 351 image_shape = shape[1:] 352 seg = np.zeros(shape, dtype="uint32") 353 354 z_values = np.round(point_prompts.data[:, 0]) 355 z_values_boxes = np.concatenate([box[:1, 0] for box in box_prompts.data]) if box_prompts.data else\ 356 np.zeros(0, dtype="int") 357 358 if track_id is not None: 359 track_ids_points = np.array(list(map(int, point_prompts.properties["track_id"]))) 360 assert len(track_ids_points) == len(z_values) 361 z_values = z_values[track_ids_points == track_id] 362 363 if len(z_values_boxes) > 0: 364 track_ids_boxes = np.array(list(map(int, box_prompts.properties["track_id"]))) 365 assert len(track_ids_boxes) == len(z_values_boxes), f"{len(track_ids_boxes)}, {len(z_values_boxes)}" 366 z_values_boxes = z_values_boxes[track_ids_boxes == track_id] 367 368 slices = np.unique(np.concatenate([z_values, z_values_boxes])).astype("int") 369 stop_lower, stop_upper = False, False 370 371 if update_progress is None: 372 def update_progress(*args): 373 pass 374 375 for i in slices: 376 points_i = point_layer_to_prompts(point_prompts, i, track_id) 377 378 # do we end the segmentation at the outer slices? 379 if points_i is None: 380 381 if i == slices[0]: # The bottom slice is a stop slice. 382 stop_lower = True 383 seg[i] = 0 384 elif i == slices[-1]: # The top sloce is a stop slice. 385 stop_upper = True 386 seg[i] = 0 387 else: # We have a stop annotation somewhere in the middle. Ignore this. 388 # Remove this slice from the annotated slices, so that it is segmented via 389 # projection in the next step. 390 slices = np.setdiff1d(slices, i) 391 print(f"You have provided a stop annotation (single red point) in slice {i},") 392 print("but you have annotated slices above or below it. This stop annotation will") 393 print(f"be ignored and the slice {i} will be segmented normally.") 394 395 update_progress(1) 396 continue 397 398 boxes, masks = shape_layer_to_prompts(box_prompts, image_shape, i=i, track_id=track_id) 399 points, labels = points_i 400 401 seg_i = prompt_segmentation( 402 predictor, points, labels, boxes, masks, image_shape, multiple_box_prompts=False, 403 image_embeddings=image_embeddings, i=i 404 ) 405 if seg_i is None: 406 print(f"The prompts at slice or frame {i} are invalid and the segmentation was skipped.") 407 print("This will lead to a wrong segmentation across slices or frames.") 408 print(f"Please correct the prompts in {i} and rerun the segmentation.") 409 continue 410 411 seg[i] = seg_i 412 update_progress(1) 413 414 return seg, slices, stop_lower, stop_upper 415 416 417# For advanced batching: match prompts to already segmented objects and continue segmentation. 418def _match_prompts(previous_segmentation, points, boxes, seg_ids): 419 # Create a mapping between ids and prompts. 420 batched_prompts = {} 421 # seg_boundaries = find_boundaries(previous_segmentation, mode="inner") 422 # indices = distance_transform_edt(seg_boundaries, return_distance=False, return_index=True) 423 return batched_prompts 424 425 426def _batched_interactive_segmentation(predictor, points, labels, boxes, image_embeddings, i, previous_segmentation): 427 prev_seg = previous_segmentation if i is None else previous_segmentation[i] 428 seg = np.zeros(prev_seg.shape, dtype="uint32") 429 430 # seg_ids = np.unique(previous_segmentation) 431 # assert seg_ids[0] == 0 432 433 batched_points, batched_labels = [], [] 434 negative_points, negative_labels = [], [] 435 for j in range(len(points)): 436 if labels[j] == 1: # positive point 437 batched_points.append(points[j:j+1]) 438 batched_labels.append(labels[j:j+1]) 439 else: # negative points 440 negative_points.append(points[j:j+1]) 441 negative_labels.append(labels[j:j+1]) 442 443 batched_prompts = [(None, point, label) for point, label in zip(batched_points, batched_labels)] 444 batched_prompts.extend([(box, None, None) for box in boxes]) 445 batched_prompts = {i: prompt for i, prompt in enumerate(batched_prompts, 1)} 446 447 # For advanced batching: match prompts to already segmented objects and continue segmentation. 448 # (This is left here as a reference for how this can be implemented. 449 # I have not decided yet if this is actually a good idea or not.) 450 # # If we have no objects: this is the first call for a batched segmentation. 451 # # We treat each positive point or box as a separate obejct. 452 # if len(seg_ids) == 1: 453 # # Create a list of all prompts. 454 # batched_prompts = [(None, point, label) for point, label in zip(batched_points, batched_labels)] 455 # batched_prompts.extend([(box, None, None) for box in boxes]) 456 # batched_prompts = {i: prompt for i, prompt in enumerate(batched_prompts, 1)} 457 458 # # Otherwise we match the prompts to existing objects. 459 # else: 460 # batched_prompts = _match_prompts(prev_seg, batched_points, boxes, seg_ids) 461 462 for seg_id, prompt in batched_prompts.items(): 463 box, point, label = prompt 464 if len(negative_points) > 0: 465 if point is None: 466 point, label = negative_points, negative_labels 467 else: 468 point = np.concatenate([point] + negative_points) 469 label = np.concatenate([label] + negative_labels) 470 471 if (box is not None) and (point is not None): 472 prediction = prompt_based_segmentation.segment_from_box_and_points( 473 predictor, box, point, label, image_embeddings=image_embeddings, i=i 474 ).squeeze() 475 elif (box is not None) and (point is None): 476 prediction = prompt_based_segmentation.segment_from_box( 477 predictor, box, image_embeddings=image_embeddings, i=i 478 ).squeeze() 479 else: 480 prediction = prompt_based_segmentation.segment_from_points( 481 predictor, point, label, image_embeddings=image_embeddings, i=i 482 ).squeeze() 483 484 seg[prediction] = seg_id 485 486 return seg 487 488 489def prompt_segmentation( 490 predictor, points, labels, boxes, masks, shape, multiple_box_prompts, 491 image_embeddings=None, i=None, box_extension=0, batched=None, previous_segmentation=None, 492): 493 """@private""" 494 assert len(points) == len(labels) 495 have_points = len(points) > 0 496 have_boxes = len(boxes) > 0 497 498 # No prompts were given, return None. 499 if not have_points and not have_boxes: 500 return 501 502 # Batched interactive segmentation. 503 elif batched: 504 assert previous_segmentation is not None 505 seg = _batched_interactive_segmentation( 506 predictor, points, labels, boxes, image_embeddings, i, previous_segmentation 507 ) 508 509 # Box and point prompts were given. 510 elif have_points and have_boxes: 511 if len(boxes) > 1: 512 print("You have provided point prompts and more than one box prompt.") 513 print("This setting is currently not supported.") 514 print("When providing both points and prompts you can only segment one object at a time.") 515 return 516 mask = masks[0] 517 if mask is None: 518 seg = prompt_based_segmentation.segment_from_box_and_points( 519 predictor, boxes[0], points, labels, image_embeddings=image_embeddings, i=i 520 ).squeeze() 521 else: 522 seg = prompt_based_segmentation.segment_from_mask( 523 predictor, mask, box=boxes[0], points=points, labels=labels, image_embeddings=image_embeddings, i=i 524 ).squeeze() 525 526 # Only point prompts were given. 527 elif have_points and not have_boxes: 528 seg = prompt_based_segmentation.segment_from_points( 529 predictor, points, labels, image_embeddings=image_embeddings, i=i 530 ).squeeze() 531 532 # Only box prompts were given. 533 elif not have_points and have_boxes: 534 seg = np.zeros(shape, dtype="uint32") 535 536 if len(boxes) > 1 and not multiple_box_prompts: 537 print("You have provided more than one box annotation. This is not yet supported in the 3d annotator.") 538 print("You can only segment one object at a time in 3d.") 539 return 540 541 # Batch this? 542 for seg_id, (box, mask) in enumerate(zip(boxes, masks), 1): 543 if mask is None: 544 prediction = prompt_based_segmentation.segment_from_box( 545 predictor, box, image_embeddings=image_embeddings, i=i 546 ).squeeze() 547 else: 548 prediction = prompt_based_segmentation.segment_from_mask( 549 predictor, mask, box=box, image_embeddings=image_embeddings, i=i, 550 box_extension=box_extension, 551 ).squeeze() 552 seg[prediction] = seg_id 553 554 return seg 555 556 557def _compute_movement(seg, t0, t1): 558 559 def compute_center(t): 560 # computation with center of mass 561 center = np.where(seg[t] == 1) 562 center = np.array(np.mean(center[0]), np.mean(center[1])) 563 return center 564 565 center0 = compute_center(t0) 566 center1 = compute_center(t1) 567 568 move = center1 - center0 569 return move.astype("float64") 570 571 572def _shift_object(mask, motion_model): 573 mask_shifted = np.zeros_like(mask) 574 shift(mask, motion_model, output=mask_shifted, order=0, prefilter=False) 575 return mask_shifted 576 577 578def track_from_prompts( 579 point_prompts, box_prompts, seg, predictor, slices, image_embeddings, 580 stop_upper, threshold, projection, motion_smoothing=0.5, box_extension=0, update_progress=None, 581): 582 """@private 583 """ 584 use_box, use_mask, use_points, use_single_point = _validate_projection(projection) 585 586 if update_progress is None: 587 def update_progress(*args): 588 pass 589 590 # shift the segmentation based on the motion model and update the motion model 591 def _update_motion_model(seg, t, t0, motion_model): 592 if t in (t0, t0 + 1): # this is the first or second frame, we don't have a motion yet 593 pass 594 elif t == t0 + 2: # this the third frame, we initialize the motion model 595 current_move = _compute_movement(seg, t - 1, t - 2) 596 motion_model = current_move 597 else: # we already have a motion model and update it 598 current_move = _compute_movement(seg, t - 1, t - 2) 599 alpha = motion_smoothing 600 motion_model = alpha * motion_model + (1 - alpha) * current_move 601 602 return motion_model 603 604 has_division = False 605 motion_model = None 606 verbose = False 607 608 t0 = int(slices.min()) 609 t = t0 + 1 610 while True: 611 612 # update the motion model 613 motion_model = _update_motion_model(seg, t, t0, motion_model) 614 615 # use the segmentation from prompts if we are in a slice with prompts 616 if t in slices: 617 seg_prev = None 618 seg_t = seg[t] 619 # currently using the box layer doesn't work for keeping track of the track state 620 # track_state = prompt_layers_to_state(point_prompts, box_prompts, t) 621 track_state = prompt_layer_to_state(point_prompts, t) 622 623 # otherwise project the mask (under the motion model) and segment the next slice from the mask 624 else: 625 if verbose: 626 print(f"Tracking object in frame {t} with movement {motion_model}") 627 628 seg_prev = seg[t - 1] 629 # shift the segmentation according to the motion model 630 if motion_model is not None: 631 seg_prev = _shift_object(seg_prev, motion_model) 632 633 seg_t = prompt_based_segmentation.segment_from_mask( 634 predictor, seg_prev, image_embeddings=image_embeddings, i=t, 635 use_mask=use_mask, use_box=use_box, use_points=use_points, 636 box_extension=box_extension, use_single_point=use_single_point, 637 ) 638 track_state = "track" 639 640 # are we beyond the last slice with prompt? 641 # if no: we continue tracking because we know we need to connect to a future frame 642 # if yes: we only continue tracking if overlaps are above the threshold 643 if t < slices[-1]: 644 seg_prev = None 645 646 update_progress(1) 647 648 if (threshold is not None) and (seg_prev is not None): 649 iou = util.compute_iou(seg_prev, seg_t) 650 if iou < threshold: 651 msg = f"Segmentation stopped at frame {t} due to IOU {iou} < {threshold}." 652 print(msg) 653 break 654 655 # stop if we have a division 656 if track_state == "division": 657 has_division = True 658 break 659 660 seg[t] = seg_t 661 t += 1 662 663 # stop tracking if we have stop upper set (i.e. single negative point was set to indicate stop track) 664 if t == slices[-1] and stop_upper: 665 break 666 667 # stop if we are at the last slce 668 if t == seg.shape[0]: 669 break 670 671 return seg, has_division 672 673 674def _sync_embedding_widget(widget, model_type, save_path, checkpoint_path, device, tile_shape, halo): 675 widget.model_type = model_type 676 index = widget.model_dropdown.findText(model_type) 677 if index > 0: 678 widget.model_dropdown.setCurrentIndex(index) 679 680 if save_path is not None: 681 widget.embeddings_save_path_param.setText(str(save_path)) 682 683 if checkpoint_path is not None: 684 widget.custom_weights_param.setText(str(checkpoint_path)) 685 686 if device is not None: 687 widget.device = device 688 index = widget.device_dropdown.findText(device) 689 widget.device_dropdown.setCurrentIndex(index) 690 691 if tile_shape is not None: 692 widget.tile_x_param.setValue(tile_shape[0]) 693 widget.tile_y_param.setValue(tile_shape[1]) 694 695 if halo is not None: 696 widget.halo_x_param.setValue(halo[0]) 697 widget.halo_y_param.setValue(halo[1]) 698 699 700# Read parameters from checkpoint path if it is given instead. 701def _sync_autosegment_widget(widget, model_type, checkpoint_path, update_decoder=None): 702 if update_decoder is not None: 703 widget._reset_segmentation_mode(update_decoder) 704 705 if widget.with_decoder: 706 settings = model_settings.AIS_SETTINGS.get(model_type, {}) 707 params = ("center_distance_thresh", "boundary_distance_thresh") 708 for param in params: 709 if param in settings: 710 getattr(widget, f"{param}_param").setValue(settings[param]) 711 else: 712 settings = model_settings.AMG_SETTINGS.get(model_type, {}) 713 params = ("pred_iou_thresh", "stability_score_thresh", "min_object_size") 714 for param in params: 715 if param in settings: 716 getattr(widget, f"{param}_param").setValue(settings[param]) 717 718 719# Read parameters from checkpoint path if it is given instead. 720def _sync_ndsegment_widget(widget, model_type, checkpoint_path): 721 settings = model_settings.ND_SEGMENT_SETTINGS.get(model_type, {}) 722 723 if "projection_mode" in settings: 724 projection_mode = settings["projection_mode"] 725 widget.projection = projection_mode 726 index = widget.projection_dropdown.findText(projection_mode) 727 if index > 0: 728 widget.projection_dropdown.setCurrentIndex(index) 729 730 params = ("iou_threshold", "box_extension") 731 for param in params: 732 if param in settings: 733 getattr(widget, f"{param}_param").setValue(settings[param]) 734 735 736def _load_amg_state(embedding_path): 737 if embedding_path is None or not os.path.exists(embedding_path): 738 return {"cache_folder": None} 739 740 cache_folder = os.path.join(embedding_path, "amg_state") 741 os.makedirs(cache_folder, exist_ok=True) 742 amg_state = {"cache_folder": cache_folder} 743 744 state_paths = glob(os.path.join(cache_folder, "*.pkl")) 745 for path in state_paths: 746 with open(path, "rb") as f: 747 state = pickle.load(f) 748 i = int(Path(path).stem.split("-")[-1]) 749 amg_state[i] = state 750 return amg_state 751 752 753def _load_is_state(embedding_path): 754 if embedding_path is None or not os.path.exists(embedding_path): 755 return {"cache_path": None} 756 757 cache_path = os.path.join(embedding_path, "is_state.h5") 758 is_state = {"cache_path": cache_path} 759 760 with h5py.File(cache_path, "a") as f: 761 for name, g in f.items(): 762 i = int(name.split("-")[-1]) 763 state = { 764 "foreground": g["foreground"][:], 765 "boundary_distances": g["boundary_distances"][:], 766 "center_distances": g["center_distances"][:], 767 } 768 is_state[i] = state 769 770 return is_state
def
point_layer_to_prompts( layer: napari.layers.points.points.Points, i=None, track_id=None, with_stop_annotation=True) -> Optional[Tuple[numpy.ndarray, numpy.ndarray]]:
154def point_layer_to_prompts( 155 layer: napari.layers.Points, i=None, track_id=None, with_stop_annotation=True, 156) -> Optional[Tuple[np.ndarray, np.ndarray]]: 157 """Extract point prompts for SAM from a napari point layer. 158 159 Args: 160 layer: The point layer from which to extract the prompts. 161 i: Index for the data (required for 3d or timeseries data). 162 track_id: Id of the current track (required for tracking data). 163 with_stop_annotation: Whether a single negative point will be interpreted 164 as stop annotation or just returned as normal prompt. 165 166 Returns: 167 The point coordinates for the prompts. 168 The labels (positive or negative / 1 or 0) for the prompts. 169 """ 170 171 points = layer.data 172 labels = layer.properties["label"] 173 assert len(points) == len(labels) 174 175 if i is None: 176 assert points.shape[1] == 2, f"{points.shape}" 177 this_points, this_labels = points, labels 178 else: 179 assert points.shape[1] == 3, f"{points.shape}" 180 mask = np.round(points[:, 0]) == i 181 this_points = points[mask][:, 1:] 182 this_labels = labels[mask] 183 assert len(this_points) == len(this_labels) 184 185 if track_id is not None: 186 assert i is not None 187 track_ids = np.array(list(map(int, layer.properties["track_id"])))[mask] 188 track_id_mask = track_ids == track_id 189 this_labels, this_points = this_labels[track_id_mask], this_points[track_id_mask] 190 assert len(this_points) == len(this_labels) 191 192 this_labels = np.array([1 if label == "positive" else 0 for label in this_labels]) 193 # a single point with a negative label is interpreted as 'stop' signal 194 # in this case we return None 195 if with_stop_annotation and (len(this_points) == 1 and this_labels[0] == 0): 196 return None 197 198 return this_points, this_labels
Extract point prompts for SAM from a napari point layer.
Arguments:
- layer: The point layer from which to extract the prompts.
- i: Index for the data (required for 3d or timeseries data).
- track_id: Id of the current track (required for tracking data).
- with_stop_annotation: Whether a single negative point will be interpreted as stop annotation or just returned as normal prompt.
Returns:
The point coordinates for the prompts. The labels (positive or negative / 1 or 0) for the prompts.
def
shape_layer_to_prompts( layer: napari.layers.shapes.shapes.Shapes, shape: Tuple[int, int], i=None, track_id=None) -> Tuple[List[numpy.ndarray], List[Optional[numpy.ndarray]]]:
201def shape_layer_to_prompts( 202 layer: napari.layers.Shapes, shape: Tuple[int, int], i=None, track_id=None 203) -> Tuple[List[np.ndarray], List[Optional[np.ndarray]]]: 204 """Extract prompts for SAM from a napari shape layer. 205 206 Extracts the bounding box for 'rectangle' shapes and the bounding box and corresponding mask 207 for 'ellipse' and 'polygon' shapes. 208 209 Args: 210 prompt_layer: The napari shape layer. 211 shape: The image shape. 212 i: Index for the data (required for 3d or timeseries data). 213 track_id: Id of the current track (required for tracking data). 214 215 Returns: 216 The box prompts. 217 The mask prompts. 218 """ 219 220 def _to_prompts(shape_data, shape_types): 221 boxes, masks = [], [] 222 223 for data, type_ in zip(shape_data, shape_types): 224 225 if type_ == "rectangle": 226 boxes.append(data) 227 masks.append(None) 228 229 elif type_ == "ellipse": 230 boxes.append(data) 231 center = np.mean(data, axis=0) 232 radius_r = ((data[2] - data[1]) / 2)[0] 233 radius_c = ((data[1] - data[0]) / 2)[1] 234 rr, cc = draw.ellipse(center[0], center[1], radius_r, radius_c, shape=shape) 235 mask = np.zeros(shape, dtype=bool) 236 mask[rr, cc] = 1 237 masks.append(mask) 238 239 elif type_ == "polygon": 240 boxes.append(data) 241 rr, cc = draw.polygon(data[:, 0], data[:, 1], shape=shape) 242 mask = np.zeros(shape, dtype=bool) 243 mask[rr, cc] = 1 244 masks.append(mask) 245 246 else: 247 warnings.warn(f"Shape type {type_} is not supported and will be ignored.") 248 249 # map to correct box format 250 boxes = [ 251 np.array([box[:, 0].min(), box[:, 1].min(), box[:, 0].max(), box[:, 1].max()]) for box in boxes 252 ] 253 return boxes, masks 254 255 shape_data, shape_types = layer.data, layer.shape_type 256 assert len(shape_data) == len(shape_types) 257 if len(shape_data) == 0: 258 return [], [] 259 260 if i is not None: 261 if track_id is None: 262 prompt_selection = [j for j, data in enumerate(shape_data) if (data[:, 0] == i).all()] 263 else: 264 track_ids = np.array(list(map(int, layer.properties["track_id"]))) 265 prompt_selection = [ 266 j for j, (data, this_track_id) in enumerate(zip(shape_data, track_ids)) 267 if ((data[:, 0] == i).all() and this_track_id == track_id) 268 ] 269 270 shape_data = [shape_data[j][:, 1:] for j in prompt_selection] 271 shape_types = [shape_types[j] for j in prompt_selection] 272 273 boxes, masks = _to_prompts(shape_data, shape_types) 274 return boxes, masks
Extract prompts for SAM from a napari shape layer.
Extracts the bounding box for 'rectangle' shapes and the bounding box and corresponding mask for 'ellipse' and 'polygon' shapes.
Arguments:
- prompt_layer: The napari shape layer.
- shape: The image shape.
- i: Index for the data (required for 3d or timeseries data).
- track_id: Id of the current track (required for tracking data).
Returns:
The box prompts. The mask prompts.
def
prompt_layer_to_state(prompt_layer: napari.layers.points.points.Points, i: int) -> str:
277def prompt_layer_to_state(prompt_layer: napari.layers.Points, i: int) -> str: 278 """Get the state of the track from a point layer for a given timeframe. 279 280 Only relevant for annotator_tracking. 281 282 Args: 283 prompt_layer: The napari layer. 284 i: Timeframe of the data. 285 286 Returns: 287 The state of this frame (either "division" or "track"). 288 """ 289 state = prompt_layer.properties["state"] 290 291 points = prompt_layer.data 292 assert points.shape[1] == 3, f"{points.shape}" 293 mask = points[:, 0] == i 294 this_points = points[mask][:, 1:] 295 this_state = state[mask] 296 assert len(this_points) == len(this_state) 297 298 # we set the state to 'division' if at least one point in this frame has a division label 299 if any(st == "division" for st in this_state): 300 return "division" 301 else: 302 return "track"
Get the state of the track from a point layer for a given timeframe.
Only relevant for annotator_tracking.
Arguments:
- prompt_layer: The napari layer.
- i: Timeframe of the data.
Returns:
The state of this frame (either "division" or "track").
def
prompt_layers_to_state( point_layer: napari.layers.points.points.Points, box_layer: napari.layers.shapes.shapes.Shapes, i: int) -> str:
305def prompt_layers_to_state(point_layer: napari.layers.Points, box_layer: napari.layers.Shapes, i: int) -> str: 306 """Get the state of the track from a point layer and shape layer for a given timeframe. 307 308 Only relevant for annotator_tracking. 309 310 Args: 311 point_layer: The napari point layer. 312 box_layer: The napari box layer. 313 i: Timeframe of the data. 314 315 Returns: 316 The state of this frame (either "division" or "track"). 317 """ 318 state = point_layer.properties["state"] 319 320 points = point_layer.data 321 assert points.shape[1] == 3, f"{points.shape}" 322 mask = points[:, 0] == i 323 if mask.sum() > 0: 324 this_state = state[mask].tolist() 325 else: 326 this_state = [] 327 328 box_states = box_layer.properties["state"] 329 this_box_states = [ 330 state for box, state in zip(box_layer.data, box_states) 331 if (box[:, 0] == i).all() 332 ] 333 this_state.extend(this_box_states) 334 335 # we set the state to 'division' if at least one point in this frame has a division label 336 if any(st == "division" for st in this_state): 337 return "division" 338 else: 339 return "track"
Get the state of the track from a point layer and shape layer for a given timeframe.
Only relevant for annotator_tracking.
Arguments:
- point_layer: The napari point layer.
- box_layer: The napari box layer.
- i: Timeframe of the data.
Returns:
The state of this frame (either "division" or "track").