Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions compose/debug.docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ services:
OASIS_ADMIN_PASS: password
volumes:
- filestore-OasisData:/shared-fs:rw
- ../src/server/oasisapi:/var/www/oasis/src/server/oasisapi
- ../src:/var/www/oasis/src
server-websocket:
restart: always
image: coreoasis/api_server:dev
Expand All @@ -81,7 +81,7 @@ services:
<<: *shared-env
volumes:
- filestore-OasisData:/shared-fs:rw
- ../src/server/oasisapi:/var/www/oasis/src/server/oasisapi
- ../src:/var/www/oasis/src
v1-worker-monitor:
restart: always
image: coreoasis/api_server:dev
Expand All @@ -94,7 +94,7 @@ services:
<<: *shared-env
volumes:
- filestore-OasisData:/shared-fs:rw
- ../src/server/oasisapi:/var/www/oasis/src/server/oasisapi
- ../src:/var/www/oasis/src
v2-worker-monitor:
restart: always
image: coreoasis/api_server:dev
Expand All @@ -107,7 +107,7 @@ services:
<<: *shared-env
volumes:
- filestore-OasisData:/shared-fs:rw
- ../src/server/oasisapi:/var/www/oasis/src/server/oasisapi
- ../src:/var/www/oasis/src
v2-task-controller:
restart: always
image: coreoasis/api_server:dev
Expand Down
43 changes: 26 additions & 17 deletions src/conf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,29 @@
)
)

#: Celery config - result backend URI
CELERY_RESULTS_DB_BACKEND = settings.get('celery', 'DB_ENGINE', fallback='db+sqlite')
if CELERY_RESULTS_DB_BACKEND == 'db+sqlite':
CELERY_RESULT_BACKEND = '{DB_ENGINE}:///{DB_NAME}'.format(
DB_ENGINE=CELERY_RESULTS_DB_BACKEND,
DB_NAME=settings.get('celery', 'db_name', fallback='celery.db.sqlite'),
)
else:
CELERY_RESULT_BACKEND = '{DB_ENGINE}://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}{SSL_MODE}'.format(
DB_ENGINE=settings.get('celery', 'db_engine'),
DB_USER=urllib.parse.quote(settings.get('celery', 'db_user')),
DB_PASS=urllib.parse.quote(settings.get('celery', 'db_pass')),
DB_HOST=settings.get('celery', 'db_host'),
DB_PORT=settings.get('celery', 'db_port'),
DB_NAME=settings.get('celery', 'db_name', fallback='celery'),
SSL_MODE=settings.get('celery', 'db_ssl_mode', fallback='?sslmode=prefer'),
)

CELERY_RESULT_BACKEND = "src.conf.custom_celery_db.backends.AuthTokenDatabaseBackend"
CELERY_BACKEND_TOKEN_CLASS = "src.conf.custom_celery_db.token_providers.StaticTokenProvider"
CELERY_BACKEND_TOKEN_CONFIG = {
#'token': urllib.parse.quote(settings.get('celery', 'db_pass')),
'token': 'INVALID',
}


##: Celery config - result backend URI
#CELERY_RESULTS_DB_BACKEND = settings.get('celery', 'DB_ENGINE', fallback='db+sqlite')
#if CELERY_RESULTS_DB_BACKEND == 'db+sqlite':
# CELERY_RESULT_BACKEND = '{DB_ENGINE}:///{DB_NAME}'.format(
# DB_ENGINE=CELERY_RESULTS_DB_BACKEND,
# DB_NAME=settings.get('celery', 'db_name', fallback='celery.db.sqlite'),
# )
#else:
# CELERY_RESULT_BACKEND = '{DB_ENGINE}://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}{SSL_MODE}'.format(
# DB_ENGINE=settings.get('celery', 'db_engine'),
# DB_USER=urllib.parse.quote(settings.get('celery', 'db_user')),
# DB_PASS=urllib.parse.quote(settings.get('celery', 'db_pass')),
# DB_HOST=settings.get('celery', 'db_host'),
# DB_PORT=settings.get('celery', 'db_port'),
# DB_NAME=settings.get('celery', 'db_name', fallback='celery'),
# SSL_MODE=settings.get('celery', 'db_ssl_mode', fallback='?sslmode=prefer'),
# )
Empty file.
166 changes: 166 additions & 0 deletions src/conf/custom_celery_db/backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import threading
from datetime import datetime, timedelta

from sqlalchemy import event
from sqlalchemy.engine.url import make_url

from celery.backends.database import DatabaseBackend
from celery.backends.database.session import SessionManager
from celery.exceptions import BackendError

from celery import current_app

import logging

import importlib
import urllib
from src.conf.iniconf import settings


logger = logging.getLogger(__name__)



def import_class_from_string(class_path: str):
"""Import a class from a string path like 'module.submodule.ClassName'."""
module_path, class_name = class_path.rsplit('.', 1)
module = importlib.import_module(module_path)
return getattr(module, class_name)






class TokenSessionManager(SessionManager):
"""
A custom SessionManager that attaches an event listener to any engine
it creates. This is the reliable way to inject our token logic.
"""
def __init__(self, backend, *args, **kwargs):
self.backend = backend # Keep a reference to our backend instance
super().__init__(*args, **kwargs)
self._init_token_provider()



def _init_token_provider(self):
"""Initialize the token provider from configuration."""
# Get token provider class from config
token_class_path = current_app.conf.get('CELERY_BACKEND_TOKEN_CLASS')
if not token_class_path:
raise ValueError("CELERY_BACKEND_TOKEN_CLASS must be configured")

# Get token provider config
token_config = current_app.conf.get('CELERY_BACKEND_TOKEN_CONFIG', {})

# Import and instantiate the token provider
TokenProviderClass = import_class_from_string(token_class_path)
self.token_provider = TokenProviderClass(**token_config)
print(f"Initialized token provider: {TokenProviderClass.__name__}")




def get_engine(self, dburi, **kwargs):
# Always get fresh token before creating engine
#self.backend._ensure_token_valid()
#
## Update the URI with fresh token
from celery.contrib import rdb; rdb.set_trace()
token_url = make_url(dburi)
token_url = token_url.set(password=self.token_provider.get_token())
engine = super().get_engine(token_url, **kwargs)




## this should be called when establishing a connection
#@event.listens_for(engine, "connect")
#def _on_connect(dbapi_connection, connection_record):
# logger.info("New connection established")

## this is called when connection failed
#@event.listens_for(engine, "invalidate")
#def _on_invalidate(dbapi_connection, connection_record, exception):
# logger.info(f"Connection invalidated: {exception}")
# # Next connection attempt will call get_engine again


@event.listens_for(engine, "handle_error")
def handle_token_error(exception_context):
from celery.contrib import rdb; rdb.set_trace()
error_msg = str(exception_context.original_exception).lower()
logger.info(f"Error with DB connection: {error_msg}")

## Check if it's a token/auth related error
if any(pattern in error_msg for pattern in [
'authentication failed', 'access denied', 'token expired',
'invalid authorization', 'password authentication failed'
]):
logger.info("Token authentication failed, refreshing...")
self.token_provider.force_refresh()

# Invalidate current connections to force new ones
engine.dispose()

# Optionally return True to suppress the original error
# and let SQLAlchemy retry with the new token
return True



#
# # Refresh the token
# self.backend._ensure_token_valid()
#
# # Update connection info for future connections
# from sqlalchemy.engine.url import make_url
# url = make_url(dburi)
# url = url.set(password=self.backend._token)
#
# engine.dispose()
#


return engine





class AuthTokenDatabaseBackend(DatabaseBackend):
"""
A custom Celery result backend that subclasses the standard DatabaseBackend
to handle expiring database tokens (e.g., IAM tokens for PostgreSQL/MySQL).
"""
def __init__(self, dburi=None, *args, **kwargs):
logger.info("RUNNING CELERY CUSTOM RESULTS DB BACKEND")
super().__init__(dburi=self._get_database_connection_string(), *args, **kwargs)


def _get_database_connection_string(self):

# Load DB connection string
dburi = '{DB_ENGINE}://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}{SSL_MODE}'.format(
#DB_ENGINE=settings.get('celery', 'db_engine'),
DB_ENGINE='postgresql+psycopg', # <---- must strip out 'DB+' if set
DB_USER=urllib.parse.quote(settings.get('celery', 'db_user')),
#DB_PASS=urllib.parse.quote(settings.get('celery', 'db_pass'), '%PLACEHOLDER%'),
DB_PASS="%PLACEHOLDER%",
DB_HOST=settings.get('celery', 'db_host'),
DB_PORT=settings.get('celery', 'db_port'),
DB_NAME=settings.get('celery', 'db_name', fallback='celery'),
SSL_MODE=settings.get('celery', 'db_ssl_mode', fallback='?sslmode=prefer'),
)
return dburi


# Load custom session Manager
def ResultSession(self):
session_manager=TokenSessionManager(backend=self)
return session_manager.session_factory(
dburi=self.url,
short_lived_sessions=self.short_lived_sessions,
**self.engine_options)

42 changes: 42 additions & 0 deletions src/conf/custom_celery_db/token_providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Optional, Tuple



class TokenProvider(ABC):
"""Abstract base class for token providers."""

@abstractmethod
def get_token(self) -> str:
"""Get a valid token, refreshing if necessary."""
raise NotImplementedError("Subclasses must implement get_token()")

@abstractmethod
def is_token_expired(self) -> bool:
"""Check if the current token is expired."""
raise NotImplementedError("Subclasses must implement is_token_expired()")

@abstractmethod
def force_refresh(self) -> str:
"""Force a token refresh and return the new token."""
raise NotImplementedError("Subclasses must implement force_refresh()")



class StaticTokenProvider(TokenProvider):
"""Simple token provider that returns a static token (for testing)."""

def __init__(self, token: str):
self.token = token

def get_token(self) -> str:
return self.token

def is_token_expired(self) -> bool:
return False

def force_refresh(self) -> str:
return self.token


Loading