Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Initial migration
"""initial_schema_with_sub

Revision ID: 0f7a55085f49
Revision ID: 3e1f02d20edc
Revises:
Create Date: 2025-12-11 20:58:00.476719
Create Date: 2025-12-21 01:06:34.911124

"""
from typing import Sequence, Union
Expand All @@ -12,7 +12,7 @@
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision: str = '0f7a55085f49'
revision: str = '3e1f02d20edc'
down_revision: Union[str, Sequence[str], None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
Expand All @@ -23,6 +23,7 @@ def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('conversations',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.String(), nullable=False),
sa.Column('session_id', sa.String(), nullable=False),
sa.Column('title', sa.String(), nullable=True),
sa.Column('dataset_ids', sa.String(), nullable=True),
Expand All @@ -32,38 +33,42 @@ def upgrade() -> None:
)
op.create_index(op.f('ix_conversations_id'), 'conversations', ['id'], unique=False)
op.create_index(op.f('ix_conversations_session_id'), 'conversations', ['session_id'], unique=True)
op.create_index(op.f('ix_conversations_user_id'), 'conversations', ['user_id'], unique=False)
op.create_table('datasets',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.String(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('shared', sa.Boolean(), nullable=True),
sa.Column('annotation', sa.Text(), nullable=False),
sa.Column('created_at', postgresql.TIMESTAMP(timezone=True), nullable=True),
sa.Column('updated_at', postgresql.TIMESTAMP(timezone=True), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_datasets_id'), 'datasets', ['id'], unique=False)
op.create_index(op.f('ix_datasets_name'), 'datasets', ['name'], unique=True)
op.create_index(op.f('ix_datasets_user_id'), 'datasets', ['user_id'], unique=False)
op.create_table('files',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.String(), nullable=False),
sa.Column('file_path', sa.String(), nullable=False),
sa.Column('shared', sa.Boolean(), nullable=True),
sa.Column('upload_date', postgresql.TIMESTAMP(timezone=True), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('file_path')
)
op.create_index(op.f('ix_files_id'), 'files', ['id'], unique=False)
op.create_index(op.f('ix_files_user_id'), 'files', ['user_id'], unique=False)
op.create_table('user_settings',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('user_hash', sa.String(), nullable=False),
sa.Column('email', sa.String(), nullable=True),
sa.Column('user_id', sa.String(), nullable=False),
sa.Column('openai_api_key', sa.String(), nullable=True),
sa.Column('anthropic_api_key', sa.String(), nullable=True),
sa.Column('gemini_api_key', sa.String(), nullable=True),
sa.Column('together_api_key', sa.String(), nullable=True),
sa.Column('updated_at', postgresql.TIMESTAMP(timezone=True), nullable=True),
sa.Column('created_at', postgresql.TIMESTAMP(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.PrimaryKeyConstraint('id')
sa.PrimaryKeyConstraint('user_id')
)
op.create_index(op.f('ix_user_settings_id'), 'user_settings', ['id'], unique=False)
op.create_index(op.f('ix_user_settings_user_hash'), 'user_settings', ['user_hash'], unique=True)
op.create_index(op.f('ix_user_settings_user_id'), 'user_settings', ['user_id'], unique=True)
op.create_table('dataset_files',
sa.Column('dataset_id', sa.Integer(), nullable=False),
sa.Column('file_id', sa.Integer(), nullable=False),
Expand Down Expand Up @@ -92,14 +97,16 @@ def downgrade() -> None:
op.drop_index(op.f('ix_messages_id'), table_name='messages')
op.drop_table('messages')
op.drop_table('dataset_files')
op.drop_index(op.f('ix_user_settings_user_hash'), table_name='user_settings')
op.drop_index(op.f('ix_user_settings_id'), table_name='user_settings')
op.drop_index(op.f('ix_user_settings_user_id'), table_name='user_settings')
op.drop_table('user_settings')
op.drop_index(op.f('ix_files_user_id'), table_name='files')
op.drop_index(op.f('ix_files_id'), table_name='files')
op.drop_table('files')
op.drop_index(op.f('ix_datasets_user_id'), table_name='datasets')
op.drop_index(op.f('ix_datasets_name'), table_name='datasets')
op.drop_index(op.f('ix_datasets_id'), table_name='datasets')
op.drop_table('datasets')
op.drop_index(op.f('ix_conversations_user_id'), table_name='conversations')
op.drop_index(op.f('ix_conversations_session_id'), table_name='conversations')
op.drop_index(op.f('ix_conversations_id'), table_name='conversations')
op.drop_table('conversations')
Expand Down
17 changes: 6 additions & 11 deletions app/backend/app/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def get_jwks():
raise HTTPException(status_code=500, detail="Could not verify authentication keys") from e
return jwks_cache

def get_user_hash(authorization: str | None = Header(None)) -> tuple[str, str]:
def get_current_user(authorization: str | None = Header(None)) -> tuple[str, str]:
"""
Validates the Bearer token signature using Auth0 JWKS and returns (user_hash, email).
"""
Expand Down Expand Up @@ -69,17 +69,12 @@ def get_user_hash(authorization: str | None = Header(None)) -> tuple[str, str]:
issuer=AUTH0_ISSUER,
)

# extract email and compute salted user hash
email = payload.get(f"{CLAIMS_NAMESPACE}/email")
if not email:
# fallback: Sometimes standard 'email' claim exists if configured differently
email = payload.get("email")
if not email:
raise HTTPException(status_code=400, detail="Token missing email claim")
hash_input = f"{email}{SALT}"
user_hash = hashlib.sha256(hash_input.encode('utf-8')).hexdigest()
# extract user_id
user_id = payload.get("sub")
if not user_id:
raise HTTPException(status_code=401, detail="Token missing 'sub' claim")

return user_hash, email
return user_id

except jwt.ExpiredSignatureError as e:
raise HTTPException(status_code=401, detail="Token is expired") from e
Expand Down
24 changes: 12 additions & 12 deletions app/backend/app/database.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from datetime import datetime, timezone

from sqlalchemy import Column, ForeignKey, Integer, String, Text
from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Text
from sqlalchemy.dialects.postgresql import TIMESTAMP
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import declarative_base
Expand Down Expand Up @@ -37,45 +37,45 @@ def read_secret(secret_name: str) -> str:
class UserSettings(Base):
__tablename__ = "user_settings"

id = Column(Integer, primary_key=True, index=True)
user_hash = Column(String, unique=True, index=True, nullable=False)
email = Column(String, nullable=True)

# API Keys
user_id = Column(String, primary_key=True, unique=True, index=True, nullable=False)
openai_api_key = Column(String, nullable=True)
anthropic_api_key = Column(String, nullable=True)
gemini_api_key = Column(String, nullable=True)
together_api_key = Column(String, nullable=True)

updated_at = Column(TIMESTAMP(timezone=True), onupdate=func.now())
created_at = Column(TIMESTAMP(timezone=True), server_default=func.now())

class Dataset(Base):
__tablename__ = "datasets"

id = Column(Integer, primary_key=True, index=True)
user_id = Column(String, index=True, nullable=False)
name = Column(String, unique=True, nullable=False, index=True)
shared = Column(Boolean, default=False)
annotation = Column(Text, nullable=False)
created_at = Column(TIMESTAMP(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(TIMESTAMP(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))

class File(Base):
__tablename__ = "files"

id = Column(Integer, primary_key=True, index=True)
user_id = Column(String, index=True, nullable=False)
file_path = Column(String, unique=True, nullable=False)
shared = Column(Boolean, default=False)
upload_date = Column(TIMESTAMP(timezone=True), default=lambda: datetime.now(timezone.utc))

class DatasetFile(Base):
__tablename__ = "dataset_files"

dataset_id = Column(Integer, ForeignKey("datasets.id", ondelete="CASCADE"), nullable=False, primary_key=True)
file_id = Column(Integer, ForeignKey("files.id", ondelete="CASCADE"), nullable=False, primary_key=True)

class Conversation(Base):
__tablename__ = "conversations"

id = Column(Integer, primary_key=True, index=True)
user_id = Column(String, index=True, nullable=False)
session_id = Column(String, unique=True, nullable=False, index=True)
title = Column(String, nullable=True) # Auto-generated from first query
dataset_ids = Column(String, nullable=True) # Comma-separated dataset IDs
Expand All @@ -84,7 +84,7 @@ class Conversation(Base):

class Message(Base):
__tablename__ = "messages"

id = Column(Integer, primary_key=True, index=True)
conversation_id = Column(Integer, ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False)
role = Column(String, nullable=False) # 'user', 'assistant', 'status', 'error', 'result'
Expand Down
4 changes: 4 additions & 0 deletions app/backend/app/models/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@ class FileItem(BaseModel):
size: int | None = None
modified: datetime | None = None

class FileBatchDelete(BaseModel):
files: list[str]

# Dataset schemas
class DatasetCreate(BaseModel):
name: str
shared: bool
annotation: str
files: list[str]

Expand Down
Loading