candle-annotator/services/ml/app/db.py
Marko Djordjevic ea339a54a7 feat(ml): add database schema, config parser, and DVC setup
- Initialize DVC with local storage backend (task 1.6)
- Create PostgreSQL schema for training_runs table (task 1.7)
- Add SQLAlchemy database connection setup (task 1.8)
- Create Pydantic config models for pipeline.yaml (task 2.1)
- Add migration runner for database setup
- Fix pyproject.toml package discovery config
2026-02-15 12:08:53 +01:00

106 lines
3 KiB
Python

"""
Database connection and session management for the ML service.
This module provides SQLAlchemy engine and session setup for PostgreSQL.
Environment variables control the connection parameters.
"""
import os
from contextlib import contextmanager
from typing import Generator
from sqlalchemy import create_engine, Column, Integer, String, DateTime, JSON
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.sql import func
# Database connection configuration from environment
DATABASE_URL = os.getenv(
"DATABASE_URL",
f"postgresql://{os.getenv('POSTGRES_USER', 'ml_user')}:"
f"{os.getenv('POSTGRES_PASSWORD', 'ml_password')}@"
f"{os.getenv('POSTGRES_HOST', 'localhost')}:"
f"{os.getenv('POSTGRES_PORT', '5432')}/"
f"{os.getenv('POSTGRES_DB', 'ml_service')}"
)
# Create SQLAlchemy engine
engine = create_engine(
DATABASE_URL,
pool_pre_ping=True, # Verify connections before using them
pool_size=5, # Number of connections to maintain
max_overflow=10, # Max connections beyond pool_size
)
# Create session factory
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Base class for declarative models
Base = declarative_base()
# Training runs model
class TrainingRun(Base):
"""Model for tracking ML training runs."""
__tablename__ = "training_runs"
id = Column(Integer, primary_key=True, index=True)
run_id = Column(String(255), unique=True, nullable=False, index=True)
model_type = Column(String(100), nullable=False)
experiment_name = Column(String(255), nullable=False, index=True)
pipeline_config_hash = Column(String(64), nullable=False)
dataset_version = Column(String(100))
metrics_summary = Column(JSON)
status = Column(String(50), nullable=False, default="running", index=True)
created_at = Column(DateTime(timezone=True), server_default=func.now(), index=True)
completed_at = Column(DateTime(timezone=True))
def __repr__(self):
return f"<TrainingRun(run_id='{self.run_id}', status='{self.status}')>"
def init_db():
"""
Initialize the database schema.
Creates all tables defined by Base.metadata.
"""
Base.metadata.create_all(bind=engine)
@contextmanager
def get_db() -> Generator[Session, None, None]:
"""
Context manager for database sessions.
Usage:
with get_db() as db:
# Use db session here
training_run = TrainingRun(run_id="123", ...)
db.add(training_run)
db.commit()
Yields:
Database session
"""
db = SessionLocal()
try:
yield db
finally:
db.close()
def get_db_session() -> Session:
"""
Get a database session (for dependency injection).
Usage with FastAPI:
@app.get("/")
def endpoint(db: Session = Depends(get_db_session)):
# Use db here
Returns:
Database session (caller must close it)
"""
return SessionLocal()