candle-annotator/services/ml/app/db.py

89 lines
2.6 KiB
Python

"""
Database connection and session management for the ML service.
This module provides SQLAlchemy engine and session setup for PostgreSQL.
Environment variables control the connection parameters.
"""
import os
from contextlib import contextmanager
from typing import Generator
from sqlalchemy import create_engine, Column, Integer, String, DateTime, JSON
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.sql import func
# Database connection configuration from environment
DATABASE_URL = os.getenv("DATABASE_URL")
if not DATABASE_URL:
raise RuntimeError(
"DATABASE_URL environment variable is required. "
"Please set it to a valid PostgreSQL connection string."
)
# Create SQLAlchemy engine
engine = create_engine(
DATABASE_URL,
pool_pre_ping=True, # Verify connections before using them
pool_size=5, # Number of connections to maintain
max_overflow=10, # Max connections beyond pool_size
)
# Create session factory
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Base class for declarative models
Base = declarative_base()
# Training runs model
class TrainingRun(Base):
"""Model for tracking ML training runs."""
__tablename__ = "training_runs"
id = Column(Integer, primary_key=True, index=True)
run_id = Column(String(255), unique=True, nullable=False, index=True)
model_type = Column(String(100), nullable=False)
experiment_name = Column(String(255), nullable=False, index=True)
pipeline_config_hash = Column(String(64), nullable=False)
dataset_version = Column(String(100))
metrics_summary = Column(JSON)
status = Column(String(50), nullable=False, default="running", index=True)
created_at = Column(DateTime(timezone=True), server_default=func.now(), index=True)
completed_at = Column(DateTime(timezone=True))
def __repr__(self):
return f"<TrainingRun(run_id='{self.run_id}', status='{self.status}')>"
def init_db():
"""
Initialize the database schema.
Creates all tables defined by Base.metadata.
"""
Base.metadata.create_all(bind=engine)
@contextmanager
def get_db() -> Generator[Session, None, None]:
"""
Context manager for database sessions.
Usage:
with get_db() as db:
# Use db session here
training_run = TrainingRun(run_id="123", ...)
db.add(training_run)
db.commit()
Yields:
Database session
"""
db = SessionLocal()
try:
yield db
finally:
db.close()