-
Notifications
You must be signed in to change notification settings - Fork 346
feat: Add visual similarity search and duplicate detection #686
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat: Add visual similarity search and duplicate detection #686
Conversation
|
|
WalkthroughThis PR introduces an image similarity feature combining an ONNX Runtime-based MobileNetV2 embedding engine, SQLite-backed embedding storage, and three new API endpoints for background scanning, duplicate detection, and similarity search, with startup initialization in the main app. Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant API Routes
participant SimilarityEngine
participant FileSystem as File System
participant Database as SQLite DB
rect rgba(100, 200, 150, 0.2)
note over Client,Database: Background Embedding Generation (POST /scan)
Client->>API Routes: POST /scan
API Routes->>API Routes: trigger background task
API Routes-->>Client: 202 Accepted
API Routes->>FileSystem: iterate image files
loop for each image
API Routes->>SimilarityEngine: compute_embedding(image_path)
SimilarityEngine->>SimilarityEngine: load_model() (first run only)
SimilarityEngine->>SimilarityEngine: preprocess_image(image_path)
SimilarityEngine->>FileSystem: read image
SimilarityEngine->>SimilarityEngine: normalize embedding
API Routes->>Database: db_save_embedding(path, embedding)
end
end
rect rgba(150, 180, 220, 0.2)
note over Client,Database: Query Embeddings (GET /duplicates or GET /search)
Client->>API Routes: GET /duplicates or /search
API Routes->>Database: db_get_all_embeddings()
Database-->>API Routes: {path: embedding} dict
API Routes->>API Routes: compute cosine similarities (numpy ops)
API Routes->>API Routes: filter/rank by threshold or score
API Routes-->>Client: results (duplicates or ranked similar)
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25–30 minutes
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (7)
backend/app/database/similarity.py (2)
31-44: Note the security implications of pickle serialization.Using
pickleis acceptable for local-only embeddings generated on-device. However, be aware that pickle can execute arbitrary code during deserialization, so this approach would be unsafe if embeddings ever came from external/untrusted sources.Consider documenting this constraint or using a safer serialization format (e.g.,
numpy.save/numpy.frombuffer) if future requirements change:# Alternative using numpy's native serialization: from io import BytesIO def db_save_embedding(image_path, embedding): conn = get_db_connection() try: # Store numpy array using numpy's native format buffer = BytesIO() np.save(buffer, embedding, allow_pickle=False) blob = buffer.getvalue() conn.execute(''' INSERT OR REPLACE INTO image_embeddings (image_path, embedding) VALUES (?, ?) ''', (image_path, blob)) conn.commit() except Exception as e: logger.error(f"Error saving embedding for {image_path}: {e}") finally: conn.close()
46-61: Consider scalability for large image collections.Loading all embeddings into memory works well for small to medium collections (< 10k images), but could cause memory issues with larger libraries. For now, this is acceptable given the local, on-device use case.
If scalability becomes a concern, consider:
- Adding pagination support
- Implementing batch retrieval
- Using a vector database (e.g., FAISS, Annoy) for efficient similarity search at scale
backend/main.py (1)
52-64: Consider refactoring the startup sequence.The initialization logic is correct and the embeddings table is properly created. However, the startup function now has 8+ sequential table creation calls, which could benefit from refactoring.
Consider grouping related initializations:
def initialize_database_tables(): """Initialize all database tables.""" db_create_folders_table() db_create_images_table() db_create_YOLO_classes_table() db_create_clusters_table() db_create_faces_table() db_create_albums_table() db_create_album_images_table() db_create_metadata_table() db_create_embeddings_table() @asynccontextmanager async def lifespan(app: FastAPI): generate_openapi_json() initialize_database_tables() microservice_util_start_sync_service() # ... rest of the functionbackend/app/utils/similarity.py (1)
36-56: LGTM!The image preprocessing follows standard MobileNetV2 requirements with correct ImageNet normalization values.
The comment on lines 43-44 could be more precise:
- # Normalize: (val - mean) / std. MobileNet expect [0,1] or standard normalization - # MobileNetV2 generic expectation: scale to [0, 1], then normalize + # Normalize: scale to [0, 1], then apply ImageNet mean/std normalization + # This matches MobileNetV2's expected input preprocessingbackend/app/routes/similarity.py (3)
1-13: Remove unused import.Line 7 imports
cdistfromscipy.spatial.distance, but it's not used in the code (cosine similarity is computed withnumpy.dotinstead).-from scipy.spatial.distance import cdist
15-35: Consider improving the background task robustness.The try/except ImportError pattern (lines 21-26) is unusual. If
db_get_all_imagesshould always be available, this suggests a structural issue. The function gracefully handles failures for individual images, which is good, but lacks progress reporting for long-running scans.Consider:
- Removing the ImportError handling if
db_get_all_imagesshould always be available- Adding progress logging (e.g., every 100 images)
- Adding a mechanism to track scan progress/status
def process_all_images(): """Background task to generate embeddings for all images.""" logger.info("Starting background embedding generation...") - # NOTE: In a real integration, we'd fetch all images from the images table. - # For now, we assume we can iterate over a known list or scan. - # mocking db_get_all_images call for safety if not exists - try: - from app.database.images import db_get_all_images - images = db_get_all_images() # returns list of dicts or objects - except ImportError: - logger.warning("Could not import db_get_all_images, skipping scan.") - return + from app.database.images import db_get_all_images + images = db_get_all_images() count = 0 + total = len(images) for img in images: path = img['path'] # adjust key based on schema embedding = similarity_engine.compute_embedding(path) if embedding is not None: db_save_embedding(path, embedding) count += 1 + if count % 100 == 0: + logger.info(f"Processed {count}/{total} images...") logger.info(f"Embedding generation complete. Processed {count} images.")
37-41: LGTM!The endpoint correctly schedules the background task and returns immediately.
Consider adding a status endpoint to check scan progress:
# Track scan status in app.state or a simple cache @router.get("/scan/status") async def get_scan_status(): """Check if a scan is currently running.""" # Implementation would require state tracking return {"running": False, "progress": 0, "total": 0}
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
backend/app/database/similarity.py(1 hunks)backend/app/routes/similarity.py(1 hunks)backend/app/utils/similarity.py(1 hunks)backend/main.py(4 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-10-31T17:00:50.132Z
Learnt from: Hemil36
Repo: AOSSIE-Org/PictoPy PR: 570
File: backend/app/database/connection.py:16-24
Timestamp: 2025-10-31T17:00:50.132Z
Learning: In PictoPy backend, the user prefers not to use database connection retry logic or extended busy timeouts in the centralized get_db_connection() context manager, even though the app has concurrent access patterns via ProcessPoolExecutor and FastAPI.
Applied to files:
backend/app/database/similarity.py
🧬 Code graph analysis (2)
backend/app/routes/similarity.py (3)
backend/app/utils/similarity.py (2)
SimilarityEngine(10-85)compute_embedding(58-85)backend/app/database/similarity.py (2)
db_save_embedding(31-44)db_get_all_embeddings(46-61)backend/app/database/images.py (1)
db_get_all_images(123-214)
backend/main.py (2)
backend/app/database/similarity.py (1)
db_create_embeddings_table(15-29)backend/app/utils/microservice.py (1)
microservice_util_start_sync_service(41-71)
🔇 Additional comments (6)
backend/app/database/similarity.py (2)
15-29: LGTM!The table schema is appropriate for the use case. The PRIMARY KEY on
image_pathprovides an implicit index for fast lookups.
8-8: Verify DB_PATH consistency across database modules.Ensure that
pictopy.dbmatches the database path used in other database modules inbackend/app/database/(e.g.,images.py,faces.py). If other modules define different database paths, reconcile them to use a single configuration source.backend/main.py (2)
24-25: LGTM!The imports are well-organized and follow the existing pattern with clear comments marking the new additions.
Also applies to: 34-35
147-149: LGTM!The similarity router is properly registered with appropriate prefix and tags.
backend/app/utils/similarity.py (2)
20-34: LGTM!The model loading logic is well-structured with proper error handling and uses the CPU execution provider for broad compatibility.
10-18: Add thread safety to singleton initialization.The singleton pattern has a race condition if multiple threads attempt to create the first instance simultaneously. However, the proposed double-checked locking pattern is error-prone in Python; a simpler approach using a single
Lockis safer and more Pythonic:+import threading + class SimilarityEngine: _instance = None + _lock = threading.Lock() def __new__(cls): - if cls._instance is None: - cls._instance = super(SimilarityEngine, cls).__new__(cls) - cls._instance.model_path = os.path.join(os.path.dirname(__file__), "mobilenetv2-7.onnx") - cls._instance.session = None + with cls._lock: + if cls._instance is None: + cls._instance = super(SimilarityEngine, cls).__new__(cls) + cls._instance.model_path = os.path.join(os.path.dirname(__file__), "mobilenetv2-7.onnx") + cls._instance.session = None return cls._instanceAlternatively, if lazy initialization is not required, instantiate the singleton at module level for inherent thread safety.
Also document where users should place the
mobilenetv2-7.onnxmodel file or provide an automatic download mechanism.
| @router.get("/duplicates") | ||
| async def find_duplicates(threshold: float = 0.95): | ||
| """Find duplicate images based on cosine similarity.""" | ||
| embeddings_map = db_get_all_embeddings() | ||
| if not embeddings_map: | ||
| return {"duplicates": []} | ||
|
|
||
| paths = list(embeddings_map.keys()) | ||
| matrix = np.stack(list(embeddings_map.values())) | ||
|
|
||
| # Compute potential duplicates | ||
| # This is O(N^2), feasible for small collections < 5k images locally. | ||
| # For optimization, FAISS or similar could be used. | ||
|
|
||
| duplicates = [] | ||
| visited = set() | ||
|
|
||
| # Dot product for normalized vectors is cosine similarity | ||
| sim_matrix = np.dot(matrix, matrix.T) | ||
|
|
||
| indices = np.where(sim_matrix > threshold) | ||
|
|
||
| # Group them | ||
| for i, j in zip(*indices): | ||
| if i != j and i not in visited and j not in visited: | ||
| # We found a pair (or more). Simple pair logic for now. | ||
| duplicates.append([paths[i], paths[j]]) | ||
| visited.add(i) | ||
| visited.add(j) | ||
|
|
||
| return {"duplicates": duplicates, "count": len(duplicates)} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix duplicate detection logic to handle clusters correctly.
The current algorithm (lines 57-71) has a critical flaw: it only pairs duplicates and marks both as visited, which misses clusters of 3+ similar images.
Example failure case:
- Images A, B, C are all mutually similar (similarity > 0.95)
- Algorithm processes pair (A, B) → marks both visited
- When checking C against A or B, both are already visited → C is skipped
- Result: Only returns
[A, B], missing C entirely
Replace the pairing logic with proper clustering:
+from collections import defaultdict
+
@router.get("/duplicates")
async def find_duplicates(threshold: float = 0.95):
"""Find duplicate images based on cosine similarity."""
embeddings_map = db_get_all_embeddings()
if not embeddings_map:
return {"duplicates": []}
paths = list(embeddings_map.keys())
matrix = np.stack(list(embeddings_map.values()))
- # Compute potential duplicates
- # This is O(N^2), feasible for small collections < 5k images locally.
- # For optimization, FAISS or similar could be used.
-
- duplicates = []
- visited = set()
-
# Dot product for normalized vectors is cosine similarity
sim_matrix = np.dot(matrix, matrix.T)
- indices = np.where(sim_matrix > threshold)
+ # Build adjacency for connected components
+ parent = list(range(len(paths)))
- # Group them
- for i, j in zip(*indices):
- if i != j and i not in visited and j not in visited:
- # We found a pair (or more). Simple pair logic for now.
- duplicates.append([paths[i], paths[j]])
- visited.add(i)
- visited.add(j)
-
- return {"duplicates": duplicates, "count": len(duplicates)}
+ def find(x):
+ if parent[x] != x:
+ parent[x] = find(parent[x])
+ return parent[x]
+
+ def union(x, y):
+ px, py = find(x), find(y)
+ if px != py:
+ parent[px] = py
+
+ # Union similar images
+ indices = np.where(sim_matrix > threshold)
+ for i, j in zip(*indices):
+ if i != j:
+ union(i, j)
+
+ # Group by connected component
+ groups = defaultdict(list)
+ for i in range(len(paths)):
+ groups[find(i)].append(paths[i])
+
+ # Only return groups with 2+ images
+ duplicates = [group for group in groups.values() if len(group) > 1]
+
+ return {"duplicates": duplicates, "count": len(duplicates)}| @router.get("/search/{target_path:path}") | ||
| async def search_similar(target_path: str, limit: int = 10): | ||
| """Find images similar to the target path.""" | ||
| embeddings_map = db_get_all_embeddings() | ||
|
|
||
| if target_path not in embeddings_map: | ||
| # compute on the fly ? | ||
| vec = similarity_engine.compute_embedding(target_path) | ||
| if vec is None: | ||
| raise HTTPException(status_code=404, detail="Image not processed or invalid") | ||
| else: | ||
| vec = embeddings_map[target_path] | ||
|
|
||
| paths = list(embeddings_map.keys()) | ||
| matrix = np.stack(list(embeddings_map.values())) | ||
|
|
||
| scores = np.dot(matrix, vec) | ||
| top_indices = np.argsort(scores)[::-1][:limit] | ||
|
|
||
| results = [] | ||
| for idx in top_indices: | ||
| if paths[idx] != target_path: | ||
| results.append({"path": paths[idx], "score": float(scores[idx])}) | ||
|
|
||
| return results |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Save on-the-fly computed embeddings and fix result filtering.
The endpoint computes embeddings on-the-fly (lines 80-84) but doesn't save them, requiring recomputation on subsequent requests. Additionally, the filtering logic (line 96) has edge cases.
Apply these improvements:
if target_path not in embeddings_map:
- # compute on the fly ?
+ # Compute and save on-the-fly
vec = similarity_engine.compute_embedding(target_path)
if vec is None:
- raise HTTPException(status_code=404, detail="Image not processed or invalid")
+ raise HTTPException(status_code=404, detail="Image not processed or invalid")
+ # Save for future queries
+ db_save_embedding(target_path, vec)
+ embeddings_map[target_path] = vec
else:
vec = embeddings_map[target_path]
paths = list(embeddings_map.keys())
matrix = np.stack(list(embeddings_map.values()))
scores = np.dot(matrix, vec)
top_indices = np.argsort(scores)[::-1][:limit]
results = []
for idx in top_indices:
if paths[idx] != target_path:
- results.append({"path": paths[idx], "score": float(scores[idx])})
+ results.append({"path": paths[idx], "score": float(scores[idx])})
+ if len(results) >= limit:
+ break
return results| def compute_embedding(self, image_path): | ||
| """Computes the 1280-d embedding for an image.""" | ||
| if not self.load_model(): | ||
| return None | ||
|
|
||
| input_data = self.preprocess_image(image_path) | ||
| if input_data is None: | ||
| return None | ||
|
|
||
| try: | ||
| inputs = {self.session.get_inputs()[0].name: input_data} | ||
| outputs = self.session.run(None, inputs) | ||
| # MobileNetV2 (without classifier) output usually (1, 1280, 7, 7) or similar. | ||
| # We want a global pool. For generic MobileNetV2-7: output is 'output' | ||
| # If standard ImageNet model, last layer before classifier. | ||
| # Assuming the output is the feature vector or we pool it. | ||
|
|
||
| # MobileNetV2 output is (1, 1280, 7, 7). We perform Global Average Pooling. | ||
| output = outputs[0] | ||
| if len(output.shape) == 4: | ||
| output = np.mean(output, axis=(2, 3)) # (1, 1280) | ||
|
|
||
| embedding = output.flatten() | ||
|
|
||
| return embedding / np.linalg.norm(embedding) # Normalize for Cosine Sim | ||
| except Exception as e: | ||
| logger.error(f"Inference error for {image_path}: {e}") | ||
| return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix potential division by zero in L2 normalization.
Line 82 can raise a division by zero error if the embedding vector has zero norm (e.g., all-zero output from the model due to a corrupted image or model issue).
Apply this diff to add a safety check:
embedding = output.flatten()
- return embedding / np.linalg.norm(embedding) # Normalize for Cosine Sim
+ norm = np.linalg.norm(embedding)
+ if norm == 0:
+ logger.error(f"Zero-norm embedding for {image_path}")
+ return None
+ return embedding / norm # Normalize for Cosine Sim
except Exception as e:
logger.error(f"Inference error for {image_path}: {e}")
return None📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def compute_embedding(self, image_path): | |
| """Computes the 1280-d embedding for an image.""" | |
| if not self.load_model(): | |
| return None | |
| input_data = self.preprocess_image(image_path) | |
| if input_data is None: | |
| return None | |
| try: | |
| inputs = {self.session.get_inputs()[0].name: input_data} | |
| outputs = self.session.run(None, inputs) | |
| # MobileNetV2 (without classifier) output usually (1, 1280, 7, 7) or similar. | |
| # We want a global pool. For generic MobileNetV2-7: output is 'output' | |
| # If standard ImageNet model, last layer before classifier. | |
| # Assuming the output is the feature vector or we pool it. | |
| # MobileNetV2 output is (1, 1280, 7, 7). We perform Global Average Pooling. | |
| output = outputs[0] | |
| if len(output.shape) == 4: | |
| output = np.mean(output, axis=(2, 3)) # (1, 1280) | |
| embedding = output.flatten() | |
| return embedding / np.linalg.norm(embedding) # Normalize for Cosine Sim | |
| except Exception as e: | |
| logger.error(f"Inference error for {image_path}: {e}") | |
| return None | |
| def compute_embedding(self, image_path): | |
| """Computes the 1280-d embedding for an image.""" | |
| if not self.load_model(): | |
| return None | |
| input_data = self.preprocess_image(image_path) | |
| if input_data is None: | |
| return None | |
| try: | |
| inputs = {self.session.get_inputs()[0].name: input_data} | |
| outputs = self.session.run(None, inputs) | |
| # MobileNetV2 (without classifier) output usually (1, 1280, 7, 7) or similar. | |
| # We want a global pool. For generic MobileNetV2-7: output is 'output' | |
| # If standard ImageNet model, last layer before classifier. | |
| # Assuming the output is the feature vector or we pool it. | |
| # MobileNetV2 output is (1, 1280, 7, 7). We perform Global Average Pooling. | |
| output = outputs[0] | |
| if len(output.shape) == 4: | |
| output = np.mean(output, axis=(2, 3)) # (1, 1280) | |
| embedding = output.flatten() | |
| norm = np.linalg.norm(embedding) | |
| if norm == 0: | |
| logger.error(f"Zero-norm embedding for {image_path}") | |
| return None | |
| return embedding / norm # Normalize for Cosine Sim | |
| except Exception as e: | |
| logger.error(f"Inference error for {image_path}: {e}") | |
| return None |
🤖 Prompt for AI Agents
In backend/app/utils/similarity.py around lines 58 to 85, the L2 normalization
on line 82 can divide by zero if the embedding norm is 0; add a safety check
after computing embedding: compute norm = np.linalg.norm(embedding) and if norm
== 0 (or norm < a small eps like 1e-12) handle it gracefully (e.g., log a
warning and return None or return the unnormalized embedding) otherwise return
embedding / norm; ensure the check prevents a ZeroDivisionError and keeps
behavior consistent with downstream cosine-similarity expectations.
Summary
This PR introduces Visual Similarity Search and Duplicate Image Detection, enabling users to find visually similar images and identify potential duplicates in their gallery. This feature powers a more intelligent and organized photo management experience while maintaining the project's privacy-first values.
Key Features
onnxruntimewith a lightweightMobileNetV2model to generate image embeddings locally. No data leaves the device.Technical Changes
Backend
image_embeddingsSQLite table and helper functions to store and retrieve binary feature vectors.POST /similarity/scan: Background task to index the entire library.GET /similarity/duplicates: Returns groups of similar images based on cosine similarity threshold.GET /similarity/search/{path}: Returns nearest neighbors for a specific image.Testing
Dependencies
onnxruntime,numpy, andscipytorequirements.txt.Summary by CodeRabbit
Release Notes
✏️ Tip: You can customize this high-level summary in your review settings.