""" Database connection and session management for the ML service. This module provides SQLAlchemy engine and session setup for PostgreSQL. Environment variables control the connection parameters. """ import os from contextlib import contextmanager from typing import Generator from sqlalchemy import create_engine, Column, Integer, String, DateTime, JSON from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker, Session from sqlalchemy.sql import func # CREATE DATABASE ml_service; # CREATE USER ml_user WITH ENCRYPTED PASSWORD 'ml_password'; # GRANT ALL PRIVILEGES ON DATABASE ml_service TO ml_user; # Database connection configuration from environment DATABASE_URL = os.getenv( "DATABASE_URL", f"postgresql://{os.getenv('POSTGRES_USER', 'ml_user')}:" f"{os.getenv('POSTGRES_PASSWORD', 'ml_password')}@" f"{os.getenv('POSTGRES_HOST', 'localhost')}:" f"{os.getenv('POSTGRES_PORT', '5432')}/" f"{os.getenv('POSTGRES_DB', 'ml_service')}" ) # Create SQLAlchemy engine engine = create_engine( DATABASE_URL, pool_pre_ping=True, # Verify connections before using them pool_size=5, # Number of connections to maintain max_overflow=10, # Max connections beyond pool_size ) # Create session factory SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) # Base class for declarative models Base = declarative_base() # Training runs model class TrainingRun(Base): """Model for tracking ML training runs.""" __tablename__ = "training_runs" id = Column(Integer, primary_key=True, index=True) run_id = Column(String(255), unique=True, nullable=False, index=True) model_type = Column(String(100), nullable=False) experiment_name = Column(String(255), nullable=False, index=True) pipeline_config_hash = Column(String(64), nullable=False) dataset_version = Column(String(100)) metrics_summary = Column(JSON) status = Column(String(50), nullable=False, default="running", index=True) created_at = Column(DateTime(timezone=True), server_default=func.now(), index=True) completed_at = Column(DateTime(timezone=True)) def __repr__(self): return f"" def init_db(): """ Initialize the database schema. Creates all tables defined by Base.metadata. """ Base.metadata.create_all(bind=engine) @contextmanager def get_db() -> Generator[Session, None, None]: """ Context manager for database sessions. Usage: with get_db() as db: # Use db session here training_run = TrainingRun(run_id="123", ...) db.add(training_run) db.commit() Yields: Database session """ db = SessionLocal() try: yield db finally: db.close() def get_db_session() -> Session: """ Get a database session (for dependency injection). Usage with FastAPI: @app.get("/") def endpoint(db: Session = Depends(get_db_session)): # Use db here Returns: Database session (caller must close it) """ return SessionLocal()