diff --git a/micro_sam/util.py b/micro_sam/util.py index a817ca785..992122ddb 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -593,25 +593,39 @@ def get_model_names() -> Iterable: # -def _to_image(input_): - # we require the input to be uint8 - if input_.dtype != np.dtype("uint8"): - # first normalize the input to [0, 1] - input_ = input_.astype("float32") - input_.min() - input_ = input_ / input_.max() - # then bring to [0, 255] and cast to uint8 - input_ = (input_ * 255).astype("uint8") - - if input_.ndim == 2: - image = np.concatenate([input_[..., None]] * 3, axis=-1) - elif input_.ndim == 3 and input_.shape[-1] == 3: - image = input_ - else: - raise ValueError(f"Invalid input image of shape {input_.shape}. Expect either 2D grayscale or 3D RGB image.") +def _normalize_channel(input_, min_val=None, max_val=None): + # First normalize the input to [0, 1]. + input_ = input_.astype("float32") + min_val = np.percentile(input_, 1) if min_val is None else min_val + input_ = input_ - min_val + max_val = np.percentile(input_, 99) if max_val is None else max_val + input_ = input_ / (max_val + 1e-7) + # Then bring it to [0, 255] and cast to uint8. + input_ = (np.clip(input_, 0, 1) * 255).astype("uint8") + return input_ + + +def _to_image(input_, min_=None, max_=None): + # Explicitly return a numpy array for compatibility with torchvision, + # because the input_ array could be something like dask array. + image = np.array(input_) + + if image.ndim == 2: + image = image[..., None] + elif image.ndim != 3 or (image.shape[-1] not in (1, 3)): + raise ValueError(f"Invalid input image of shape {input_.shape}. Expect either grayscale or RGB image.") + + image_normalized = np.zeros(image.shape, dtype="uint8") + for c in range(image.shape[2]): + min_val = None if min_ is None else min_[c] + max_val = None if max_ is None else max_[c] + image_normalized[..., c] = _normalize_channel(image[..., c], min_val=min_val, max_val=max_val) - # explicitly return a numpy array for compatibility with torchvision - # because the input_ array could be something like dask array - return np.array(image) + if image_normalized.shape[-1] == 1: + image_normalized = np.concatenate([image_normalized] * 3, axis=-1) + assert image_normalized.shape[2] == 3, f"{image_normalized.shape}" + + return image_normalized @torch.no_grad @@ -690,6 +704,7 @@ def _compute_tiled_features_2d(predictor, input_, tile_shape, halo, f, pbar_init pbar_init(n_tiles, "Compute Image Embeddings 2D tiled") n_batches = int(np.ceil(n_tiles / batch_size)) + input_ = _to_image(input_) for batch_id in range(n_batches): tile_start = batch_id * batch_size tile_stop = min(tile_start + batch_size, n_tiles) @@ -698,7 +713,7 @@ def _compute_tiled_features_2d(predictor, input_, tile_shape, halo, f, pbar_init for tile_id in range(tile_start, tile_stop): tile = tiling.getBlockWithHalo(tile_id, list(halo)) outer_tile = tuple(slice(beg, end) for beg, end in zip(tile.outerBlock.begin, tile.outerBlock.end)) - tile_input = _to_image(input_[outer_tile]) + tile_input = input_[outer_tile] batched_images.append(tile_input) batched_embeddings, original_sizes, input_sizes = _compute_embeddings_batched(predictor, batched_images) @@ -715,6 +730,18 @@ def _compute_tiled_features_2d(predictor, input_, tile_shape, halo, f, pbar_init return features +def _precompute_stats(input_): + stats = {} + input_ = input_[..., None] if input_.ndim == 3 else input_ + assert input_.ndim == 4 + for z in range(input_.shape[0]): + min_ = {c: np.percentile(input_[z, ..., c], 1) for c in range(input_.shape[3])} + # TODO double check + max_ = {c: np.percentile(input_[z, ..., c], 99) - min_[c] for c in range(input_.shape[3])} + stats[z] = {"min": min_, "max": max_} + return stats + + def _compute_tiled_features_3d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size): assert input_.ndim == 3 @@ -733,6 +760,9 @@ def _compute_tiled_features_3d(predictor, input_, tile_shape, halo, f, pbar_init # We batch across the z axis. n_batches = int(np.ceil(n_slices / batch_size)) + # Precompute min and max for each slice. + stats = _precompute_stats(input_) + for tile_id in range(n_tiles): tile = tiling.getBlockWithHalo(tile_id, list(halo)) outer_tile = tuple(slice(beg, end) for beg, end in zip(tile.outerBlock.begin, tile.outerBlock.end)) @@ -744,7 +774,7 @@ def _compute_tiled_features_3d(predictor, input_, tile_shape, halo, f, pbar_init batched_images = [] for z in range(z_start, z_stop): - tile_input = _to_image(input_[z][outer_tile]) + tile_input = _to_image(input_[z][outer_tile], min_=stats[z]["min"], max_=stats[z]["max"]) batched_images.append(tile_input) batched_embeddings, original_sizes, input_sizes = _compute_embeddings_batched(predictor, batched_images) @@ -858,8 +888,8 @@ def _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_u # Skip feature computation in case of partial features in non-zero slice. if partial_features and np.count_nonzero(features[z]) != 0: continue - tile_input = _to_image(input_[z]) - batched_images.append(tile_input) + batch_input = _to_image(input_[z]) + batched_images.append(batch_input) batched_z.append(z) batched_embeddings, original_sizes, input_sizes = _compute_embeddings_batched(predictor, batched_images)