candle-annotator/services/ml/app/db.py
Marko Djordjevic aa81d4f3d0 fix(ml): complete ML pipeline fixes and setup
- Fix CCI indicator to use HLC prices instead of close only
- Parse datetime column when loading enriched CSV
- Strip timezone from annotation timestamps
- Fix TA-Lib pattern names (CDL3WHITESOLDIERS, CDL3BLACKCROWS)
- Exclude programmatic label columns from training features
- Fix classification report to handle missing classes
- Update MLflow tracking to use localhost:5000
- Grant PostgreSQL permissions to ml_user

Pipeline now runs successfully end-to-end:
- Feature engineering: 2543 rows, 31 columns
- Annotation ingestion: 286 samples
- Training: 89.47% test accuracy with Random Forest
2026-02-15 21:29:54 +01:00

111 lines
3.1 KiB
Python

"""
Database connection and session management for the ML service.
This module provides SQLAlchemy engine and session setup for PostgreSQL.
Environment variables control the connection parameters.
"""
import os
from contextlib import contextmanager
from typing import Generator
from sqlalchemy import create_engine, Column, Integer, String, DateTime, JSON
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.sql import func
# CREATE DATABASE ml_service;
# CREATE USER ml_user WITH ENCRYPTED PASSWORD 'ml_password';
# GRANT ALL PRIVILEGES ON DATABASE ml_service TO ml_user;
# Database connection configuration from environment
DATABASE_URL = os.getenv(
"DATABASE_URL",
f"postgresql://{os.getenv('POSTGRES_USER', 'ml_user')}:"
f"{os.getenv('POSTGRES_PASSWORD', 'ml_password')}@"
f"{os.getenv('POSTGRES_HOST', 'localhost')}:"
f"{os.getenv('POSTGRES_PORT', '5432')}/"
f"{os.getenv('POSTGRES_DB', 'ml_service')}"
)
# Create SQLAlchemy engine
engine = create_engine(
DATABASE_URL,
pool_pre_ping=True, # Verify connections before using them
pool_size=5, # Number of connections to maintain
max_overflow=10, # Max connections beyond pool_size
)
# Create session factory
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Base class for declarative models
Base = declarative_base()
# Training runs model
class TrainingRun(Base):
"""Model for tracking ML training runs."""
__tablename__ = "training_runs"
id = Column(Integer, primary_key=True, index=True)
run_id = Column(String(255), unique=True, nullable=False, index=True)
model_type = Column(String(100), nullable=False)
experiment_name = Column(String(255), nullable=False, index=True)
pipeline_config_hash = Column(String(64), nullable=False)
dataset_version = Column(String(100))
metrics_summary = Column(JSON)
status = Column(String(50), nullable=False, default="running", index=True)
created_at = Column(DateTime(timezone=True), server_default=func.now(), index=True)
completed_at = Column(DateTime(timezone=True))
def __repr__(self):
return f"<TrainingRun(run_id='{self.run_id}', status='{self.status}')>"
def init_db():
"""
Initialize the database schema.
Creates all tables defined by Base.metadata.
"""
Base.metadata.create_all(bind=engine)
@contextmanager
def get_db() -> Generator[Session, None, None]:
"""
Context manager for database sessions.
Usage:
with get_db() as db:
# Use db session here
training_run = TrainingRun(run_id="123", ...)
db.add(training_run)
db.commit()
Yields:
Database session
"""
db = SessionLocal()
try:
yield db
finally:
db.close()
def get_db_session() -> Session:
"""
Get a database session (for dependency injection).
Usage with FastAPI:
@app.get("/")
def endpoint(db: Session = Depends(get_db_session)):
# Use db here
Returns:
Database session (caller must close it)
"""
return SessionLocal()