Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions micro_sam/sam_annotator/_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import zarr
import numpy as np

from napari.layers import Image
from qtpy.QtWidgets import QWidget

import torch.nn as nn
Expand All @@ -19,7 +20,6 @@
from micro_sam.instance_segmentation import AMGBase, get_decoder
from micro_sam.precompute_state import cache_amg_state, cache_is_state

from napari.layers import Image
from segment_anything import SamPredictor

try:
Expand Down Expand Up @@ -102,6 +102,7 @@ def initialize_predictor(
pbar_update=None,
skip_load=True,
use_cli=False,
is_sam2=False, # By default, we use SAM1.
):
assert ndim in (2, 3)

Expand Down Expand Up @@ -132,8 +133,13 @@ def progress_bar_factory(model_type):
self.image_embeddings = save_path
self.embedding_path = None # setting this to 'None' as we do not have embeddings cached.

else: # otherwise, compute the image embeddings.
self.image_embeddings = util.precompute_image_embeddings(
else: # Otherwise, compute the image embeddings.
if is_sam2:
from micro_sam2.util import precompute_image_embeddings as _comp_embed_fn
else:
_comp_embed_fn = util.precompute_image_embeddings

self.image_embeddings = _comp_embed_fn(
predictor=self.predictor,
input_=image_data,
save_path=save_path,
Expand Down
132 changes: 107 additions & 25 deletions micro_sam/sam_annotator/_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
# from napari.qt.threading import thread_worker
from napari.utils import progress

from segment_anything import SamPredictor

from . import util as vutil
from ._tooltips import get_tooltip
from ._state import AnnotatorState
Expand Down Expand Up @@ -237,8 +239,11 @@ def _get_model_size_options(self):
# We store the actual model names mapped to UI labels.
self.model_size_mapping = {}
if self.model_family == "Natural Images (SAM)":
self.model_size_options = list(self._model_size_map .values())
self.model_size_options = list(self._model_size_map.values())
self.model_size_mapping = {self._model_size_map[k]: f"vit_{k}" for k in self._model_size_map.keys()}
elif self.model_family == "Natural Images (SAM2)":
self.model_size_options = list(self._model_size_map.values())
self.model_size_mapping = {self._model_size_map[k]: f"hvit_{k}" for k in self._model_size_map.keys()}
else:
model_suffix = self.supported_dropdown_maps[self.model_family]
self.model_size_options = []
Expand Down Expand Up @@ -278,7 +283,10 @@ def _update_model_type(self):
size_key = next(
(k for k, v in self._model_size_map.items() if v == self.model_size), "b"
)
self.model_type = f"vit_{size_key}" + self.supported_dropdown_maps[self.model_family]
if "SAM2" in self.model_family:
self.model_type = f"hvit_{size_key}"
else:
self.model_type = f"vit_{size_key}" + self.supported_dropdown_maps[self.model_family]

self.model_size_dropdown.setCurrentText(self.model_size) # Apply the selected text to the dropdown

Expand All @@ -293,6 +301,7 @@ def _create_model_section(self, default_model: str = util._DEFAULT_MODEL, create
# Create a list of support dropdown values and correspond them to suffixes.
self.supported_dropdown_maps = {
"Natural Images (SAM)": "",
"Natural Images (SAM2)": "_sam2",
"Light Microscopy": "_lm",
"Electron Microscopy": "_em_organelles",
"Medical Imaging": "_medical_imaging",
Expand Down Expand Up @@ -343,7 +352,10 @@ def _create_model_size_section(self):

def _validate_model_type_and_custom_weights(self):
# Let's get all model combination stuff into the desired `model_type` structure.
self.model_type = "vit_" + self.model_size[0] + self.supported_dropdown_maps[self.model_family]
if "SAM2" in self.model_family:
self.model_type = "hvit_" + self.model_size[0]
else:
self.model_type = "vit_" + self.model_size[0] + self.supported_dropdown_maps[self.model_family]

# For 'custom_weights', we remove the displayed text on top of the drop-down menu.
if self.custom_weights:
Expand Down Expand Up @@ -1014,10 +1026,21 @@ def segment(viewer: "napari.viewer.Viewer", batched: bool = False) -> None:

predictor = AnnotatorState().predictor
image_embeddings = AnnotatorState().image_embeddings
seg = vutil.prompt_segmentation(
predictor, points, labels, boxes, masks, shape, image_embeddings=image_embeddings,
multiple_box_prompts=True, batched=batched, previous_segmentation=viewer.layers["current_object"].data,
)

if isinstance(predictor, SamPredictor): # This is SAM1 predictor.
seg = vutil.prompt_segmentation(
predictor, points, labels, boxes, masks, shape, image_embeddings=image_embeddings,
multiple_box_prompts=True, batched=batched, previous_segmentation=viewer.layers["current_object"].data,
)
else: # This would be SAM2 predictors.
from micro_sam2.prompt_based_segmentation import promptable_segmentation_2d
seg = promptable_segmentation_2d(
predictor=predictor,
points=points,
labels=labels,
boxes=boxes,
masks=masks,
)

# no prompts were given or prompts were invalid, skip segmentation
if seg is None:
Expand Down Expand Up @@ -1451,11 +1474,35 @@ def pbar_init(total, description):
if self.automatic_segmentation_mode == "amg":
prefer_decoder = False

# Define a predictor for SAM2 models.
predictor = None
if self.model_type.startswith("h"): # i.e. SAM2 models.
from micro_sam2.util import get_sam2_model

if ndim == 2: # Get the SAM2 model and prepare the image predictor.
model = get_sam2_model(model_type=self.model_type, input_type="images")
# Prepare the SAM2 predictor.
from sam2.sam2_image_predictor import SAM2ImagePredictor
predictor = SAM2ImagePredictor(model)
elif ndim == 3: # Get SAM2 video predictor
predictor = get_sam2_model(model_type=self.model_type, input_type="videos")
else:
raise ValueError

state.initialize_predictor(
image_data, model_type=self.model_type, save_path=save_path, ndim=ndim,
device=self.device, checkpoint_path=self.custom_weights, tile_shape=tile_shape, halo=halo,
prefer_decoder=prefer_decoder, pbar_init=pbar_init,
image_data,
model_type=self.model_type,
save_path=save_path,
ndim=ndim,
device=self.device,
checkpoint_path=self.custom_weights,
predictor=predictor,
tile_shape=tile_shape,
halo=halo,
prefer_decoder=prefer_decoder,
pbar_init=pbar_init,
pbar_update=lambda update: pbar_signals.pbar_update.emit(update),
is_sam2=self.model_type.startswith("h"),
)
pbar_signals.pbar_stop.emit()

Expand Down Expand Up @@ -1613,24 +1660,59 @@ def volumetric_segmentation_impl():
pbar_signals.pbar_total.emit(shape[0])
pbar_signals.pbar_description.emit("Segment object")

# Step 1: Segment all slices with prompts.
seg, slices, stop_lower, stop_upper = vutil.segment_slices_with_prompts(
state.predictor, self._viewer.layers["point_prompts"], self._viewer.layers["prompts"],
state.image_embeddings, shape,
update_progress=lambda update: pbar_signals.pbar_update.emit(update),
)
if isinstance(state.predictor, SamPredictor): # This is SAM2 predictor.
# Step 1: Segment all slices with prompts.
seg, slices, stop_lower, stop_upper = vutil.segment_slices_with_prompts(
state.predictor, self._viewer.layers["point_prompts"], self._viewer.layers["prompts"],
state.image_embeddings, shape,
update_progress=lambda update: pbar_signals.pbar_update.emit(update),
)

# Step 2: Segment the rest of the volume based on projecting prompts.
seg, (z_min, z_max) = segment_mask_in_volume(
seg, state.predictor, state.image_embeddings, slices,
stop_lower, stop_upper,
iou_threshold=self.iou_threshold, projection=self.projection,
box_extension=self.box_extension,
update_progress=lambda update: pbar_signals.pbar_update.emit(update),
)

state.z_range = (z_min, z_max)

else: # This would be SAM2 predictors.
# Prepare the prompts
point_prompts = self._viewer.layers["point_prompts"]
box_prompts = self._viewer.layers["prompts"]
z_values = np.round(point_prompts.data[:, 0])
z_values_boxes = np.concatenate(
[box[:1, 0] for box in box_prompts.data]
) if box_prompts.data else np.zeros(0, dtype="int")

# TODO: Make this modular.
if z_values:
frame_id = z_values[0]
else:
frame_id = z_values_boxes[0]

# Get the volume
# TODO: We need to switch later to volume embeddings.
volume = self._viewer.layers[0].data # Assumption is image is in the first index.

points, labels = vutil.point_layer_to_prompts(point_prompts, frame_id)
boxes, masks = vutil.shape_layer_to_prompts(box_prompts, state.image_shape, i=frame_id)

from micro_sam2.prompt_based_segmentation import promptable_segmentation_3d
seg = promptable_segmentation_3d(
predictor=state.predictor,
volume=volume,
frame_id=frame_id,
points=points,
labels=labels,
boxes=boxes,
)

# Step 2: Segment the rest of the volume based on projecting prompts.
seg, (z_min, z_max) = segment_mask_in_volume(
seg, state.predictor, state.image_embeddings, slices,
stop_lower, stop_upper,
iou_threshold=self.iou_threshold, projection=self.projection,
box_extension=self.box_extension,
update_progress=lambda update: pbar_signals.pbar_update.emit(update),
)
pbar_signals.pbar_stop.emit()

state.z_range = (z_min, z_max)
return seg

def update_segmentation(seg):
Expand Down
10 changes: 10 additions & 0 deletions micro_sam/sam_annotator/annotator_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ def _update_image(self, segmentation_result=None):
else:
state.amg_state = _load_amg_state(state.embedding_path)

# NOTE: this is an intermediate solution. We should re-design the plugin for SAM2 eventually.
# If we have a SAM2 model, then remove the widget for segmenting single slices.
model_type = self._embedding_widget.model_type
if model_type.startswith("hvit") and "segment" in self._widgets:
widget = self._widgets.pop("segment")
layout = self._annotator_widget.layout()
layout.removeWidget(widget.native)
widget.native.setParent(None)
widget.native.deleteLater()


def annotator_3d(
image: np.ndarray,
Expand Down
3 changes: 2 additions & 1 deletion micro_sam/sam_annotator/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,8 @@ def _sync_embedding_widget(widget, model_type, save_path, checkpoint_path, devic

# Update the index for model size, eg. 'base', 'tiny', etc.
size_map = {"t": "tiny", "b": "base", "l": "large", "h": "huge"}
model_size = size_map[model_type[4]]
size_idx = 5 if model_type.startswith("h") else 4
model_size = size_map[model_type[size_idx]]

index = widget.model_size_dropdown.findText(model_size)
if index > 0:
Expand Down
Loading