candle-annotator/services/ml/app/db.py
Marko Djordjevic 688e75e6be Scope training run queries in FastAPI to filter by user ID (Task 14.3)
- Add user_id column to TrainingRun model in db.py
- Store user_id on TrainingRun insert in /training/start
- Filter GET /training/runs by user_id (returns empty list if no user context)
- Enforce user ownership on GET /training/runs/{run_id} (404 on mismatch)
- Enforce user ownership on DELETE /training/runs/{run_id} (404 on mismatch)
- Add migration 002 to add user_id column and index to training_runs table

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-02-20 18:38:18 +01:00

90 lines
2.6 KiB
Python

"""
Database connection and session management for the ML service.
This module provides SQLAlchemy engine and session setup for PostgreSQL.
Environment variables control the connection parameters.
"""
import os
from contextlib import contextmanager
from typing import Generator
from sqlalchemy import create_engine, Column, Integer, String, DateTime, JSON
from sqlalchemy.orm import DeclarativeBase, sessionmaker, Session
from sqlalchemy.sql import func
# Database connection configuration from environment
DATABASE_URL = os.getenv("DATABASE_URL")
if not DATABASE_URL:
raise RuntimeError(
"DATABASE_URL environment variable is required. "
"Please set it to a valid PostgreSQL connection string."
)
# Create SQLAlchemy engine
engine = create_engine(
DATABASE_URL,
pool_pre_ping=True, # Verify connections before using them
pool_size=5, # Number of connections to maintain
max_overflow=10, # Max connections beyond pool_size
)
# Create session factory
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Base class for declarative models
class Base(DeclarativeBase):
pass
# Training runs model
class TrainingRun(Base):
"""Model for tracking ML training runs."""
__tablename__ = "training_runs"
id = Column(Integer, primary_key=True, index=True)
run_id = Column(String(255), unique=True, nullable=False, index=True)
user_id = Column(String(255), nullable=True, index=True)
model_type = Column(String(100), nullable=False)
experiment_name = Column(String(255), nullable=False, index=True)
pipeline_config_hash = Column(String(64), nullable=False)
dataset_version = Column(String(100))
metrics_summary = Column(JSON)
status = Column(String(50), nullable=False, default="running", index=True)
created_at = Column(DateTime(timezone=True), server_default=func.now(), index=True)
completed_at = Column(DateTime(timezone=True))
def __repr__(self):
return f"<TrainingRun(run_id='{self.run_id}', status='{self.status}')>"
def init_db():
"""
Initialize the database schema.
Creates all tables defined by Base.metadata.
"""
Base.metadata.create_all(bind=engine)
@contextmanager
def get_db() -> Generator[Session, None, None]:
"""
Context manager for database sessions.
Usage:
with get_db() as db:
# Use db session here
training_run = TrainingRun(run_id="123", ...)
db.add(training_run)
db.commit()
Yields:
Database session
"""
db = SessionLocal()
try:
yield db
finally:
db.close()