diff --git a/micro_sam/sam_annotator/_state.py b/micro_sam/sam_annotator/_state.py index c15a1fcf..2d3ca86c 100644 --- a/micro_sam/sam_annotator/_state.py +++ b/micro_sam/sam_annotator/_state.py @@ -10,6 +10,7 @@ import zarr import numpy as np +from napari.layers import Image from qtpy.QtWidgets import QWidget import torch.nn as nn @@ -19,7 +20,6 @@ from micro_sam.instance_segmentation import AMGBase, get_decoder from micro_sam.precompute_state import cache_amg_state, cache_is_state -from napari.layers import Image from segment_anything import SamPredictor try: @@ -102,6 +102,7 @@ def initialize_predictor( pbar_update=None, skip_load=True, use_cli=False, + is_sam2=False, # By default, we use SAM1. ): assert ndim in (2, 3) @@ -132,8 +133,13 @@ def progress_bar_factory(model_type): self.image_embeddings = save_path self.embedding_path = None # setting this to 'None' as we do not have embeddings cached. - else: # otherwise, compute the image embeddings. - self.image_embeddings = util.precompute_image_embeddings( + else: # Otherwise, compute the image embeddings. + if is_sam2: + from micro_sam2.util import precompute_image_embeddings as _comp_embed_fn + else: + _comp_embed_fn = util.precompute_image_embeddings + + self.image_embeddings = _comp_embed_fn( predictor=self.predictor, input_=image_data, save_path=save_path, diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index 34d8592a..ba08f30e 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -31,6 +31,8 @@ # from napari.qt.threading import thread_worker from napari.utils import progress +from segment_anything import SamPredictor + from . import util as vutil from ._tooltips import get_tooltip from ._state import AnnotatorState @@ -237,8 +239,11 @@ def _get_model_size_options(self): # We store the actual model names mapped to UI labels. self.model_size_mapping = {} if self.model_family == "Natural Images (SAM)": - self.model_size_options = list(self._model_size_map .values()) + self.model_size_options = list(self._model_size_map.values()) self.model_size_mapping = {self._model_size_map[k]: f"vit_{k}" for k in self._model_size_map.keys()} + elif self.model_family == "Natural Images (SAM2)": + self.model_size_options = list(self._model_size_map.values()) + self.model_size_mapping = {self._model_size_map[k]: f"hvit_{k}" for k in self._model_size_map.keys()} else: model_suffix = self.supported_dropdown_maps[self.model_family] self.model_size_options = [] @@ -278,7 +283,10 @@ def _update_model_type(self): size_key = next( (k for k, v in self._model_size_map.items() if v == self.model_size), "b" ) - self.model_type = f"vit_{size_key}" + self.supported_dropdown_maps[self.model_family] + if "SAM2" in self.model_family: + self.model_type = f"hvit_{size_key}" + else: + self.model_type = f"vit_{size_key}" + self.supported_dropdown_maps[self.model_family] self.model_size_dropdown.setCurrentText(self.model_size) # Apply the selected text to the dropdown @@ -293,6 +301,7 @@ def _create_model_section(self, default_model: str = util._DEFAULT_MODEL, create # Create a list of support dropdown values and correspond them to suffixes. self.supported_dropdown_maps = { "Natural Images (SAM)": "", + "Natural Images (SAM2)": "_sam2", "Light Microscopy": "_lm", "Electron Microscopy": "_em_organelles", "Medical Imaging": "_medical_imaging", @@ -343,7 +352,10 @@ def _create_model_size_section(self): def _validate_model_type_and_custom_weights(self): # Let's get all model combination stuff into the desired `model_type` structure. - self.model_type = "vit_" + self.model_size[0] + self.supported_dropdown_maps[self.model_family] + if "SAM2" in self.model_family: + self.model_type = "hvit_" + self.model_size[0] + else: + self.model_type = "vit_" + self.model_size[0] + self.supported_dropdown_maps[self.model_family] # For 'custom_weights', we remove the displayed text on top of the drop-down menu. if self.custom_weights: @@ -1014,10 +1026,21 @@ def segment(viewer: "napari.viewer.Viewer", batched: bool = False) -> None: predictor = AnnotatorState().predictor image_embeddings = AnnotatorState().image_embeddings - seg = vutil.prompt_segmentation( - predictor, points, labels, boxes, masks, shape, image_embeddings=image_embeddings, - multiple_box_prompts=True, batched=batched, previous_segmentation=viewer.layers["current_object"].data, - ) + + if isinstance(predictor, SamPredictor): # This is SAM1 predictor. + seg = vutil.prompt_segmentation( + predictor, points, labels, boxes, masks, shape, image_embeddings=image_embeddings, + multiple_box_prompts=True, batched=batched, previous_segmentation=viewer.layers["current_object"].data, + ) + else: # This would be SAM2 predictors. + from micro_sam2.prompt_based_segmentation import promptable_segmentation_2d + seg = promptable_segmentation_2d( + predictor=predictor, + points=points, + labels=labels, + boxes=boxes, + masks=masks, + ) # no prompts were given or prompts were invalid, skip segmentation if seg is None: @@ -1451,11 +1474,35 @@ def pbar_init(total, description): if self.automatic_segmentation_mode == "amg": prefer_decoder = False + # Define a predictor for SAM2 models. + predictor = None + if self.model_type.startswith("h"): # i.e. SAM2 models. + from micro_sam2.util import get_sam2_model + + if ndim == 2: # Get the SAM2 model and prepare the image predictor. + model = get_sam2_model(model_type=self.model_type, input_type="images") + # Prepare the SAM2 predictor. + from sam2.sam2_image_predictor import SAM2ImagePredictor + predictor = SAM2ImagePredictor(model) + elif ndim == 3: # Get SAM2 video predictor + predictor = get_sam2_model(model_type=self.model_type, input_type="videos") + else: + raise ValueError + state.initialize_predictor( - image_data, model_type=self.model_type, save_path=save_path, ndim=ndim, - device=self.device, checkpoint_path=self.custom_weights, tile_shape=tile_shape, halo=halo, - prefer_decoder=prefer_decoder, pbar_init=pbar_init, + image_data, + model_type=self.model_type, + save_path=save_path, + ndim=ndim, + device=self.device, + checkpoint_path=self.custom_weights, + predictor=predictor, + tile_shape=tile_shape, + halo=halo, + prefer_decoder=prefer_decoder, + pbar_init=pbar_init, pbar_update=lambda update: pbar_signals.pbar_update.emit(update), + is_sam2=self.model_type.startswith("h"), ) pbar_signals.pbar_stop.emit() @@ -1613,24 +1660,59 @@ def volumetric_segmentation_impl(): pbar_signals.pbar_total.emit(shape[0]) pbar_signals.pbar_description.emit("Segment object") - # Step 1: Segment all slices with prompts. - seg, slices, stop_lower, stop_upper = vutil.segment_slices_with_prompts( - state.predictor, self._viewer.layers["point_prompts"], self._viewer.layers["prompts"], - state.image_embeddings, shape, - update_progress=lambda update: pbar_signals.pbar_update.emit(update), - ) + if isinstance(state.predictor, SamPredictor): # This is SAM2 predictor. + # Step 1: Segment all slices with prompts. + seg, slices, stop_lower, stop_upper = vutil.segment_slices_with_prompts( + state.predictor, self._viewer.layers["point_prompts"], self._viewer.layers["prompts"], + state.image_embeddings, shape, + update_progress=lambda update: pbar_signals.pbar_update.emit(update), + ) + + # Step 2: Segment the rest of the volume based on projecting prompts. + seg, (z_min, z_max) = segment_mask_in_volume( + seg, state.predictor, state.image_embeddings, slices, + stop_lower, stop_upper, + iou_threshold=self.iou_threshold, projection=self.projection, + box_extension=self.box_extension, + update_progress=lambda update: pbar_signals.pbar_update.emit(update), + ) + + state.z_range = (z_min, z_max) + + else: # This would be SAM2 predictors. + # Prepare the prompts + point_prompts = self._viewer.layers["point_prompts"] + box_prompts = self._viewer.layers["prompts"] + z_values = np.round(point_prompts.data[:, 0]) + z_values_boxes = np.concatenate( + [box[:1, 0] for box in box_prompts.data] + ) if box_prompts.data else np.zeros(0, dtype="int") + + # TODO: Make this modular. + if z_values: + frame_id = z_values[0] + else: + frame_id = z_values_boxes[0] + + # Get the volume + # TODO: We need to switch later to volume embeddings. + volume = self._viewer.layers[0].data # Assumption is image is in the first index. + + points, labels = vutil.point_layer_to_prompts(point_prompts, frame_id) + boxes, masks = vutil.shape_layer_to_prompts(box_prompts, state.image_shape, i=frame_id) + + from micro_sam2.prompt_based_segmentation import promptable_segmentation_3d + seg = promptable_segmentation_3d( + predictor=state.predictor, + volume=volume, + frame_id=frame_id, + points=points, + labels=labels, + boxes=boxes, + ) - # Step 2: Segment the rest of the volume based on projecting prompts. - seg, (z_min, z_max) = segment_mask_in_volume( - seg, state.predictor, state.image_embeddings, slices, - stop_lower, stop_upper, - iou_threshold=self.iou_threshold, projection=self.projection, - box_extension=self.box_extension, - update_progress=lambda update: pbar_signals.pbar_update.emit(update), - ) pbar_signals.pbar_stop.emit() - state.z_range = (z_min, z_max) return seg def update_segmentation(seg): diff --git a/micro_sam/sam_annotator/annotator_3d.py b/micro_sam/sam_annotator/annotator_3d.py index ebc52592..6adef6ae 100644 --- a/micro_sam/sam_annotator/annotator_3d.py +++ b/micro_sam/sam_annotator/annotator_3d.py @@ -46,6 +46,16 @@ def _update_image(self, segmentation_result=None): else: state.amg_state = _load_amg_state(state.embedding_path) + # NOTE: this is an intermediate solution. We should re-design the plugin for SAM2 eventually. + # If we have a SAM2 model, then remove the widget for segmenting single slices. + model_type = self._embedding_widget.model_type + if model_type.startswith("hvit") and "segment" in self._widgets: + widget = self._widgets.pop("segment") + layout = self._annotator_widget.layout() + layout.removeWidget(widget.native) + widget.native.setParent(None) + widget.native.deleteLater() + def annotator_3d( image: np.ndarray, diff --git a/micro_sam/sam_annotator/util.py b/micro_sam/sam_annotator/util.py index 85096fb1..7c1bea4c 100644 --- a/micro_sam/sam_annotator/util.py +++ b/micro_sam/sam_annotator/util.py @@ -693,7 +693,8 @@ def _sync_embedding_widget(widget, model_type, save_path, checkpoint_path, devic # Update the index for model size, eg. 'base', 'tiny', etc. size_map = {"t": "tiny", "b": "base", "l": "large", "h": "huge"} - model_size = size_map[model_type[4]] + size_idx = 5 if model_type.startswith("h") else 4 + model_size = size_map[model_type[size_idx]] index = widget.model_size_dropdown.findText(model_size) if index > 0: