- Add python-dotenv loading in main.py so DATABASE_URL is read from .env before db.py module initializes - Add MLFLOW_TRACKING_URI to .env.example pointing to PostgreSQL - Add python-dotenv>=1.0.0 to pyproject.toml dependencies - Initialize MLflow schema in candle_annotator PostgreSQL database MLflow server now starts without filesystem deprecation warnings and with full job execution support. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1897 lines
66 KiB
Python
1897 lines
66 KiB
Python
"""
|
||
FastAPI inference service for candlestick pattern prediction.
|
||
|
||
Provides REST API endpoints for model serving, health checks, and prediction.
|
||
"""
|
||
|
||
import concurrent.futures
|
||
import hashlib
|
||
import logging
|
||
import os
|
||
import re
|
||
import threading
|
||
import uuid
|
||
from contextlib import asynccontextmanager
|
||
from pathlib import Path
|
||
from typing import Optional, Dict, Any, List
|
||
from datetime import datetime, timezone
|
||
import json
|
||
|
||
# Load .env file before any module that reads environment variables
|
||
from dotenv import load_dotenv
|
||
load_dotenv(Path(__file__).parent.parent / ".env")
|
||
|
||
import requests as http_requests
|
||
|
||
from fastapi import FastAPI, HTTPException, Header, Depends, Security, status, Response
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from pydantic import BaseModel, Field
|
||
import numpy as np
|
||
import pandas as pd
|
||
import joblib
|
||
import mlflow
|
||
import mlflow.sklearn
|
||
import mlflow.xgboost
|
||
from sqlalchemy import update as sa_update, desc, text as sa_text
|
||
|
||
from app.config import load_config, PipelineConfig, get_default_config
|
||
from app.db import get_db, TrainingRun, init_db
|
||
from app.preprocessing import preprocess_candles, extract_feature_columns
|
||
from app.patterns import (
|
||
TALIB_PATTERNS,
|
||
get_available_patterns,
|
||
validate_pattern_names,
|
||
detect_patterns,
|
||
)
|
||
|
||
# Configure logging
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# --- API Key Dependency ---
|
||
|
||
async def verify_api_key(x_api_key: str = Header(default="")):
|
||
"""Verify X-API-Key header against API_KEY env var. Fail-open if not configured."""
|
||
api_key = os.getenv("API_KEY")
|
||
if not api_key:
|
||
return # fail-open if not configured
|
||
if x_api_key != api_key:
|
||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||
|
||
|
||
# --- User ID Header Dependency ---
|
||
|
||
async def get_user_id(x_user_id: str = Header(default="")) -> Optional[str]:
|
||
"""
|
||
Extract X-User-ID header from incoming requests.
|
||
|
||
Returns the user ID if present, or None if not provided.
|
||
This is optional — callers decide whether to enforce it.
|
||
"""
|
||
return x_user_id if x_user_id else None
|
||
|
||
|
||
# --- Lifespan ---
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
"""
|
||
FastAPI lifespan context manager.
|
||
Handles startup and shutdown logic.
|
||
"""
|
||
# --- Startup ---
|
||
logger.info("Starting inference service...")
|
||
|
||
# Ensure training_runs table exists
|
||
init_db()
|
||
|
||
# Mark any stale "running" records as failed — they belong to a previous
|
||
# process and will never complete.
|
||
try:
|
||
with get_db() as db:
|
||
stmt = (
|
||
sa_update(TrainingRun)
|
||
.where(TrainingRun.status == "running")
|
||
.values(
|
||
status="failed",
|
||
completed_at=datetime.now(timezone.utc),
|
||
metrics_summary={"error": "Service restarted while training was in progress"},
|
||
)
|
||
)
|
||
result = db.execute(stmt)
|
||
db.commit()
|
||
if result.rowcount:
|
||
logger.warning(
|
||
f"Marked {result.rowcount} stale training run(s) as failed on startup"
|
||
)
|
||
except Exception as exc:
|
||
logger.error(f"Failed to clean up stale training runs: {exc}")
|
||
|
||
# Load pipeline config
|
||
config_path = Path("config/pipeline.yaml")
|
||
if config_path.exists():
|
||
try:
|
||
state.pipeline_config = load_config(config_path)
|
||
logger.info(f"Loaded pipeline config from {config_path}")
|
||
|
||
# Load model based on config
|
||
inference_config = state.pipeline_config.stages.inference
|
||
|
||
if not inference_config.enabled:
|
||
logger.warning("Inference stage is disabled in config")
|
||
else:
|
||
# Load model
|
||
try:
|
||
if inference_config.model_source == "mlflow":
|
||
# Configure MLflow tracking URI
|
||
mlflow.set_tracking_uri(state.pipeline_config.stages.training.mlflow.tracking_uri)
|
||
|
||
state.model, state.model_info = load_model_from_mlflow(
|
||
model_name=inference_config.mlflow_model_name,
|
||
stage=inference_config.mlflow_model_stage
|
||
)
|
||
elif inference_config.model_source == "local":
|
||
state.model, state.model_info = load_model_from_local(
|
||
model_path=inference_config.local_model_path
|
||
)
|
||
else:
|
||
logger.error(f"Unknown model_source: {inference_config.model_source}")
|
||
|
||
state.feature_columns = state.model_info.get("feature_columns")
|
||
logger.info("Model loaded successfully")
|
||
logger.info(f"Model info: {state.model_info['model_name']} "
|
||
f"v{state.model_info['model_version']} "
|
||
f"({state.model_info['model_type']})")
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to load model: {e}")
|
||
logger.warning("Service will start without a model. Use /health to check status.")
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to load pipeline config: {e}")
|
||
else:
|
||
logger.warning(f"Config file not found: {config_path}. Using defaults.")
|
||
|
||
yield
|
||
|
||
# --- Shutdown (nothing to do currently) ---
|
||
|
||
|
||
# FastAPI app
|
||
app = FastAPI(
|
||
title="Candle Pattern Inference API",
|
||
description="ML inference service for candlestick pattern recognition",
|
||
version="1.0.0",
|
||
lifespan=lifespan
|
||
)
|
||
|
||
# Parse CORS origins from environment variable or use default
|
||
cors_origins_str = os.getenv("CORS_ORIGINS", "http://localhost:3000")
|
||
allow_origins = [origin.strip() for origin in cors_origins_str.split(",")]
|
||
|
||
# CORS middleware
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=allow_origins,
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# Global state
|
||
class AppState:
|
||
"""Application state container."""
|
||
model: Optional[Any] = None
|
||
model_info: Optional[Dict[str, Any]] = None
|
||
pipeline_config: Optional[PipelineConfig] = None
|
||
feature_columns: Optional[List[str]] = None
|
||
label_encoder: Optional[Dict[int, str]] = None
|
||
# Training thread management
|
||
active_training_run_id: Optional[str] = None
|
||
training_lock: threading.Lock = None
|
||
|
||
def __init__(self):
|
||
self.training_lock = threading.Lock()
|
||
|
||
state = AppState()
|
||
|
||
def _get_model_n_features(model: Any) -> Optional[int]:
|
||
"""
|
||
Extract expected feature count from a model (supports wrappers).
|
||
"""
|
||
if hasattr(model, "n_features_in_"):
|
||
try:
|
||
return int(model.n_features_in_)
|
||
except Exception:
|
||
return None
|
||
if hasattr(model, "model") and hasattr(model.model, "n_features_in_"):
|
||
try:
|
||
return int(model.model.n_features_in_)
|
||
except Exception:
|
||
return None
|
||
return None
|
||
|
||
|
||
# --- Pydantic Models ---
|
||
|
||
class CandleData(BaseModel):
|
||
"""Single candle data point."""
|
||
time: int = Field(..., description="Unix timestamp in seconds")
|
||
open: float
|
||
high: float
|
||
low: float
|
||
close: float
|
||
volume: Optional[float] = None
|
||
|
||
|
||
class PredictRequest(BaseModel):
|
||
"""Request model for /predict endpoint."""
|
||
pair: Optional[str] = Field(None, description="Trading pair (e.g., EURUSD)")
|
||
timeframe: Optional[str] = Field(None, description="Timeframe (e.g., 1H, 4H, 1D)")
|
||
candles: List[CandleData] = Field(..., min_length=1, description="Array of candle data")
|
||
|
||
|
||
class PredictionResult(BaseModel):
|
||
"""Single candle prediction."""
|
||
time: int
|
||
label: str
|
||
confidence: float = Field(..., ge=0.0, le=1.0)
|
||
|
||
|
||
class PredictionSpan(BaseModel):
|
||
"""Grouped prediction span."""
|
||
start_time: int
|
||
end_time: int
|
||
label: str
|
||
avg_confidence: float = Field(..., ge=0.0, le=1.0)
|
||
candle_count: int
|
||
|
||
|
||
class ModelInfo(BaseModel):
|
||
"""Model metadata."""
|
||
model_name: str
|
||
model_version: Optional[str] = None
|
||
model_type: str
|
||
trained_at: Optional[str] = None
|
||
dataset_version: Optional[str] = None
|
||
feature_engineering_enabled: bool
|
||
|
||
|
||
class PredictResponse(BaseModel):
|
||
"""Response model for /predict endpoint."""
|
||
predictions: List[PredictionResult]
|
||
spans: List[PredictionSpan]
|
||
model_info: ModelInfo
|
||
|
||
|
||
class BatchPredictRequest(BaseModel):
|
||
"""Request model for /predict/batch endpoint."""
|
||
pair: str
|
||
timeframe: str
|
||
start_date: str = Field(..., description="ISO format date (YYYY-MM-DD)")
|
||
end_date: str = Field(..., description="ISO format date (YYYY-MM-DD)")
|
||
|
||
|
||
class LabelInfo(BaseModel):
|
||
"""Pattern label with display color."""
|
||
name: str
|
||
color: str
|
||
|
||
|
||
class PerClassMetrics(BaseModel):
|
||
"""Per-class performance metrics."""
|
||
label: str
|
||
precision: float
|
||
recall: float
|
||
f1_score: float
|
||
support: int
|
||
|
||
|
||
class ModelInfoResponse(BaseModel):
|
||
"""Extended model information with metrics."""
|
||
model_name: str
|
||
model_version: Optional[str] = None
|
||
model_type: str
|
||
trained_at: Optional[str] = None
|
||
dataset_version: Optional[str] = None
|
||
feature_engineering_enabled: bool
|
||
labels: List[str]
|
||
per_class_metrics: List[PerClassMetrics]
|
||
|
||
|
||
class HealthResponse(BaseModel):
|
||
"""Health check response."""
|
||
status: str = Field(..., description="healthy, degraded, or unhealthy")
|
||
model_loaded: bool
|
||
mlflow: str = Field(..., description="healthy or unhealthy")
|
||
database: str = Field(..., description="healthy or unhealthy")
|
||
|
||
|
||
# --- Model Loading Functions ---
|
||
|
||
def load_model_from_mlflow(model_name: str, stage: str) -> tuple[Any, Dict[str, Any]]:
|
||
"""
|
||
Load model from MLflow model registry.
|
||
|
||
Args:
|
||
model_name: Name of the registered model
|
||
stage: Model stage (Production, Staging, None)
|
||
|
||
Returns:
|
||
Tuple of (model, model_info)
|
||
|
||
Raises:
|
||
Exception: If model not found or MLflow unavailable
|
||
"""
|
||
logger.info(f"Loading model from MLflow: {model_name} stage={stage}")
|
||
|
||
try:
|
||
# Load model from registry
|
||
model_uri = f"models:/{model_name}/{stage}"
|
||
model = mlflow.pyfunc.load_model(model_uri)
|
||
|
||
# Get model version details
|
||
client = mlflow.tracking.MlflowClient()
|
||
model_versions = client.get_latest_versions(model_name, stages=[stage])
|
||
|
||
if not model_versions:
|
||
raise ValueError(f"No model found for {model_name} at stage {stage}")
|
||
|
||
model_version = model_versions[0]
|
||
run_id = model_version.run_id
|
||
|
||
# Get run metadata
|
||
run = client.get_run(run_id)
|
||
|
||
# Extract model info
|
||
model_info = {
|
||
"model_name": model_name,
|
||
"model_version": model_version.version,
|
||
"model_type": run.data.params.get("model_type", "unknown"),
|
||
"trained_at": datetime.fromtimestamp(run.info.start_time / 1000).isoformat(),
|
||
"dataset_version": run.data.params.get("dataset_version", None),
|
||
"feature_engineering_enabled": run.data.params.get("feature_engineering", "true") == "true",
|
||
"labels": json.loads(run.data.params.get("labels", "[]")),
|
||
"per_class_metrics": []
|
||
}
|
||
|
||
# Extract per-class metrics
|
||
for key, value in run.data.metrics.items():
|
||
if key.startswith("precision_"):
|
||
label = key.replace("precision_", "")
|
||
recall = run.data.metrics.get(f"recall_{label}", 0.0)
|
||
f1 = run.data.metrics.get(f"f1_{label}", 0.0)
|
||
support = int(run.data.params.get(f"support_{label}", 0))
|
||
|
||
model_info["per_class_metrics"].append({
|
||
"label": label,
|
||
"precision": value,
|
||
"recall": recall,
|
||
"f1_score": f1,
|
||
"support": support
|
||
})
|
||
|
||
logger.info(f"Successfully loaded model: {model_name} v{model_version.version}")
|
||
return model, model_info
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to load model from MLflow: {e}")
|
||
raise
|
||
|
||
|
||
def verify_model_checksum(model_path: Path) -> None:
|
||
"""
|
||
Verify model file integrity against SHA256 checksum manifest.
|
||
|
||
Looks for checksums in models/checksums.sha256. If the manifest exists and
|
||
contains an entry for the model file, verifies the SHA256 hash matches.
|
||
Raises HTTPException on mismatch. Logs a warning and continues if the
|
||
manifest is missing or the file is not listed (fail-open for backward
|
||
compatibility).
|
||
|
||
Args:
|
||
model_path: Path to the model file to verify.
|
||
|
||
Raises:
|
||
HTTPException: If checksum verification fails (hash mismatch).
|
||
"""
|
||
manifest_path = Path("models/checksums.sha256")
|
||
model_path = Path(model_path)
|
||
|
||
if not manifest_path.exists():
|
||
logger.warning("Model checksum manifest not found at %s — skipping integrity check", manifest_path)
|
||
return
|
||
|
||
# Parse manifest: each line is "<sha256hash> <filename>"
|
||
checksums: Dict[str, str] = {}
|
||
try:
|
||
for line in manifest_path.read_text().splitlines():
|
||
line = line.strip()
|
||
if not line or line.startswith("#"):
|
||
continue
|
||
parts = line.split(None, 1)
|
||
if len(parts) == 2:
|
||
checksums[parts[1].strip()] = parts[0].strip()
|
||
except Exception as e:
|
||
logger.warning("Failed to read checksum manifest: %s — skipping integrity check", e)
|
||
return
|
||
|
||
filename = model_path.name
|
||
if filename not in checksums:
|
||
logger.warning("No checksum entry for '%s' in manifest — skipping integrity check", filename)
|
||
return
|
||
|
||
# Compute SHA256 of the actual file
|
||
sha256 = hashlib.sha256()
|
||
try:
|
||
with open(model_path, "rb") as f:
|
||
for chunk in iter(lambda: f.read(8192), b""):
|
||
sha256.update(chunk)
|
||
except Exception as e:
|
||
logger.error("Failed to compute SHA256 for '%s': %s", model_path, e)
|
||
raise HTTPException(status_code=500, detail="Internal server error")
|
||
|
||
actual_hash = sha256.hexdigest()
|
||
expected_hash = checksums[filename]
|
||
|
||
if actual_hash != expected_hash:
|
||
logger.error(
|
||
"Model integrity check FAILED for '%s': expected %s, got %s",
|
||
filename, expected_hash, actual_hash
|
||
)
|
||
raise HTTPException(status_code=500, detail="Internal server error")
|
||
|
||
logger.info("Model integrity check passed for '%s'", filename)
|
||
|
||
|
||
def load_model_from_local(model_path: str) -> tuple[Any, Dict[str, Any]]:
|
||
"""
|
||
Load model from local file using joblib.
|
||
|
||
Args:
|
||
model_path: Path to .pkl model file
|
||
|
||
Returns:
|
||
Tuple of (model, model_info)
|
||
|
||
Raises:
|
||
FileNotFoundError: If model file doesn't exist
|
||
Exception: If model loading fails
|
||
"""
|
||
model_path = Path(model_path)
|
||
logger.info(f"Loading model from local file: {model_path}")
|
||
|
||
if not model_path.exists():
|
||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||
|
||
# Verify model integrity before loading
|
||
verify_model_checksum(model_path)
|
||
|
||
try:
|
||
# Load model with joblib
|
||
model_data = joblib.load(model_path)
|
||
|
||
# Extract model and metadata
|
||
if isinstance(model_data, dict):
|
||
model = model_data.get("model")
|
||
metadata = model_data.get("metadata", {})
|
||
else:
|
||
model = model_data
|
||
metadata = {}
|
||
|
||
# Extract labels from model if not in metadata
|
||
labels = metadata.get("labels", [])
|
||
if not labels:
|
||
# Try to get class labels from the model itself
|
||
if hasattr(model, 'classes_'):
|
||
labels = [str(c) for c in model.classes_]
|
||
elif hasattr(model, 'model') and hasattr(model.model, 'classes_'):
|
||
labels = [str(c) for c in model.model.classes_]
|
||
|
||
# Build model info
|
||
model_info = {
|
||
"model_name": model_path.stem,
|
||
"model_version": metadata.get("version", "local"),
|
||
"model_type": metadata.get("model_type", "unknown"),
|
||
"trained_at": metadata.get("trained_at", None),
|
||
"dataset_version": metadata.get("dataset_version", None),
|
||
"feature_engineering_enabled": metadata.get("feature_engineering_enabled", True),
|
||
"labels": labels,
|
||
"per_class_metrics": metadata.get("per_class_metrics", []),
|
||
"feature_columns": metadata.get("feature_columns", None),
|
||
}
|
||
|
||
logger.info(f"Successfully loaded local model: {model_path.name}")
|
||
return model, model_info
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to load model from local file: {e}")
|
||
raise
|
||
|
||
|
||
# --- Health Check ---
|
||
|
||
@app.get("/health", response_model=HealthResponse)
|
||
async def health_check(response: Response):
|
||
"""
|
||
Health check endpoint.
|
||
|
||
Returns service status, model loaded status, and dependency health.
|
||
Returns HTTP 503 if any component (database or MLflow) is unhealthy.
|
||
"""
|
||
model_loaded = state.model is not None
|
||
|
||
# Check database connection with a real SELECT 1
|
||
db_status = "unhealthy"
|
||
try:
|
||
with get_db() as db:
|
||
db.execute(sa_text("SELECT 1"))
|
||
db_status = "healthy"
|
||
except Exception as exc:
|
||
logger.warning(f"Health check: database unhealthy: {exc}")
|
||
|
||
# Check MLflow connection via HTTP GET to its health endpoint
|
||
mlflow_status = "unhealthy"
|
||
try:
|
||
mlflow_tracking_uri = os.getenv("MLFLOW_TRACKING_URI", "http://localhost:5000")
|
||
mlflow_health_url = f"{mlflow_tracking_uri.rstrip('/')}/health"
|
||
resp = http_requests.get(mlflow_health_url, timeout=3)
|
||
if resp.status_code == 200:
|
||
mlflow_status = "healthy"
|
||
else:
|
||
logger.warning(
|
||
f"Health check: MLflow returned HTTP {resp.status_code}"
|
||
)
|
||
except Exception as exc:
|
||
logger.warning(f"Health check: MLflow unreachable: {exc}")
|
||
|
||
# Determine overall status
|
||
if db_status == "healthy" and mlflow_status == "healthy":
|
||
overall_status = "healthy"
|
||
elif model_loaded:
|
||
overall_status = "degraded"
|
||
else:
|
||
overall_status = "unhealthy"
|
||
|
||
# Return HTTP 503 if any dependency is unhealthy
|
||
if db_status != "healthy" or mlflow_status != "healthy":
|
||
response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
|
||
|
||
return HealthResponse(
|
||
status=overall_status,
|
||
model_loaded=model_loaded,
|
||
mlflow=mlflow_status,
|
||
database=db_status,
|
||
)
|
||
|
||
|
||
# --- Model Info Endpoints (Stubs) ---
|
||
|
||
@app.get("/model/info", response_model=ModelInfoResponse, dependencies=[Depends(verify_api_key)])
|
||
async def get_model_info():
|
||
"""
|
||
Get detailed model information and per-class metrics.
|
||
|
||
Returns HTTP 503 if no model is loaded.
|
||
"""
|
||
if state.model is None or state.model_info is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||
detail="No model available"
|
||
)
|
||
|
||
# Build response from loaded model info
|
||
return ModelInfoResponse(
|
||
model_name=state.model_info["model_name"],
|
||
model_version=state.model_info.get("model_version"),
|
||
model_type=state.model_info["model_type"],
|
||
trained_at=state.model_info.get("trained_at"),
|
||
dataset_version=state.model_info.get("dataset_version"),
|
||
feature_engineering_enabled=state.model_info["feature_engineering_enabled"],
|
||
labels=state.model_info.get("labels", []),
|
||
per_class_metrics=[
|
||
PerClassMetrics(**metric)
|
||
for metric in state.model_info.get("per_class_metrics", [])
|
||
]
|
||
)
|
||
|
||
|
||
@app.get("/model/labels", response_model=List[LabelInfo], dependencies=[Depends(verify_api_key)])
|
||
async def get_model_labels():
|
||
"""
|
||
Get all pattern labels the model can predict with display colors.
|
||
"""
|
||
if state.model is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||
detail="No model available"
|
||
)
|
||
|
||
# TODO: Load label colors from config or database
|
||
# For now, return placeholder
|
||
labels = state.model_info.get("labels", []) if state.model_info else []
|
||
|
||
return [
|
||
LabelInfo(name=label, color="#888888")
|
||
for label in labels
|
||
]
|
||
|
||
|
||
# --- Prediction Helper Functions ---
|
||
|
||
def group_prediction_spans(
|
||
predictions: List[PredictionResult]
|
||
) -> List[PredictionSpan]:
|
||
"""
|
||
Group consecutive predictions with the same label into spans.
|
||
|
||
Only groups non-"O" (non-background) predictions. "O" predictions are
|
||
treated as background and not grouped into spans.
|
||
|
||
Args:
|
||
predictions: List of per-candle predictions
|
||
|
||
Returns:
|
||
List of prediction spans
|
||
"""
|
||
if not predictions:
|
||
return []
|
||
|
||
spans = []
|
||
current_span = None
|
||
|
||
for pred in predictions:
|
||
# Skip "O" (background) predictions
|
||
if pred.label == "O":
|
||
if current_span is not None:
|
||
# End current span
|
||
spans.append(current_span)
|
||
current_span = None
|
||
continue
|
||
|
||
# Start new span or continue current span
|
||
if current_span is None or current_span.label != pred.label:
|
||
# End previous span if exists
|
||
if current_span is not None:
|
||
spans.append(current_span)
|
||
|
||
# Start new span
|
||
current_span = PredictionSpan(
|
||
start_time=pred.time,
|
||
end_time=pred.time,
|
||
label=pred.label,
|
||
avg_confidence=pred.confidence,
|
||
candle_count=1
|
||
)
|
||
else:
|
||
# Continue current span
|
||
current_span.end_time = pred.time
|
||
current_span.candle_count += 1
|
||
# Update running average confidence
|
||
total_confidence = current_span.avg_confidence * (current_span.candle_count - 1)
|
||
current_span.avg_confidence = (total_confidence + pred.confidence) / current_span.candle_count
|
||
|
||
# Don't forget last span
|
||
if current_span is not None:
|
||
spans.append(current_span)
|
||
|
||
logger.info(f"Grouped {len(predictions)} predictions into {len(spans)} spans")
|
||
|
||
return spans
|
||
|
||
|
||
# --- Prediction Endpoints ---
|
||
|
||
@app.post("/predict", response_model=PredictResponse, dependencies=[Depends(verify_api_key)])
|
||
async def predict(request: PredictRequest):
|
||
"""
|
||
Predict candlestick patterns for provided candles.
|
||
|
||
Accepts OHLCV candles, runs preprocessing, and returns predictions.
|
||
"""
|
||
if state.model is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||
detail="No model available"
|
||
)
|
||
|
||
if not request.candles:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="Candles array cannot be empty"
|
||
)
|
||
|
||
if state.pipeline_config is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="Pipeline configuration not loaded"
|
||
)
|
||
|
||
logger.info(f"Predict request: {request.pair or 'unknown'} {request.timeframe or 'unknown'}, {len(request.candles)} candles")
|
||
|
||
# Auto-sort candles by time in ascending order to ensure chronological processing
|
||
# regardless of the order in which the client sends them.
|
||
candles_sorted = sorted(request.candles, key=lambda c: c.time)
|
||
|
||
# Validate that there are no duplicate timestamps after sorting.
|
||
times = [c.time for c in candles_sorted]
|
||
duplicate_times = [t for i, t in enumerate(times) if i > 0 and t == times[i - 1]]
|
||
if duplicate_times:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=f"Duplicate candle timestamps detected: {duplicate_times[:5]}"
|
||
)
|
||
|
||
# Grab model reference under lock to prevent reading a partially-swapped model
|
||
with _model_swap_lock:
|
||
current_model = state.model
|
||
current_model_info = state.model_info
|
||
|
||
if current_model is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||
detail="No model available"
|
||
)
|
||
|
||
try:
|
||
# Convert sorted candles to list of dicts
|
||
candles_data = [candle.model_dump() for candle in candles_sorted]
|
||
|
||
# Preprocess candles (feature engineering + windowing)
|
||
training_feature_columns = (
|
||
current_model_info.get("feature_columns") if current_model_info else None
|
||
)
|
||
X, window_times = preprocess_candles(
|
||
candles_data,
|
||
state.pipeline_config,
|
||
training_feature_columns=training_feature_columns
|
||
)
|
||
|
||
expected_n_features = _get_model_n_features(current_model)
|
||
if expected_n_features is not None and X.shape[1] != expected_n_features:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=(
|
||
f"Feature mismatch: model expects {expected_n_features} features, "
|
||
f"but preprocessing produced {X.shape[1]}. "
|
||
"Ensure the loaded model matches the inference preprocessing config."
|
||
),
|
||
)
|
||
|
||
# Get predictions and probabilities (using local reference, outside lock)
|
||
if hasattr(current_model, 'predict_proba'):
|
||
y_pred = current_model.predict(X)
|
||
y_proba = current_model.predict_proba(X)
|
||
|
||
# Get confidence as max probability
|
||
confidences = np.max(y_proba, axis=1)
|
||
else:
|
||
# Fallback for models without predict_proba
|
||
y_pred = current_model.predict(X)
|
||
confidences = np.ones(len(y_pred)) # Default confidence of 1.0
|
||
|
||
# Get label names (handle both string and int predictions)
|
||
if state.label_encoder is not None:
|
||
labels = [state.label_encoder.get(int(pred), f"unknown_{pred}") for pred in y_pred]
|
||
else:
|
||
labels = [str(pred) for pred in y_pred]
|
||
|
||
# Build per-window predictions (each window maps to its last candle time)
|
||
predictions = [
|
||
PredictionResult(
|
||
time=int(time),
|
||
label=label,
|
||
confidence=float(conf)
|
||
)
|
||
for time, label, conf in zip(window_times, labels, confidences)
|
||
]
|
||
|
||
# Group into spans
|
||
spans = group_prediction_spans(predictions)
|
||
|
||
# Build model info for response
|
||
model_info = ModelInfo(
|
||
model_name=current_model_info["model_name"],
|
||
model_version=current_model_info.get("model_version"),
|
||
model_type=current_model_info["model_type"],
|
||
trained_at=current_model_info.get("trained_at"),
|
||
dataset_version=current_model_info.get("dataset_version"),
|
||
feature_engineering_enabled=current_model_info["feature_engineering_enabled"]
|
||
)
|
||
|
||
logger.info(
|
||
f"Prediction complete: {len(predictions)} windows, "
|
||
f"{len(spans)} spans, {len([p for p in predictions if p.label != 'O'])} patterns"
|
||
)
|
||
|
||
return PredictResponse(
|
||
predictions=predictions,
|
||
spans=spans,
|
||
model_info=model_info
|
||
)
|
||
|
||
except ValueError as e:
|
||
logger.error(f"Prediction validation error: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=str(e)
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Prediction failed: {e}", exc_info=True)
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="Internal server error"
|
||
)
|
||
|
||
|
||
@app.post("/predict/batch", response_model=PredictResponse, dependencies=[Depends(verify_api_key)])
|
||
async def predict_batch(request: BatchPredictRequest):
|
||
"""
|
||
Batch prediction for a date range.
|
||
|
||
Loads data from storage and returns predictions for the full range.
|
||
"""
|
||
if state.model is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||
detail="No model available"
|
||
)
|
||
|
||
if state.pipeline_config is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="Pipeline configuration not loaded"
|
||
)
|
||
|
||
# Validate date range
|
||
try:
|
||
start_dt = datetime.strptime(request.start_date, "%Y-%m-%d")
|
||
end_dt = datetime.strptime(request.end_date, "%Y-%m-%d")
|
||
except ValueError as exc:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="Invalid date format. Use YYYY-MM-DD",
|
||
)
|
||
|
||
if end_dt < start_dt:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="end_date must be after start_date",
|
||
)
|
||
|
||
if (end_dt - start_dt).days > 365:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="Date range cannot exceed 1 year",
|
||
)
|
||
|
||
logger.info(
|
||
f"Batch predict: {request.pair} {request.timeframe} "
|
||
f"from {request.start_date} to {request.end_date}"
|
||
)
|
||
|
||
# Grab model reference under lock to prevent reading a partially-swapped model
|
||
with _model_swap_lock:
|
||
current_model = state.model
|
||
current_model_info = state.model_info
|
||
|
||
if current_model is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||
detail="No model available"
|
||
)
|
||
|
||
try:
|
||
# Load OHLCV data from raw data path
|
||
raw_path = Path(state.pipeline_config.data.raw_path)
|
||
|
||
if not raw_path.exists():
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"No data found for {request.pair} {request.timeframe}"
|
||
)
|
||
|
||
# Load data
|
||
df_raw = pd.read_csv(raw_path)
|
||
|
||
# Filter by date range if time column exists
|
||
if 'time' in df_raw.columns:
|
||
# Parse dates to timestamps
|
||
start_ts = int(pd.Timestamp(request.start_date).timestamp())
|
||
end_ts = int(pd.Timestamp(request.end_date).timestamp())
|
||
|
||
# Filter data
|
||
df_filtered = df_raw[
|
||
(df_raw['time'] >= start_ts) & (df_raw['time'] <= end_ts)
|
||
].copy()
|
||
|
||
if len(df_filtered) == 0:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"No data found in date range {request.start_date} to {request.end_date}"
|
||
)
|
||
|
||
logger.info(f"Loaded {len(df_filtered)} candles from {request.start_date} to {request.end_date}")
|
||
else:
|
||
df_filtered = df_raw
|
||
logger.warning("No 'time' column found, using all data")
|
||
|
||
# Get batch size from config
|
||
batch_size = state.pipeline_config.stages.inference.batch_size
|
||
|
||
# Process in batches
|
||
all_predictions = []
|
||
all_spans = []
|
||
|
||
for i in range(0, len(df_filtered), batch_size):
|
||
batch_df = df_filtered.iloc[i:i + batch_size]
|
||
|
||
logger.info(f"Processing batch {i // batch_size + 1}: rows {i} to {i + len(batch_df)}")
|
||
|
||
# Convert batch to candles format
|
||
batch_candles = batch_df.to_dict('records')
|
||
|
||
# Preprocess (feature engineering + windowing)
|
||
training_feature_columns = (
|
||
current_model_info.get("feature_columns") if current_model_info else None
|
||
)
|
||
X, window_times = preprocess_candles(
|
||
batch_candles,
|
||
state.pipeline_config,
|
||
training_feature_columns=training_feature_columns
|
||
)
|
||
|
||
expected_n_features = _get_model_n_features(current_model)
|
||
if expected_n_features is not None and X.shape[1] != expected_n_features:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=(
|
||
f"Feature mismatch: model expects {expected_n_features} features, "
|
||
f"but preprocessing produced {X.shape[1]}. "
|
||
"Ensure the loaded model matches the inference preprocessing config."
|
||
),
|
||
)
|
||
|
||
# Predict (using local reference, outside lock)
|
||
if hasattr(current_model, 'predict_proba'):
|
||
y_pred = current_model.predict(X)
|
||
y_proba = current_model.predict_proba(X)
|
||
confidences = np.max(y_proba, axis=1)
|
||
else:
|
||
y_pred = current_model.predict(X)
|
||
confidences = np.ones(len(y_pred))
|
||
|
||
# Get labels
|
||
if state.label_encoder is not None:
|
||
labels = [state.label_encoder.get(int(pred), f"unknown_{pred}") for pred in y_pred]
|
||
else:
|
||
labels = [str(pred) for pred in y_pred]
|
||
|
||
# Build predictions for this batch
|
||
batch_predictions = [
|
||
PredictionResult(
|
||
time=int(time),
|
||
label=label,
|
||
confidence=float(conf)
|
||
)
|
||
for time, label, conf in zip(window_times, labels, confidences)
|
||
]
|
||
|
||
all_predictions.extend(batch_predictions)
|
||
|
||
# Group all predictions into spans
|
||
all_spans = group_prediction_spans(all_predictions)
|
||
|
||
# Build model info
|
||
model_info = ModelInfo(
|
||
model_name=current_model_info["model_name"],
|
||
model_version=current_model_info.get("model_version"),
|
||
model_type=current_model_info["model_type"],
|
||
trained_at=current_model_info.get("trained_at"),
|
||
dataset_version=current_model_info.get("dataset_version"),
|
||
feature_engineering_enabled=current_model_info["feature_engineering_enabled"]
|
||
)
|
||
|
||
logger.info(
|
||
f"Batch prediction complete: {len(all_predictions)} candles, "
|
||
f"{len(all_spans)} spans"
|
||
)
|
||
|
||
return PredictResponse(
|
||
predictions=all_predictions,
|
||
spans=all_spans,
|
||
model_info=model_info
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"Batch prediction failed: {e}", exc_info=True)
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="Internal server error"
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Pattern Detection Endpoints
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class PatternInfo(BaseModel):
|
||
"""A single supported CDL pattern."""
|
||
function_name: str
|
||
display_name: str
|
||
|
||
|
||
class DetectPatternsRequest(BaseModel):
|
||
"""Request model for POST /patterns/detect."""
|
||
candles: List[CandleData] = Field(..., min_length=1, description="Array of candle data")
|
||
patterns: List[str] = Field(
|
||
default=[],
|
||
description="CDL function names to run. Empty list means run all.",
|
||
)
|
||
|
||
|
||
class SpanAnnotation(BaseModel):
|
||
"""A span annotation returned by pattern detection."""
|
||
start_time: int
|
||
end_time: int
|
||
label: str
|
||
confidence: float = Field(..., ge=0.0, le=1.0)
|
||
source: str
|
||
notes: str
|
||
|
||
|
||
class DetectPatternsResponse(BaseModel):
|
||
"""Response model for POST /patterns/detect."""
|
||
annotations: List[SpanAnnotation]
|
||
metadata: Dict[str, Any]
|
||
|
||
|
||
@app.get("/patterns/available", response_model=List[PatternInfo], dependencies=[Depends(verify_api_key)])
|
||
async def patterns_available():
|
||
"""
|
||
Return all supported CDL pattern names with display names.
|
||
"""
|
||
return get_available_patterns()
|
||
|
||
|
||
@app.post("/patterns/detect", response_model=DetectPatternsResponse, dependencies=[Depends(verify_api_key)])
|
||
async def patterns_detect(request: DetectPatternsRequest):
|
||
"""
|
||
Detect TA-Lib CDL patterns on provided candle data.
|
||
|
||
- Empty ``patterns`` list runs all available CDL functions.
|
||
- Invalid pattern names return HTTP 400.
|
||
"""
|
||
# 1.4 – Validate pattern names
|
||
if request.patterns:
|
||
invalid = validate_pattern_names(request.patterns)
|
||
if invalid:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=f"Invalid pattern name(s): {', '.join(invalid)}. "
|
||
f"Use GET /patterns/available to see supported patterns.",
|
||
)
|
||
|
||
candles_data = [c.model_dump() for c in request.candles]
|
||
|
||
try:
|
||
raw_annotations = detect_patterns(candles_data, request.patterns or None)
|
||
except RuntimeError as exc:
|
||
logger.error(f"Pattern detection runtime error: {exc}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||
detail="Internal server error",
|
||
)
|
||
|
||
annotations = [SpanAnnotation(**ann) for ann in raw_annotations]
|
||
|
||
return DetectPatternsResponse(
|
||
annotations=annotations,
|
||
metadata={"source": "talib", "count": len(annotations)},
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Training Endpoints
|
||
# ---------------------------------------------------------------------------
|
||
|
||
SUPPORTED_MODEL_TYPES = ["random_forest", "xgboost"]
|
||
|
||
|
||
class TrainingStartRequest(BaseModel):
|
||
"""Request model for POST /training/start."""
|
||
model_type: str = Field(
|
||
"random_forest",
|
||
description="Model type: random_forest or xgboost",
|
||
)
|
||
chart_id: Optional[int] = Field(
|
||
default=None,
|
||
ge=1,
|
||
description="Chart ID to train on. If omitted, falls back to first chart.",
|
||
)
|
||
|
||
|
||
class TrainingStartResponse(BaseModel):
|
||
"""Response model for POST /training/start."""
|
||
run_id: str
|
||
status: str
|
||
|
||
|
||
class TrainingRunInfo(BaseModel):
|
||
"""Summary of a single training run."""
|
||
run_id: str
|
||
model_type: str
|
||
status: str
|
||
experiment_name: Optional[str] = None
|
||
created_at: Optional[str] = None
|
||
completed_at: Optional[str] = None
|
||
metrics_summary: Optional[Dict[str, Any]] = None
|
||
|
||
|
||
class TrainingRunsResponse(BaseModel):
|
||
"""Response model for GET /training/runs."""
|
||
runs: List[TrainingRunInfo]
|
||
|
||
|
||
class DatasetInfoResponse(BaseModel):
|
||
"""Response model for GET /training/dataset-info."""
|
||
path: str
|
||
exists: bool
|
||
size_bytes: Optional[int] = None
|
||
last_modified: Optional[str] = None
|
||
row_count: Optional[int] = None
|
||
|
||
|
||
def build_dataset_from_db(config: PipelineConfig, chart_id: Optional[int] = None) -> dict:
|
||
"""
|
||
Build the labeled training dataset directly from the database.
|
||
|
||
Steps:
|
||
1. Export candles from PostgreSQL to raw CSV
|
||
2. Run feature engineering (TA-Lib indicators, candle features)
|
||
3. Run annotation ingestion from DB (span_annotations -> labeled CSV)
|
||
|
||
Returns:
|
||
dict with keys: chart_name, n_candles, n_annotations, n_samples, labeled_path
|
||
"""
|
||
from app.data_access import DataAccess
|
||
from app.annotation_ingestion import AnnotationIngestion
|
||
from features.engineer import run_feature_engineering_stage
|
||
|
||
data_access = DataAccess()
|
||
|
||
# Resolve target chart
|
||
if chart_id is not None:
|
||
chart = data_access.get_chart_by_id(chart_id)
|
||
if not chart:
|
||
raise ValueError(f"Chart not found: id={chart_id}")
|
||
chart_name = chart["name"]
|
||
chart_id_int = int(chart["id"])
|
||
else:
|
||
charts_df = data_access.get_all_charts()
|
||
if charts_df.empty:
|
||
raise ValueError("No charts found in database. Upload candle data first.")
|
||
chart = charts_df.iloc[0]
|
||
chart_name = chart["name"]
|
||
chart_id_int = int(chart["id"])
|
||
|
||
logger.info(f"Building dataset for chart: {chart_name} (id={chart_id_int})")
|
||
|
||
# Step 1: Export candles to raw CSV
|
||
candles_df = data_access.get_candles(chart_id_int)
|
||
if candles_df.empty:
|
||
raise ValueError(f"No candles found for chart: {chart_name}")
|
||
|
||
raw_path = Path(config.data.raw_path)
|
||
raw_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
# Ensure 'time' column is suitable for feature engineering
|
||
export_df = candles_df[["time", "open", "high", "low", "close"]].copy()
|
||
export_df.to_csv(raw_path, index=False)
|
||
logger.info(f"Exported {len(export_df)} candles to {raw_path}")
|
||
|
||
# Step 2: Run feature engineering
|
||
run_feature_engineering_stage(config)
|
||
enriched_path = Path(config.data.enriched_path)
|
||
logger.info(f"Feature engineering complete: {enriched_path}")
|
||
|
||
# Step 3: Run annotation ingestion from database
|
||
enriched_df = pd.read_csv(enriched_path, parse_dates=["time"])
|
||
ingestion = AnnotationIngestion(config.stages.annotation_ingestion)
|
||
# Include all annotation sources so TA-Lib generated spans (source='talib')
|
||
# can be used for training alongside manual labels.
|
||
labeled_df = ingestion.process_from_db(enriched_df, chart_name, source=None)
|
||
|
||
if labeled_df.empty:
|
||
raise ValueError(
|
||
f"No labeled samples produced. "
|
||
f"Ensure you have span annotations on chart '{chart_name}'."
|
||
)
|
||
|
||
# Write labeled dataset
|
||
labeled_path = Path(config.data.labeled_path)
|
||
labeled_path.parent.mkdir(parents=True, exist_ok=True)
|
||
labeled_df.to_csv(labeled_path, index=False)
|
||
|
||
result = {
|
||
"chart_name": chart_name,
|
||
"n_candles": len(export_df),
|
||
"n_samples": len(labeled_df),
|
||
"n_features": len([c for c in labeled_df.columns if c != "label"]),
|
||
"labeled_path": str(labeled_path),
|
||
}
|
||
logger.info(f"Dataset built: {result}")
|
||
return result
|
||
|
||
|
||
def _run_training_background(
|
||
run_id: str,
|
||
model_type: str,
|
||
config: PipelineConfig,
|
||
chart_id: Optional[int] = None,
|
||
user_id: Optional[str] = None,
|
||
) -> None:
|
||
"""
|
||
Background thread target: build dataset then train a model.
|
||
|
||
Uses the pre-inserted TrainingRun record identified by ``run_id``.
|
||
|
||
Args:
|
||
run_id: Training run ID
|
||
model_type: Type of model to train
|
||
config: Pipeline configuration
|
||
chart_id: Optional chart ID to train on
|
||
user_id: Optional user ID for scoped experiment naming
|
||
"""
|
||
logger.info(f"Training thread started: run_id={run_id}, model_type={model_type}, user_id={user_id}")
|
||
|
||
try:
|
||
# Import training utilities here to avoid circular import issues
|
||
from training.train import create_model, temporal_split
|
||
from sklearn.metrics import accuracy_score, f1_score
|
||
|
||
# Set up MLflow experiment with user scoping
|
||
mlflow_config = config.stages.training.mlflow
|
||
mlflow.set_tracking_uri(mlflow_config.tracking_uri)
|
||
|
||
# Use user-scoped experiment name if user_id provided, otherwise use default
|
||
if user_id:
|
||
experiment_name = f"user_{user_id}_training"
|
||
else:
|
||
experiment_name = mlflow_config.experiment_name
|
||
|
||
mlflow.set_experiment(experiment_name)
|
||
logger.info(f"MLflow experiment set to: {experiment_name}")
|
||
|
||
# Build dataset from database (feature engineering + annotation ingestion)
|
||
logger.info("Building dataset from database...")
|
||
build_dataset_from_db(config, chart_id=chart_id)
|
||
|
||
labeled_path = Path(config.data.labeled_path)
|
||
if not labeled_path.exists():
|
||
raise FileNotFoundError(f"Labeled dataset not found: {labeled_path}")
|
||
|
||
# Load dataset
|
||
df = pd.read_csv(labeled_path)
|
||
|
||
if "label" not in df.columns:
|
||
raise ValueError("Labeled dataset must have 'label' column")
|
||
|
||
# Check dataset size: reject if it exceeds 500MB in memory
|
||
_DATASET_SIZE_LIMIT_BYTES = 500 * 1024 * 1024 # 500 MB
|
||
dataset_size_bytes = df.memory_usage(deep=True).sum()
|
||
if dataset_size_bytes > _DATASET_SIZE_LIMIT_BYTES:
|
||
raise ValueError(
|
||
f"Dataset too large. Maximum size is 500MB "
|
||
f"(current size: {dataset_size_bytes / (1024 * 1024):.1f}MB)."
|
||
)
|
||
logger.info(
|
||
f"Dataset size check passed: {dataset_size_bytes / (1024 * 1024):.1f}MB"
|
||
)
|
||
|
||
feature_cols = [
|
||
col for col in df.columns
|
||
if col not in ("label", "time", "timestamp")
|
||
and not col.startswith("label_programmatic_")
|
||
]
|
||
|
||
X = df[feature_cols].values
|
||
y = df["label"].values
|
||
|
||
logger.info(f"Loaded {len(X)} samples, {len(feature_cols)} features")
|
||
|
||
# Split data
|
||
training_cfg = config.stages.training
|
||
X_train, X_val, X_test, y_train, y_val, y_test = temporal_split(
|
||
X, y, training_cfg.test_split, training_cfg.validation_split
|
||
)
|
||
|
||
# Run model training with a 30-minute timeout
|
||
_TRAINING_TIMEOUT_SECONDS = 1800 # 30 minutes
|
||
|
||
def _do_train():
|
||
_model = create_model(
|
||
model_type, training_cfg.hyperparameters, training_cfg.class_weights
|
||
)
|
||
_model.fit(X_train, y_train)
|
||
return _model
|
||
|
||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||
future = executor.submit(_do_train)
|
||
try:
|
||
model_instance = future.result(timeout=_TRAINING_TIMEOUT_SECONDS)
|
||
except concurrent.futures.TimeoutError:
|
||
logger.error(
|
||
f"Training timed out after 30 minutes: run_id={run_id}"
|
||
)
|
||
try:
|
||
with get_db() as db:
|
||
stmt = (
|
||
sa_update(TrainingRun)
|
||
.where(TrainingRun.run_id == run_id)
|
||
.values(
|
||
status="failed",
|
||
completed_at=datetime.now(timezone.utc),
|
||
metrics_summary={
|
||
"error": "Training timed out after 30 minutes"
|
||
},
|
||
)
|
||
)
|
||
db.execute(stmt)
|
||
db.commit()
|
||
except Exception as db_exc:
|
||
logger.error(
|
||
f"Failed to update DB for timed-out run {run_id}: {db_exc}"
|
||
)
|
||
return
|
||
|
||
logger.info("Model training complete")
|
||
|
||
# Evaluate
|
||
y_val_pred = model_instance.predict(X_val)
|
||
y_test_pred = model_instance.predict(X_test)
|
||
|
||
metrics = {
|
||
"val_accuracy": float(accuracy_score(y_val, y_val_pred)),
|
||
"val_f1_macro": float(
|
||
f1_score(y_val, y_val_pred, average="macro", zero_division=0)
|
||
),
|
||
"test_accuracy": float(accuracy_score(y_test, y_test_pred)),
|
||
"test_f1_macro": float(
|
||
f1_score(y_test, y_test_pred, average="macro", zero_division=0)
|
||
),
|
||
"n_samples": int(len(X)),
|
||
"n_features": int(X.shape[1]),
|
||
}
|
||
|
||
# Save model locally
|
||
models_dir = Path("models")
|
||
models_dir.mkdir(exist_ok=True)
|
||
model_path = models_dir / f"{run_id}.pkl"
|
||
|
||
model_data = {
|
||
"model": model_instance,
|
||
"metadata": {
|
||
"model_type": model_type,
|
||
"trained_at": datetime.now(timezone.utc).isoformat(),
|
||
"run_id": run_id,
|
||
"feature_columns": feature_cols,
|
||
"labels": (
|
||
[str(c) for c in model_instance.model.classes_]
|
||
if hasattr(model_instance, "model") and hasattr(model_instance.model, "classes_")
|
||
else []
|
||
),
|
||
},
|
||
}
|
||
joblib.dump(model_data, model_path)
|
||
logger.info(f"Model saved to {model_path}")
|
||
|
||
# Update DB: completed
|
||
with get_db() as db:
|
||
stmt = (
|
||
sa_update(TrainingRun)
|
||
.where(TrainingRun.run_id == run_id)
|
||
.values(
|
||
status="completed",
|
||
completed_at=datetime.now(timezone.utc),
|
||
metrics_summary=metrics,
|
||
)
|
||
)
|
||
db.execute(stmt)
|
||
db.commit()
|
||
|
||
logger.info(f"Training completed successfully: run_id={run_id}")
|
||
|
||
except Exception as exc:
|
||
logger.error(f"Training failed for run_id={run_id}: {exc}", exc_info=True)
|
||
try:
|
||
with get_db() as db:
|
||
stmt = (
|
||
sa_update(TrainingRun)
|
||
.where(TrainingRun.run_id == run_id)
|
||
.values(
|
||
status="failed",
|
||
completed_at=datetime.now(timezone.utc),
|
||
metrics_summary={"error": str(exc)},
|
||
)
|
||
)
|
||
db.execute(stmt)
|
||
db.commit()
|
||
except Exception as db_exc:
|
||
logger.error(f"Failed to update DB for failed run {run_id}: {db_exc}")
|
||
finally:
|
||
with state.training_lock:
|
||
if state.active_training_run_id == run_id:
|
||
state.active_training_run_id = None
|
||
logger.info(f"Training thread exiting: run_id={run_id}")
|
||
|
||
|
||
@app.post("/training/start", response_model=TrainingStartResponse, dependencies=[Depends(verify_api_key)])
|
||
async def training_start(request: TrainingStartRequest, user_id: Optional[str] = Depends(get_user_id)):
|
||
"""
|
||
Start a training run in a background thread.
|
||
|
||
Returns immediately with run_id and status "running".
|
||
Rejects concurrent runs with HTTP 409.
|
||
|
||
Args:
|
||
request: Training request parameters
|
||
user_id: Optional user ID from X-User-ID header for scoped experiments
|
||
"""
|
||
# Validate model type
|
||
if request.model_type not in SUPPORTED_MODEL_TYPES:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=f"Unsupported model type. Available: {', '.join(SUPPORTED_MODEL_TYPES)}",
|
||
)
|
||
|
||
if request.chart_id is not None:
|
||
from app.data_access import DataAccess
|
||
chart = DataAccess().get_chart_by_id(request.chart_id)
|
||
if not chart:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=f"Chart not found: id={request.chart_id}",
|
||
)
|
||
|
||
# Reject concurrent runs (atomic check-and-set)
|
||
with state.training_lock:
|
||
if state.active_training_run_id is not None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_409_CONFLICT,
|
||
detail={
|
||
"error": "Training already in progress",
|
||
"run_id": state.active_training_run_id,
|
||
},
|
||
)
|
||
run_id = str(uuid.uuid4())
|
||
state.active_training_run_id = run_id
|
||
|
||
config = state.pipeline_config or get_default_config()
|
||
|
||
# Compute config hash (best-effort)
|
||
config_hash = "unknown"
|
||
try:
|
||
from training.train import compute_config_hash
|
||
config_hash = compute_config_hash(config)
|
||
except Exception:
|
||
pass
|
||
|
||
# Pre-insert the run record so callers can track it immediately
|
||
try:
|
||
# Compute scoped experiment name
|
||
mlflow_config = config.stages.training.mlflow
|
||
if user_id:
|
||
experiment_name = f"user_{user_id}_training"
|
||
else:
|
||
experiment_name = mlflow_config.experiment_name
|
||
|
||
with get_db() as db:
|
||
training_run = TrainingRun(
|
||
run_id=run_id,
|
||
user_id=user_id,
|
||
model_type=request.model_type,
|
||
experiment_name=experiment_name,
|
||
pipeline_config_hash=config_hash,
|
||
status="running",
|
||
created_at=datetime.now(timezone.utc),
|
||
metrics_summary={},
|
||
)
|
||
db.add(training_run)
|
||
db.commit()
|
||
except Exception as exc:
|
||
with state.training_lock:
|
||
state.active_training_run_id = None
|
||
logger.error(f"Failed to insert training run record: {exc}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="Internal server error",
|
||
)
|
||
|
||
# Launch background thread (daemon so it doesn't block process exit)
|
||
thread = threading.Thread(
|
||
target=_run_training_background,
|
||
args=(run_id, request.model_type, config, request.chart_id, user_id),
|
||
daemon=True,
|
||
name=f"training-{run_id[:8]}",
|
||
)
|
||
thread.start()
|
||
|
||
logger.info(f"Training started: run_id={run_id}, model_type={request.model_type}, user_id={user_id or 'default'}")
|
||
return TrainingStartResponse(run_id=run_id, status="running")
|
||
|
||
|
||
@app.get("/training/runs", response_model=TrainingRunsResponse, dependencies=[Depends(verify_api_key)])
|
||
async def training_runs(user_id: Optional[str] = Depends(get_user_id)):
|
||
"""
|
||
Return training run history from the database, sorted by date descending.
|
||
|
||
Filters results to only include runs belonging to the requesting user
|
||
(identified by the X-User-ID header). If no user ID is provided, returns
|
||
an empty list to prevent data leakage across users.
|
||
"""
|
||
try:
|
||
from sqlalchemy import select
|
||
with get_db() as db:
|
||
stmt = select(TrainingRun).order_by(desc(TrainingRun.created_at))
|
||
if user_id:
|
||
stmt = stmt.where(TrainingRun.user_id == user_id)
|
||
else:
|
||
# No user context — return nothing to prevent data leakage
|
||
return TrainingRunsResponse(runs=[])
|
||
rows = db.execute(stmt).scalars().all()
|
||
|
||
runs = [
|
||
TrainingRunInfo(
|
||
run_id=row.run_id,
|
||
model_type=row.model_type,
|
||
status=row.status,
|
||
experiment_name=row.experiment_name,
|
||
created_at=row.created_at.isoformat() if row.created_at else None,
|
||
completed_at=row.completed_at.isoformat() if row.completed_at else None,
|
||
metrics_summary=row.metrics_summary,
|
||
)
|
||
for row in rows
|
||
]
|
||
except Exception as exc:
|
||
logger.error(f"Failed to fetch training runs: {exc}", exc_info=True)
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="Internal server error",
|
||
)
|
||
|
||
return TrainingRunsResponse(runs=runs)
|
||
|
||
|
||
class ActiveTrainingResponse(BaseModel):
|
||
"""Response model for GET /training/active."""
|
||
active: bool
|
||
run_id: Optional[str] = None
|
||
|
||
|
||
@app.get("/training/active", response_model=ActiveTrainingResponse, dependencies=[Depends(verify_api_key)])
|
||
async def training_active():
|
||
"""
|
||
Return whether a training run is currently active and its run_id.
|
||
"""
|
||
with state.training_lock:
|
||
run_id = state.active_training_run_id
|
||
return ActiveTrainingResponse(active=run_id is not None, run_id=run_id)
|
||
|
||
|
||
class DeleteRunResponse(BaseModel):
|
||
"""Response model for DELETE /training/runs/{run_id}."""
|
||
run_id: str
|
||
deleted: bool
|
||
|
||
|
||
@app.delete("/training/runs/{run_id}", response_model=DeleteRunResponse, dependencies=[Depends(verify_api_key)])
|
||
async def delete_training_run(run_id: str, user_id: Optional[str] = Depends(get_user_id)):
|
||
"""
|
||
Delete a training run record and its model artifact.
|
||
|
||
Returns HTTP 400 if the run_id format is invalid.
|
||
Returns HTTP 409 if the run is currently active.
|
||
Returns HTTP 404 if the run_id doesn't exist or belongs to a different user.
|
||
"""
|
||
from sqlalchemy import select, delete as sa_delete
|
||
|
||
# Validate run_id format to prevent path traversal
|
||
if not re.match(r'^[a-zA-Z0-9_-]+$', run_id):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="Invalid run_id format",
|
||
)
|
||
|
||
# Reject deletion of the active run
|
||
with state.training_lock:
|
||
if state.active_training_run_id == run_id:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_409_CONFLICT,
|
||
detail="Cannot delete an active training run",
|
||
)
|
||
|
||
try:
|
||
with get_db() as db:
|
||
stmt = select(TrainingRun).where(TrainingRun.run_id == run_id)
|
||
row = db.execute(stmt).scalar_one_or_none()
|
||
|
||
if row is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"Training run not found: {run_id}",
|
||
)
|
||
|
||
# Enforce user ownership: return 404 (not 403) to avoid leaking existence
|
||
if user_id and row.user_id and row.user_id != user_id:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"Training run not found: {run_id}",
|
||
)
|
||
|
||
db.execute(sa_delete(TrainingRun).where(TrainingRun.run_id == run_id))
|
||
db.commit()
|
||
except HTTPException:
|
||
raise
|
||
except Exception as exc:
|
||
logger.error(f"Failed to delete training run {run_id}: {exc}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="Internal server error",
|
||
)
|
||
|
||
# Remove model artifact if it exists
|
||
models_base = Path("models").resolve()
|
||
model_path = (Path("models") / f"{run_id}.pkl").resolve()
|
||
if not str(model_path).startswith(str(models_base) + "/"):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="Invalid run_id format",
|
||
)
|
||
if model_path.exists():
|
||
try:
|
||
model_path.unlink()
|
||
logger.info(f"Deleted model artifact: {model_path}")
|
||
except Exception as exc:
|
||
logger.warning(f"Could not delete model artifact {model_path}: {exc}")
|
||
|
||
logger.info(f"Deleted training run: {run_id}")
|
||
return DeleteRunResponse(run_id=run_id, deleted=True)
|
||
|
||
|
||
@app.get("/training/runs/{run_id}", response_model=TrainingRunInfo, dependencies=[Depends(verify_api_key)])
|
||
async def get_training_run(run_id: str, user_id: Optional[str] = Depends(get_user_id)):
|
||
"""
|
||
Get information about a specific training run by run_id.
|
||
|
||
Returns HTTP 400 if the run_id format is invalid.
|
||
Returns HTTP 404 if the run_id doesn't exist or belongs to a different user.
|
||
"""
|
||
from sqlalchemy import select
|
||
|
||
# Validate run_id format to prevent path traversal
|
||
if not re.match(r'^[a-zA-Z0-9_-]+$', run_id):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="Invalid run_id format",
|
||
)
|
||
|
||
try:
|
||
with get_db() as db:
|
||
stmt = select(TrainingRun).where(TrainingRun.run_id == run_id)
|
||
row = db.execute(stmt).scalar_one_or_none()
|
||
|
||
if row is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"Training run not found: {run_id}",
|
||
)
|
||
|
||
# Enforce user ownership: return 404 (not 403) to avoid leaking existence
|
||
if user_id and row.user_id and row.user_id != user_id:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"Training run not found: {run_id}",
|
||
)
|
||
|
||
return TrainingRunInfo(
|
||
run_id=row.run_id,
|
||
model_type=row.model_type,
|
||
status=row.status,
|
||
experiment_name=row.experiment_name,
|
||
created_at=row.created_at.isoformat() if row.created_at else None,
|
||
completed_at=row.completed_at.isoformat() if row.completed_at else None,
|
||
metrics_summary=row.metrics_summary,
|
||
)
|
||
except HTTPException:
|
||
raise
|
||
except Exception as exc:
|
||
logger.error(f"Failed to fetch training run {run_id}: {exc}", exc_info=True)
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="Internal server error",
|
||
)
|
||
|
||
|
||
@app.get("/training/dataset-info", response_model=DatasetInfoResponse, dependencies=[Depends(verify_api_key)])
|
||
async def training_dataset_info():
|
||
"""
|
||
Return information about the labeled training dataset.
|
||
|
||
Includes file path, existence, size, last modified date, and row count.
|
||
"""
|
||
config = state.pipeline_config or get_default_config()
|
||
labeled_path = Path(config.data.labeled_path)
|
||
|
||
if not labeled_path.exists():
|
||
return DatasetInfoResponse(path=str(labeled_path), exists=False)
|
||
|
||
try:
|
||
stat = labeled_path.stat()
|
||
size_bytes = stat.st_size
|
||
last_modified = datetime.fromtimestamp(stat.st_mtime).isoformat()
|
||
|
||
row_count = None
|
||
try:
|
||
# Read only one column for efficiency
|
||
df_head = pd.read_csv(labeled_path, usecols=[0])
|
||
row_count = len(df_head)
|
||
except Exception:
|
||
pass
|
||
|
||
return DatasetInfoResponse(
|
||
path=str(labeled_path),
|
||
exists=True,
|
||
size_bytes=size_bytes,
|
||
last_modified=last_modified,
|
||
row_count=row_count,
|
||
)
|
||
except Exception as exc:
|
||
logger.error(f"Failed to get dataset info: {exc}", exc_info=True)
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="Internal server error",
|
||
)
|
||
|
||
|
||
class BuildDatasetResponse(BaseModel):
|
||
"""Response model for POST /training/build-dataset."""
|
||
chart_name: str
|
||
n_candles: int
|
||
n_samples: int
|
||
n_features: int
|
||
labeled_path: str
|
||
|
||
|
||
@app.post("/training/build-dataset", response_model=BuildDatasetResponse, dependencies=[Depends(verify_api_key)])
|
||
async def training_build_dataset():
|
||
"""
|
||
Build the labeled training dataset from database annotations.
|
||
|
||
Exports candles, runs feature engineering, and ingests span annotations
|
||
into a labeled CSV ready for training.
|
||
"""
|
||
config = state.pipeline_config or get_default_config()
|
||
|
||
try:
|
||
result = build_dataset_from_db(config)
|
||
return BuildDatasetResponse(**result)
|
||
except ValueError as exc:
|
||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Internal server error")
|
||
except Exception as exc:
|
||
logger.error(f"Failed to build dataset: {exc}", exc_info=True)
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="Internal server error",
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Model Loading Endpoint
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class ModelLoadRequest(BaseModel):
|
||
"""Request model for POST /model/load."""
|
||
run_id: str = Field(..., description="Training run ID to load model from")
|
||
|
||
|
||
class ModelLoadResponse(BaseModel):
|
||
"""Response model for POST /model/load."""
|
||
run_id: str
|
||
model_type: str
|
||
status: str
|
||
|
||
|
||
# Lock protecting model hot-swap
|
||
_model_swap_lock = threading.Lock()
|
||
|
||
|
||
@app.post("/model/load", response_model=ModelLoadResponse, dependencies=[Depends(verify_api_key)])
|
||
async def model_load(request: ModelLoadRequest):
|
||
"""
|
||
Load a trained model from a completed training run.
|
||
|
||
Looks up the run_id in the training_runs table, loads the model artifact,
|
||
and replaces the active model in AppState with a brief lock to prevent
|
||
conflicts with in-flight prediction requests.
|
||
"""
|
||
from sqlalchemy import select
|
||
|
||
# 0. Validate run_id format to prevent path traversal
|
||
if not re.match(r'^[a-zA-Z0-9_-]+$', request.run_id):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="Invalid run_id format",
|
||
)
|
||
|
||
# 1. Look up the training run
|
||
try:
|
||
with get_db() as db:
|
||
stmt = select(TrainingRun).where(TrainingRun.run_id == request.run_id)
|
||
row = db.execute(stmt).scalar_one_or_none()
|
||
except Exception as exc:
|
||
logger.error(f"DB lookup failed for run_id={request.run_id}: {exc}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="Internal server error",
|
||
)
|
||
|
||
if row is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"Training run not found: {request.run_id}",
|
||
)
|
||
|
||
if row.status != "completed":
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=f"Training run is not completed (status={row.status})",
|
||
)
|
||
|
||
# 2. Resolve model artifact path
|
||
models_base = Path("models").resolve()
|
||
model_path = (Path("models") / f"{request.run_id}.pkl").resolve()
|
||
if not str(model_path).startswith(str(models_base) + "/"):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="Invalid run_id format",
|
||
)
|
||
if not model_path.exists():
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"Model artifact not found at {model_path}",
|
||
)
|
||
|
||
# 3. Load model (outside lock – can be slow)
|
||
try:
|
||
new_model, new_model_info = load_model_from_local(str(model_path))
|
||
except Exception as exc:
|
||
logger.error(f"Failed to load model for run_id={request.run_id}: {exc}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="Internal server error",
|
||
)
|
||
|
||
# 4. Thread-safe model swap (3.2 – brief lock)
|
||
with _model_swap_lock:
|
||
state.model = new_model
|
||
state.model_info = new_model_info
|
||
state.feature_columns = new_model_info.get("feature_columns")
|
||
|
||
logger.info(f"Model hot-swapped: run_id={request.run_id}, type={row.model_type}")
|
||
|
||
return ModelLoadResponse(
|
||
run_id=request.run_id,
|
||
model_type=row.model_type,
|
||
status="loaded",
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
uvicorn.run(app, host="0.0.0.0", port=8001)
|