candle-annotator/services/ml/app/main.py
Marko Djordjevic 399e760fa5 Add X-User-ID header extraction to FastAPI service.
Create a get_user_id() dependency that extracts the X-User-ID header
from incoming requests, making it available to route handlers.
The dependency is optional (not enforced) — callers decide whether
to use it or require it on specific routes.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-02-20 13:44:02 +01:00

1837 lines
63 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
FastAPI inference service for candlestick pattern prediction.
Provides REST API endpoints for model serving, health checks, and prediction.
"""
import concurrent.futures
import hashlib
import logging
import os
import re
import threading
import uuid
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Optional, Dict, Any, List
from datetime import datetime, timezone
import json
import requests as http_requests
from fastapi import FastAPI, HTTPException, Header, Depends, Security, status, Response
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import numpy as np
import pandas as pd
import joblib
import mlflow
import mlflow.sklearn
import mlflow.xgboost
from sqlalchemy import update as sa_update, desc, text as sa_text
from app.config import load_config, PipelineConfig, get_default_config
from app.db import get_db, TrainingRun, init_db
from app.preprocessing import preprocess_candles, extract_feature_columns
from app.patterns import (
TALIB_PATTERNS,
get_available_patterns,
validate_pattern_names,
detect_patterns,
)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# --- API Key Dependency ---
async def verify_api_key(x_api_key: str = Header(default="")):
"""Verify X-API-Key header against API_KEY env var. Fail-open if not configured."""
api_key = os.getenv("API_KEY")
if not api_key:
return # fail-open if not configured
if x_api_key != api_key:
raise HTTPException(status_code=401, detail="Unauthorized")
# --- User ID Header Dependency ---
async def get_user_id(x_user_id: str = Header(default="")) -> Optional[str]:
"""
Extract X-User-ID header from incoming requests.
Returns the user ID if present, or None if not provided.
This is optional — callers decide whether to enforce it.
"""
return x_user_id if x_user_id else None
# --- Lifespan ---
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
FastAPI lifespan context manager.
Handles startup and shutdown logic.
"""
# --- Startup ---
logger.info("Starting inference service...")
# Ensure training_runs table exists
init_db()
# Mark any stale "running" records as failed — they belong to a previous
# process and will never complete.
try:
with get_db() as db:
stmt = (
sa_update(TrainingRun)
.where(TrainingRun.status == "running")
.values(
status="failed",
completed_at=datetime.now(timezone.utc),
metrics_summary={"error": "Service restarted while training was in progress"},
)
)
result = db.execute(stmt)
db.commit()
if result.rowcount:
logger.warning(
f"Marked {result.rowcount} stale training run(s) as failed on startup"
)
except Exception as exc:
logger.error(f"Failed to clean up stale training runs: {exc}")
# Load pipeline config
config_path = Path("config/pipeline.yaml")
if config_path.exists():
try:
state.pipeline_config = load_config(config_path)
logger.info(f"Loaded pipeline config from {config_path}")
# Load model based on config
inference_config = state.pipeline_config.stages.inference
if not inference_config.enabled:
logger.warning("Inference stage is disabled in config")
else:
# Load model
try:
if inference_config.model_source == "mlflow":
# Configure MLflow tracking URI
mlflow.set_tracking_uri(state.pipeline_config.stages.training.mlflow.tracking_uri)
state.model, state.model_info = load_model_from_mlflow(
model_name=inference_config.mlflow_model_name,
stage=inference_config.mlflow_model_stage
)
elif inference_config.model_source == "local":
state.model, state.model_info = load_model_from_local(
model_path=inference_config.local_model_path
)
else:
logger.error(f"Unknown model_source: {inference_config.model_source}")
state.feature_columns = state.model_info.get("feature_columns")
logger.info("Model loaded successfully")
logger.info(f"Model info: {state.model_info['model_name']} "
f"v{state.model_info['model_version']} "
f"({state.model_info['model_type']})")
except Exception as e:
logger.error(f"Failed to load model: {e}")
logger.warning("Service will start without a model. Use /health to check status.")
except Exception as e:
logger.error(f"Failed to load pipeline config: {e}")
else:
logger.warning(f"Config file not found: {config_path}. Using defaults.")
yield
# --- Shutdown (nothing to do currently) ---
# FastAPI app
app = FastAPI(
title="Candle Pattern Inference API",
description="ML inference service for candlestick pattern recognition",
version="1.0.0",
lifespan=lifespan
)
# Parse CORS origins from environment variable or use default
cors_origins_str = os.getenv("CORS_ORIGINS", "http://localhost:3000")
allow_origins = [origin.strip() for origin in cors_origins_str.split(",")]
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=allow_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global state
class AppState:
"""Application state container."""
model: Optional[Any] = None
model_info: Optional[Dict[str, Any]] = None
pipeline_config: Optional[PipelineConfig] = None
feature_columns: Optional[List[str]] = None
label_encoder: Optional[Dict[int, str]] = None
# Training thread management
active_training_run_id: Optional[str] = None
training_lock: threading.Lock = None
def __init__(self):
self.training_lock = threading.Lock()
state = AppState()
def _get_model_n_features(model: Any) -> Optional[int]:
"""
Extract expected feature count from a model (supports wrappers).
"""
if hasattr(model, "n_features_in_"):
try:
return int(model.n_features_in_)
except Exception:
return None
if hasattr(model, "model") and hasattr(model.model, "n_features_in_"):
try:
return int(model.model.n_features_in_)
except Exception:
return None
return None
# --- Pydantic Models ---
class CandleData(BaseModel):
"""Single candle data point."""
time: int = Field(..., description="Unix timestamp in seconds")
open: float
high: float
low: float
close: float
volume: Optional[float] = None
class PredictRequest(BaseModel):
"""Request model for /predict endpoint."""
pair: Optional[str] = Field(None, description="Trading pair (e.g., EURUSD)")
timeframe: Optional[str] = Field(None, description="Timeframe (e.g., 1H, 4H, 1D)")
candles: List[CandleData] = Field(..., min_length=1, description="Array of candle data")
class PredictionResult(BaseModel):
"""Single candle prediction."""
time: int
label: str
confidence: float = Field(..., ge=0.0, le=1.0)
class PredictionSpan(BaseModel):
"""Grouped prediction span."""
start_time: int
end_time: int
label: str
avg_confidence: float = Field(..., ge=0.0, le=1.0)
candle_count: int
class ModelInfo(BaseModel):
"""Model metadata."""
model_name: str
model_version: Optional[str] = None
model_type: str
trained_at: Optional[str] = None
dataset_version: Optional[str] = None
feature_engineering_enabled: bool
class PredictResponse(BaseModel):
"""Response model for /predict endpoint."""
predictions: List[PredictionResult]
spans: List[PredictionSpan]
model_info: ModelInfo
class BatchPredictRequest(BaseModel):
"""Request model for /predict/batch endpoint."""
pair: str
timeframe: str
start_date: str = Field(..., description="ISO format date (YYYY-MM-DD)")
end_date: str = Field(..., description="ISO format date (YYYY-MM-DD)")
class LabelInfo(BaseModel):
"""Pattern label with display color."""
name: str
color: str
class PerClassMetrics(BaseModel):
"""Per-class performance metrics."""
label: str
precision: float
recall: float
f1_score: float
support: int
class ModelInfoResponse(BaseModel):
"""Extended model information with metrics."""
model_name: str
model_version: Optional[str] = None
model_type: str
trained_at: Optional[str] = None
dataset_version: Optional[str] = None
feature_engineering_enabled: bool
labels: List[str]
per_class_metrics: List[PerClassMetrics]
class HealthResponse(BaseModel):
"""Health check response."""
status: str = Field(..., description="healthy, degraded, or unhealthy")
model_loaded: bool
mlflow: str = Field(..., description="healthy or unhealthy")
database: str = Field(..., description="healthy or unhealthy")
# --- Model Loading Functions ---
def load_model_from_mlflow(model_name: str, stage: str) -> tuple[Any, Dict[str, Any]]:
"""
Load model from MLflow model registry.
Args:
model_name: Name of the registered model
stage: Model stage (Production, Staging, None)
Returns:
Tuple of (model, model_info)
Raises:
Exception: If model not found or MLflow unavailable
"""
logger.info(f"Loading model from MLflow: {model_name} stage={stage}")
try:
# Load model from registry
model_uri = f"models:/{model_name}/{stage}"
model = mlflow.pyfunc.load_model(model_uri)
# Get model version details
client = mlflow.tracking.MlflowClient()
model_versions = client.get_latest_versions(model_name, stages=[stage])
if not model_versions:
raise ValueError(f"No model found for {model_name} at stage {stage}")
model_version = model_versions[0]
run_id = model_version.run_id
# Get run metadata
run = client.get_run(run_id)
# Extract model info
model_info = {
"model_name": model_name,
"model_version": model_version.version,
"model_type": run.data.params.get("model_type", "unknown"),
"trained_at": datetime.fromtimestamp(run.info.start_time / 1000).isoformat(),
"dataset_version": run.data.params.get("dataset_version", None),
"feature_engineering_enabled": run.data.params.get("feature_engineering", "true") == "true",
"labels": json.loads(run.data.params.get("labels", "[]")),
"per_class_metrics": []
}
# Extract per-class metrics
for key, value in run.data.metrics.items():
if key.startswith("precision_"):
label = key.replace("precision_", "")
recall = run.data.metrics.get(f"recall_{label}", 0.0)
f1 = run.data.metrics.get(f"f1_{label}", 0.0)
support = int(run.data.params.get(f"support_{label}", 0))
model_info["per_class_metrics"].append({
"label": label,
"precision": value,
"recall": recall,
"f1_score": f1,
"support": support
})
logger.info(f"Successfully loaded model: {model_name} v{model_version.version}")
return model, model_info
except Exception as e:
logger.error(f"Failed to load model from MLflow: {e}")
raise
def verify_model_checksum(model_path: Path) -> None:
"""
Verify model file integrity against SHA256 checksum manifest.
Looks for checksums in models/checksums.sha256. If the manifest exists and
contains an entry for the model file, verifies the SHA256 hash matches.
Raises HTTPException on mismatch. Logs a warning and continues if the
manifest is missing or the file is not listed (fail-open for backward
compatibility).
Args:
model_path: Path to the model file to verify.
Raises:
HTTPException: If checksum verification fails (hash mismatch).
"""
manifest_path = Path("models/checksums.sha256")
model_path = Path(model_path)
if not manifest_path.exists():
logger.warning("Model checksum manifest not found at %s — skipping integrity check", manifest_path)
return
# Parse manifest: each line is "<sha256hash> <filename>"
checksums: Dict[str, str] = {}
try:
for line in manifest_path.read_text().splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
parts = line.split(None, 1)
if len(parts) == 2:
checksums[parts[1].strip()] = parts[0].strip()
except Exception as e:
logger.warning("Failed to read checksum manifest: %s — skipping integrity check", e)
return
filename = model_path.name
if filename not in checksums:
logger.warning("No checksum entry for '%s' in manifest — skipping integrity check", filename)
return
# Compute SHA256 of the actual file
sha256 = hashlib.sha256()
try:
with open(model_path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
sha256.update(chunk)
except Exception as e:
logger.error("Failed to compute SHA256 for '%s': %s", model_path, e)
raise HTTPException(status_code=500, detail="Internal server error")
actual_hash = sha256.hexdigest()
expected_hash = checksums[filename]
if actual_hash != expected_hash:
logger.error(
"Model integrity check FAILED for '%s': expected %s, got %s",
filename, expected_hash, actual_hash
)
raise HTTPException(status_code=500, detail="Internal server error")
logger.info("Model integrity check passed for '%s'", filename)
def load_model_from_local(model_path: str) -> tuple[Any, Dict[str, Any]]:
"""
Load model from local file using joblib.
Args:
model_path: Path to .pkl model file
Returns:
Tuple of (model, model_info)
Raises:
FileNotFoundError: If model file doesn't exist
Exception: If model loading fails
"""
model_path = Path(model_path)
logger.info(f"Loading model from local file: {model_path}")
if not model_path.exists():
raise FileNotFoundError(f"Model file not found: {model_path}")
# Verify model integrity before loading
verify_model_checksum(model_path)
try:
# Load model with joblib
model_data = joblib.load(model_path)
# Extract model and metadata
if isinstance(model_data, dict):
model = model_data.get("model")
metadata = model_data.get("metadata", {})
else:
model = model_data
metadata = {}
# Extract labels from model if not in metadata
labels = metadata.get("labels", [])
if not labels:
# Try to get class labels from the model itself
if hasattr(model, 'classes_'):
labels = [str(c) for c in model.classes_]
elif hasattr(model, 'model') and hasattr(model.model, 'classes_'):
labels = [str(c) for c in model.model.classes_]
# Build model info
model_info = {
"model_name": model_path.stem,
"model_version": metadata.get("version", "local"),
"model_type": metadata.get("model_type", "unknown"),
"trained_at": metadata.get("trained_at", None),
"dataset_version": metadata.get("dataset_version", None),
"feature_engineering_enabled": metadata.get("feature_engineering_enabled", True),
"labels": labels,
"per_class_metrics": metadata.get("per_class_metrics", []),
"feature_columns": metadata.get("feature_columns", None),
}
logger.info(f"Successfully loaded local model: {model_path.name}")
return model, model_info
except Exception as e:
logger.error(f"Failed to load model from local file: {e}")
raise
# --- Health Check ---
@app.get("/health", response_model=HealthResponse)
async def health_check(response: Response):
"""
Health check endpoint.
Returns service status, model loaded status, and dependency health.
Returns HTTP 503 if any component (database or MLflow) is unhealthy.
"""
model_loaded = state.model is not None
# Check database connection with a real SELECT 1
db_status = "unhealthy"
try:
with get_db() as db:
db.execute(sa_text("SELECT 1"))
db_status = "healthy"
except Exception as exc:
logger.warning(f"Health check: database unhealthy: {exc}")
# Check MLflow connection via HTTP GET to its health endpoint
mlflow_status = "unhealthy"
try:
mlflow_tracking_uri = os.getenv("MLFLOW_TRACKING_URI", "http://localhost:5000")
mlflow_health_url = f"{mlflow_tracking_uri.rstrip('/')}/health"
resp = http_requests.get(mlflow_health_url, timeout=3)
if resp.status_code == 200:
mlflow_status = "healthy"
else:
logger.warning(
f"Health check: MLflow returned HTTP {resp.status_code}"
)
except Exception as exc:
logger.warning(f"Health check: MLflow unreachable: {exc}")
# Determine overall status
if db_status == "healthy" and mlflow_status == "healthy":
overall_status = "healthy"
elif model_loaded:
overall_status = "degraded"
else:
overall_status = "unhealthy"
# Return HTTP 503 if any dependency is unhealthy
if db_status != "healthy" or mlflow_status != "healthy":
response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
return HealthResponse(
status=overall_status,
model_loaded=model_loaded,
mlflow=mlflow_status,
database=db_status,
)
# --- Model Info Endpoints (Stubs) ---
@app.get("/model/info", response_model=ModelInfoResponse, dependencies=[Depends(verify_api_key)])
async def get_model_info():
"""
Get detailed model information and per-class metrics.
Returns HTTP 503 if no model is loaded.
"""
if state.model is None or state.model_info is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="No model available"
)
# Build response from loaded model info
return ModelInfoResponse(
model_name=state.model_info["model_name"],
model_version=state.model_info.get("model_version"),
model_type=state.model_info["model_type"],
trained_at=state.model_info.get("trained_at"),
dataset_version=state.model_info.get("dataset_version"),
feature_engineering_enabled=state.model_info["feature_engineering_enabled"],
labels=state.model_info.get("labels", []),
per_class_metrics=[
PerClassMetrics(**metric)
for metric in state.model_info.get("per_class_metrics", [])
]
)
@app.get("/model/labels", response_model=List[LabelInfo], dependencies=[Depends(verify_api_key)])
async def get_model_labels():
"""
Get all pattern labels the model can predict with display colors.
"""
if state.model is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="No model available"
)
# TODO: Load label colors from config or database
# For now, return placeholder
labels = state.model_info.get("labels", []) if state.model_info else []
return [
LabelInfo(name=label, color="#888888")
for label in labels
]
# --- Prediction Helper Functions ---
def group_prediction_spans(
predictions: List[PredictionResult]
) -> List[PredictionSpan]:
"""
Group consecutive predictions with the same label into spans.
Only groups non-"O" (non-background) predictions. "O" predictions are
treated as background and not grouped into spans.
Args:
predictions: List of per-candle predictions
Returns:
List of prediction spans
"""
if not predictions:
return []
spans = []
current_span = None
for pred in predictions:
# Skip "O" (background) predictions
if pred.label == "O":
if current_span is not None:
# End current span
spans.append(current_span)
current_span = None
continue
# Start new span or continue current span
if current_span is None or current_span.label != pred.label:
# End previous span if exists
if current_span is not None:
spans.append(current_span)
# Start new span
current_span = PredictionSpan(
start_time=pred.time,
end_time=pred.time,
label=pred.label,
avg_confidence=pred.confidence,
candle_count=1
)
else:
# Continue current span
current_span.end_time = pred.time
current_span.candle_count += 1
# Update running average confidence
total_confidence = current_span.avg_confidence * (current_span.candle_count - 1)
current_span.avg_confidence = (total_confidence + pred.confidence) / current_span.candle_count
# Don't forget last span
if current_span is not None:
spans.append(current_span)
logger.info(f"Grouped {len(predictions)} predictions into {len(spans)} spans")
return spans
# --- Prediction Endpoints ---
@app.post("/predict", response_model=PredictResponse, dependencies=[Depends(verify_api_key)])
async def predict(request: PredictRequest):
"""
Predict candlestick patterns for provided candles.
Accepts OHLCV candles, runs preprocessing, and returns predictions.
"""
if state.model is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="No model available"
)
if not request.candles:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Candles array cannot be empty"
)
if state.pipeline_config is None:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Pipeline configuration not loaded"
)
logger.info(f"Predict request: {request.pair or 'unknown'} {request.timeframe or 'unknown'}, {len(request.candles)} candles")
# Auto-sort candles by time in ascending order to ensure chronological processing
# regardless of the order in which the client sends them.
candles_sorted = sorted(request.candles, key=lambda c: c.time)
# Validate that there are no duplicate timestamps after sorting.
times = [c.time for c in candles_sorted]
duplicate_times = [t for i, t in enumerate(times) if i > 0 and t == times[i - 1]]
if duplicate_times:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Duplicate candle timestamps detected: {duplicate_times[:5]}"
)
# Grab model reference under lock to prevent reading a partially-swapped model
with _model_swap_lock:
current_model = state.model
current_model_info = state.model_info
if current_model is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="No model available"
)
try:
# Convert sorted candles to list of dicts
candles_data = [candle.model_dump() for candle in candles_sorted]
# Preprocess candles (feature engineering + windowing)
training_feature_columns = (
current_model_info.get("feature_columns") if current_model_info else None
)
X, window_times = preprocess_candles(
candles_data,
state.pipeline_config,
training_feature_columns=training_feature_columns
)
expected_n_features = _get_model_n_features(current_model)
if expected_n_features is not None and X.shape[1] != expected_n_features:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=(
f"Feature mismatch: model expects {expected_n_features} features, "
f"but preprocessing produced {X.shape[1]}. "
"Ensure the loaded model matches the inference preprocessing config."
),
)
# Get predictions and probabilities (using local reference, outside lock)
if hasattr(current_model, 'predict_proba'):
y_pred = current_model.predict(X)
y_proba = current_model.predict_proba(X)
# Get confidence as max probability
confidences = np.max(y_proba, axis=1)
else:
# Fallback for models without predict_proba
y_pred = current_model.predict(X)
confidences = np.ones(len(y_pred)) # Default confidence of 1.0
# Get label names (handle both string and int predictions)
if state.label_encoder is not None:
labels = [state.label_encoder.get(int(pred), f"unknown_{pred}") for pred in y_pred]
else:
labels = [str(pred) for pred in y_pred]
# Build per-window predictions (each window maps to its last candle time)
predictions = [
PredictionResult(
time=int(time),
label=label,
confidence=float(conf)
)
for time, label, conf in zip(window_times, labels, confidences)
]
# Group into spans
spans = group_prediction_spans(predictions)
# Build model info for response
model_info = ModelInfo(
model_name=current_model_info["model_name"],
model_version=current_model_info.get("model_version"),
model_type=current_model_info["model_type"],
trained_at=current_model_info.get("trained_at"),
dataset_version=current_model_info.get("dataset_version"),
feature_engineering_enabled=current_model_info["feature_engineering_enabled"]
)
logger.info(
f"Prediction complete: {len(predictions)} windows, "
f"{len(spans)} spans, {len([p for p in predictions if p.label != 'O'])} patterns"
)
return PredictResponse(
predictions=predictions,
spans=spans,
model_info=model_info
)
except ValueError as e:
logger.error(f"Prediction validation error: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
logger.error(f"Prediction failed: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal server error"
)
@app.post("/predict/batch", response_model=PredictResponse, dependencies=[Depends(verify_api_key)])
async def predict_batch(request: BatchPredictRequest):
"""
Batch prediction for a date range.
Loads data from storage and returns predictions for the full range.
"""
if state.model is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="No model available"
)
if state.pipeline_config is None:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Pipeline configuration not loaded"
)
# Validate date range
try:
start_dt = datetime.strptime(request.start_date, "%Y-%m-%d")
end_dt = datetime.strptime(request.end_date, "%Y-%m-%d")
except ValueError as exc:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid date format. Use YYYY-MM-DD",
)
if end_dt < start_dt:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="end_date must be after start_date",
)
if (end_dt - start_dt).days > 365:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Date range cannot exceed 1 year",
)
logger.info(
f"Batch predict: {request.pair} {request.timeframe} "
f"from {request.start_date} to {request.end_date}"
)
# Grab model reference under lock to prevent reading a partially-swapped model
with _model_swap_lock:
current_model = state.model
current_model_info = state.model_info
if current_model is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="No model available"
)
try:
# Load OHLCV data from raw data path
raw_path = Path(state.pipeline_config.data.raw_path)
if not raw_path.exists():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No data found for {request.pair} {request.timeframe}"
)
# Load data
df_raw = pd.read_csv(raw_path)
# Filter by date range if time column exists
if 'time' in df_raw.columns:
# Parse dates to timestamps
start_ts = int(pd.Timestamp(request.start_date).timestamp())
end_ts = int(pd.Timestamp(request.end_date).timestamp())
# Filter data
df_filtered = df_raw[
(df_raw['time'] >= start_ts) & (df_raw['time'] <= end_ts)
].copy()
if len(df_filtered) == 0:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No data found in date range {request.start_date} to {request.end_date}"
)
logger.info(f"Loaded {len(df_filtered)} candles from {request.start_date} to {request.end_date}")
else:
df_filtered = df_raw
logger.warning("No 'time' column found, using all data")
# Get batch size from config
batch_size = state.pipeline_config.stages.inference.batch_size
# Process in batches
all_predictions = []
all_spans = []
for i in range(0, len(df_filtered), batch_size):
batch_df = df_filtered.iloc[i:i + batch_size]
logger.info(f"Processing batch {i // batch_size + 1}: rows {i} to {i + len(batch_df)}")
# Convert batch to candles format
batch_candles = batch_df.to_dict('records')
# Preprocess (feature engineering + windowing)
training_feature_columns = (
current_model_info.get("feature_columns") if current_model_info else None
)
X, window_times = preprocess_candles(
batch_candles,
state.pipeline_config,
training_feature_columns=training_feature_columns
)
expected_n_features = _get_model_n_features(current_model)
if expected_n_features is not None and X.shape[1] != expected_n_features:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=(
f"Feature mismatch: model expects {expected_n_features} features, "
f"but preprocessing produced {X.shape[1]}. "
"Ensure the loaded model matches the inference preprocessing config."
),
)
# Predict (using local reference, outside lock)
if hasattr(current_model, 'predict_proba'):
y_pred = current_model.predict(X)
y_proba = current_model.predict_proba(X)
confidences = np.max(y_proba, axis=1)
else:
y_pred = current_model.predict(X)
confidences = np.ones(len(y_pred))
# Get labels
if state.label_encoder is not None:
labels = [state.label_encoder.get(int(pred), f"unknown_{pred}") for pred in y_pred]
else:
labels = [str(pred) for pred in y_pred]
# Build predictions for this batch
batch_predictions = [
PredictionResult(
time=int(time),
label=label,
confidence=float(conf)
)
for time, label, conf in zip(window_times, labels, confidences)
]
all_predictions.extend(batch_predictions)
# Group all predictions into spans
all_spans = group_prediction_spans(all_predictions)
# Build model info
model_info = ModelInfo(
model_name=current_model_info["model_name"],
model_version=current_model_info.get("model_version"),
model_type=current_model_info["model_type"],
trained_at=current_model_info.get("trained_at"),
dataset_version=current_model_info.get("dataset_version"),
feature_engineering_enabled=current_model_info["feature_engineering_enabled"]
)
logger.info(
f"Batch prediction complete: {len(all_predictions)} candles, "
f"{len(all_spans)} spans"
)
return PredictResponse(
predictions=all_predictions,
spans=all_spans,
model_info=model_info
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Batch prediction failed: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal server error"
)
# ---------------------------------------------------------------------------
# Pattern Detection Endpoints
# ---------------------------------------------------------------------------
class PatternInfo(BaseModel):
"""A single supported CDL pattern."""
function_name: str
display_name: str
class DetectPatternsRequest(BaseModel):
"""Request model for POST /patterns/detect."""
candles: List[CandleData] = Field(..., min_length=1, description="Array of candle data")
patterns: List[str] = Field(
default=[],
description="CDL function names to run. Empty list means run all.",
)
class SpanAnnotation(BaseModel):
"""A span annotation returned by pattern detection."""
start_time: int
end_time: int
label: str
confidence: float = Field(..., ge=0.0, le=1.0)
source: str
notes: str
class DetectPatternsResponse(BaseModel):
"""Response model for POST /patterns/detect."""
annotations: List[SpanAnnotation]
metadata: Dict[str, Any]
@app.get("/patterns/available", response_model=List[PatternInfo], dependencies=[Depends(verify_api_key)])
async def patterns_available():
"""
Return all supported CDL pattern names with display names.
"""
return get_available_patterns()
@app.post("/patterns/detect", response_model=DetectPatternsResponse, dependencies=[Depends(verify_api_key)])
async def patterns_detect(request: DetectPatternsRequest):
"""
Detect TA-Lib CDL patterns on provided candle data.
- Empty ``patterns`` list runs all available CDL functions.
- Invalid pattern names return HTTP 400.
"""
# 1.4 Validate pattern names
if request.patterns:
invalid = validate_pattern_names(request.patterns)
if invalid:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid pattern name(s): {', '.join(invalid)}. "
f"Use GET /patterns/available to see supported patterns.",
)
candles_data = [c.model_dump() for c in request.candles]
try:
raw_annotations = detect_patterns(candles_data, request.patterns or None)
except RuntimeError as exc:
logger.error(f"Pattern detection runtime error: {exc}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Internal server error",
)
annotations = [SpanAnnotation(**ann) for ann in raw_annotations]
return DetectPatternsResponse(
annotations=annotations,
metadata={"source": "talib", "count": len(annotations)},
)
# ---------------------------------------------------------------------------
# Training Endpoints
# ---------------------------------------------------------------------------
SUPPORTED_MODEL_TYPES = ["random_forest", "xgboost"]
class TrainingStartRequest(BaseModel):
"""Request model for POST /training/start."""
model_type: str = Field(
"random_forest",
description="Model type: random_forest or xgboost",
)
chart_id: Optional[int] = Field(
default=None,
ge=1,
description="Chart ID to train on. If omitted, falls back to first chart.",
)
class TrainingStartResponse(BaseModel):
"""Response model for POST /training/start."""
run_id: str
status: str
class TrainingRunInfo(BaseModel):
"""Summary of a single training run."""
run_id: str
model_type: str
status: str
experiment_name: Optional[str] = None
created_at: Optional[str] = None
completed_at: Optional[str] = None
metrics_summary: Optional[Dict[str, Any]] = None
class TrainingRunsResponse(BaseModel):
"""Response model for GET /training/runs."""
runs: List[TrainingRunInfo]
class DatasetInfoResponse(BaseModel):
"""Response model for GET /training/dataset-info."""
path: str
exists: bool
size_bytes: Optional[int] = None
last_modified: Optional[str] = None
row_count: Optional[int] = None
def build_dataset_from_db(config: PipelineConfig, chart_id: Optional[int] = None) -> dict:
"""
Build the labeled training dataset directly from the database.
Steps:
1. Export candles from PostgreSQL to raw CSV
2. Run feature engineering (TA-Lib indicators, candle features)
3. Run annotation ingestion from DB (span_annotations -> labeled CSV)
Returns:
dict with keys: chart_name, n_candles, n_annotations, n_samples, labeled_path
"""
from app.data_access import DataAccess
from app.annotation_ingestion import AnnotationIngestion
from features.engineer import run_feature_engineering_stage
data_access = DataAccess()
# Resolve target chart
if chart_id is not None:
chart = data_access.get_chart_by_id(chart_id)
if not chart:
raise ValueError(f"Chart not found: id={chart_id}")
chart_name = chart["name"]
chart_id_int = int(chart["id"])
else:
charts_df = data_access.get_all_charts()
if charts_df.empty:
raise ValueError("No charts found in database. Upload candle data first.")
chart = charts_df.iloc[0]
chart_name = chart["name"]
chart_id_int = int(chart["id"])
logger.info(f"Building dataset for chart: {chart_name} (id={chart_id_int})")
# Step 1: Export candles to raw CSV
candles_df = data_access.get_candles(chart_id_int)
if candles_df.empty:
raise ValueError(f"No candles found for chart: {chart_name}")
raw_path = Path(config.data.raw_path)
raw_path.parent.mkdir(parents=True, exist_ok=True)
# Ensure 'time' column is suitable for feature engineering
export_df = candles_df[["time", "open", "high", "low", "close"]].copy()
export_df.to_csv(raw_path, index=False)
logger.info(f"Exported {len(export_df)} candles to {raw_path}")
# Step 2: Run feature engineering
run_feature_engineering_stage(config)
enriched_path = Path(config.data.enriched_path)
logger.info(f"Feature engineering complete: {enriched_path}")
# Step 3: Run annotation ingestion from database
enriched_df = pd.read_csv(enriched_path, parse_dates=["time"])
ingestion = AnnotationIngestion(config.stages.annotation_ingestion)
# Include all annotation sources so TA-Lib generated spans (source='talib')
# can be used for training alongside manual labels.
labeled_df = ingestion.process_from_db(enriched_df, chart_name, source=None)
if labeled_df.empty:
raise ValueError(
f"No labeled samples produced. "
f"Ensure you have span annotations on chart '{chart_name}'."
)
# Write labeled dataset
labeled_path = Path(config.data.labeled_path)
labeled_path.parent.mkdir(parents=True, exist_ok=True)
labeled_df.to_csv(labeled_path, index=False)
result = {
"chart_name": chart_name,
"n_candles": len(export_df),
"n_samples": len(labeled_df),
"n_features": len([c for c in labeled_df.columns if c != "label"]),
"labeled_path": str(labeled_path),
}
logger.info(f"Dataset built: {result}")
return result
def _run_training_background(
run_id: str,
model_type: str,
config: PipelineConfig,
chart_id: Optional[int] = None,
) -> None:
"""
Background thread target: build dataset then train a model.
Uses the pre-inserted TrainingRun record identified by ``run_id``.
"""
logger.info(f"Training thread started: run_id={run_id}, model_type={model_type}")
try:
# Import training utilities here to avoid circular import issues
from training.train import create_model, temporal_split
from sklearn.metrics import accuracy_score, f1_score
# Build dataset from database (feature engineering + annotation ingestion)
logger.info("Building dataset from database...")
build_dataset_from_db(config, chart_id=chart_id)
labeled_path = Path(config.data.labeled_path)
if not labeled_path.exists():
raise FileNotFoundError(f"Labeled dataset not found: {labeled_path}")
# Load dataset
df = pd.read_csv(labeled_path)
if "label" not in df.columns:
raise ValueError("Labeled dataset must have 'label' column")
# Check dataset size: reject if it exceeds 500MB in memory
_DATASET_SIZE_LIMIT_BYTES = 500 * 1024 * 1024 # 500 MB
dataset_size_bytes = df.memory_usage(deep=True).sum()
if dataset_size_bytes > _DATASET_SIZE_LIMIT_BYTES:
raise ValueError(
f"Dataset too large. Maximum size is 500MB "
f"(current size: {dataset_size_bytes / (1024 * 1024):.1f}MB)."
)
logger.info(
f"Dataset size check passed: {dataset_size_bytes / (1024 * 1024):.1f}MB"
)
feature_cols = [
col for col in df.columns
if col not in ("label", "time", "timestamp")
and not col.startswith("label_programmatic_")
]
X = df[feature_cols].values
y = df["label"].values
logger.info(f"Loaded {len(X)} samples, {len(feature_cols)} features")
# Split data
training_cfg = config.stages.training
X_train, X_val, X_test, y_train, y_val, y_test = temporal_split(
X, y, training_cfg.test_split, training_cfg.validation_split
)
# Run model training with a 30-minute timeout
_TRAINING_TIMEOUT_SECONDS = 1800 # 30 minutes
def _do_train():
_model = create_model(
model_type, training_cfg.hyperparameters, training_cfg.class_weights
)
_model.fit(X_train, y_train)
return _model
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(_do_train)
try:
model_instance = future.result(timeout=_TRAINING_TIMEOUT_SECONDS)
except concurrent.futures.TimeoutError:
logger.error(
f"Training timed out after 30 minutes: run_id={run_id}"
)
try:
with get_db() as db:
stmt = (
sa_update(TrainingRun)
.where(TrainingRun.run_id == run_id)
.values(
status="failed",
completed_at=datetime.now(timezone.utc),
metrics_summary={
"error": "Training timed out after 30 minutes"
},
)
)
db.execute(stmt)
db.commit()
except Exception as db_exc:
logger.error(
f"Failed to update DB for timed-out run {run_id}: {db_exc}"
)
return
logger.info("Model training complete")
# Evaluate
y_val_pred = model_instance.predict(X_val)
y_test_pred = model_instance.predict(X_test)
metrics = {
"val_accuracy": float(accuracy_score(y_val, y_val_pred)),
"val_f1_macro": float(
f1_score(y_val, y_val_pred, average="macro", zero_division=0)
),
"test_accuracy": float(accuracy_score(y_test, y_test_pred)),
"test_f1_macro": float(
f1_score(y_test, y_test_pred, average="macro", zero_division=0)
),
"n_samples": int(len(X)),
"n_features": int(X.shape[1]),
}
# Save model locally
models_dir = Path("models")
models_dir.mkdir(exist_ok=True)
model_path = models_dir / f"{run_id}.pkl"
model_data = {
"model": model_instance,
"metadata": {
"model_type": model_type,
"trained_at": datetime.now(timezone.utc).isoformat(),
"run_id": run_id,
"feature_columns": feature_cols,
"labels": (
[str(c) for c in model_instance.model.classes_]
if hasattr(model_instance, "model") and hasattr(model_instance.model, "classes_")
else []
),
},
}
joblib.dump(model_data, model_path)
logger.info(f"Model saved to {model_path}")
# Update DB: completed
with get_db() as db:
stmt = (
sa_update(TrainingRun)
.where(TrainingRun.run_id == run_id)
.values(
status="completed",
completed_at=datetime.now(timezone.utc),
metrics_summary=metrics,
)
)
db.execute(stmt)
db.commit()
logger.info(f"Training completed successfully: run_id={run_id}")
except Exception as exc:
logger.error(f"Training failed for run_id={run_id}: {exc}", exc_info=True)
try:
with get_db() as db:
stmt = (
sa_update(TrainingRun)
.where(TrainingRun.run_id == run_id)
.values(
status="failed",
completed_at=datetime.now(timezone.utc),
metrics_summary={"error": str(exc)},
)
)
db.execute(stmt)
db.commit()
except Exception as db_exc:
logger.error(f"Failed to update DB for failed run {run_id}: {db_exc}")
finally:
with state.training_lock:
if state.active_training_run_id == run_id:
state.active_training_run_id = None
logger.info(f"Training thread exiting: run_id={run_id}")
@app.post("/training/start", response_model=TrainingStartResponse, dependencies=[Depends(verify_api_key)])
async def training_start(request: TrainingStartRequest):
"""
Start a training run in a background thread.
Returns immediately with run_id and status "running".
Rejects concurrent runs with HTTP 409.
"""
# Validate model type
if request.model_type not in SUPPORTED_MODEL_TYPES:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported model type. Available: {', '.join(SUPPORTED_MODEL_TYPES)}",
)
if request.chart_id is not None:
from app.data_access import DataAccess
chart = DataAccess().get_chart_by_id(request.chart_id)
if not chart:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Chart not found: id={request.chart_id}",
)
# Reject concurrent runs (atomic check-and-set)
with state.training_lock:
if state.active_training_run_id is not None:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail={
"error": "Training already in progress",
"run_id": state.active_training_run_id,
},
)
run_id = str(uuid.uuid4())
state.active_training_run_id = run_id
config = state.pipeline_config or get_default_config()
# Compute config hash (best-effort)
config_hash = "unknown"
try:
from training.train import compute_config_hash
config_hash = compute_config_hash(config)
except Exception:
pass
# Pre-insert the run record so callers can track it immediately
try:
with get_db() as db:
training_run = TrainingRun(
run_id=run_id,
model_type=request.model_type,
experiment_name=config.stages.training.mlflow.experiment_name,
pipeline_config_hash=config_hash,
status="running",
created_at=datetime.now(timezone.utc),
metrics_summary={},
)
db.add(training_run)
db.commit()
except Exception as exc:
with state.training_lock:
state.active_training_run_id = None
logger.error(f"Failed to insert training run record: {exc}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal server error",
)
# Launch background thread (daemon so it doesn't block process exit)
thread = threading.Thread(
target=_run_training_background,
args=(run_id, request.model_type, config, request.chart_id),
daemon=True,
name=f"training-{run_id[:8]}",
)
thread.start()
logger.info(f"Training started: run_id={run_id}, model_type={request.model_type}")
return TrainingStartResponse(run_id=run_id, status="running")
@app.get("/training/runs", response_model=TrainingRunsResponse, dependencies=[Depends(verify_api_key)])
async def training_runs():
"""
Return training run history from the database, sorted by date descending.
"""
try:
from sqlalchemy import select
with get_db() as db:
stmt = select(TrainingRun).order_by(desc(TrainingRun.created_at))
rows = db.execute(stmt).scalars().all()
runs = [
TrainingRunInfo(
run_id=row.run_id,
model_type=row.model_type,
status=row.status,
experiment_name=row.experiment_name,
created_at=row.created_at.isoformat() if row.created_at else None,
completed_at=row.completed_at.isoformat() if row.completed_at else None,
metrics_summary=row.metrics_summary,
)
for row in rows
]
except Exception as exc:
logger.error(f"Failed to fetch training runs: {exc}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal server error",
)
return TrainingRunsResponse(runs=runs)
class ActiveTrainingResponse(BaseModel):
"""Response model for GET /training/active."""
active: bool
run_id: Optional[str] = None
@app.get("/training/active", response_model=ActiveTrainingResponse, dependencies=[Depends(verify_api_key)])
async def training_active():
"""
Return whether a training run is currently active and its run_id.
"""
with state.training_lock:
run_id = state.active_training_run_id
return ActiveTrainingResponse(active=run_id is not None, run_id=run_id)
class DeleteRunResponse(BaseModel):
"""Response model for DELETE /training/runs/{run_id}."""
run_id: str
deleted: bool
@app.delete("/training/runs/{run_id}", response_model=DeleteRunResponse, dependencies=[Depends(verify_api_key)])
async def delete_training_run(run_id: str):
"""
Delete a training run record and its model artifact.
Returns HTTP 400 if the run_id format is invalid.
Returns HTTP 409 if the run is currently active.
Returns HTTP 404 if the run_id doesn't exist.
"""
from sqlalchemy import select, delete as sa_delete
# Validate run_id format to prevent path traversal
if not re.match(r'^[a-zA-Z0-9_-]+$', run_id):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid run_id format",
)
# Reject deletion of the active run
with state.training_lock:
if state.active_training_run_id == run_id:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Cannot delete an active training run",
)
try:
with get_db() as db:
stmt = select(TrainingRun).where(TrainingRun.run_id == run_id)
row = db.execute(stmt).scalar_one_or_none()
if row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Training run not found: {run_id}",
)
db.execute(sa_delete(TrainingRun).where(TrainingRun.run_id == run_id))
db.commit()
except HTTPException:
raise
except Exception as exc:
logger.error(f"Failed to delete training run {run_id}: {exc}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal server error",
)
# Remove model artifact if it exists
models_base = Path("models").resolve()
model_path = (Path("models") / f"{run_id}.pkl").resolve()
if not str(model_path).startswith(str(models_base) + "/"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid run_id format",
)
if model_path.exists():
try:
model_path.unlink()
logger.info(f"Deleted model artifact: {model_path}")
except Exception as exc:
logger.warning(f"Could not delete model artifact {model_path}: {exc}")
logger.info(f"Deleted training run: {run_id}")
return DeleteRunResponse(run_id=run_id, deleted=True)
@app.get("/training/runs/{run_id}", response_model=TrainingRunInfo, dependencies=[Depends(verify_api_key)])
async def get_training_run(run_id: str):
"""
Get information about a specific training run by run_id.
Returns HTTP 400 if the run_id format is invalid.
Returns HTTP 404 if the run_id doesn't exist.
"""
from sqlalchemy import select
# Validate run_id format to prevent path traversal
if not re.match(r'^[a-zA-Z0-9_-]+$', run_id):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid run_id format",
)
try:
with get_db() as db:
stmt = select(TrainingRun).where(TrainingRun.run_id == run_id)
row = db.execute(stmt).scalar_one_or_none()
if row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Training run not found: {run_id}",
)
return TrainingRunInfo(
run_id=row.run_id,
model_type=row.model_type,
status=row.status,
experiment_name=row.experiment_name,
created_at=row.created_at.isoformat() if row.created_at else None,
completed_at=row.completed_at.isoformat() if row.completed_at else None,
metrics_summary=row.metrics_summary,
)
except HTTPException:
raise
except Exception as exc:
logger.error(f"Failed to fetch training run {run_id}: {exc}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal server error",
)
@app.get("/training/dataset-info", response_model=DatasetInfoResponse, dependencies=[Depends(verify_api_key)])
async def training_dataset_info():
"""
Return information about the labeled training dataset.
Includes file path, existence, size, last modified date, and row count.
"""
config = state.pipeline_config or get_default_config()
labeled_path = Path(config.data.labeled_path)
if not labeled_path.exists():
return DatasetInfoResponse(path=str(labeled_path), exists=False)
try:
stat = labeled_path.stat()
size_bytes = stat.st_size
last_modified = datetime.fromtimestamp(stat.st_mtime).isoformat()
row_count = None
try:
# Read only one column for efficiency
df_head = pd.read_csv(labeled_path, usecols=[0])
row_count = len(df_head)
except Exception:
pass
return DatasetInfoResponse(
path=str(labeled_path),
exists=True,
size_bytes=size_bytes,
last_modified=last_modified,
row_count=row_count,
)
except Exception as exc:
logger.error(f"Failed to get dataset info: {exc}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal server error",
)
class BuildDatasetResponse(BaseModel):
"""Response model for POST /training/build-dataset."""
chart_name: str
n_candles: int
n_samples: int
n_features: int
labeled_path: str
@app.post("/training/build-dataset", response_model=BuildDatasetResponse, dependencies=[Depends(verify_api_key)])
async def training_build_dataset():
"""
Build the labeled training dataset from database annotations.
Exports candles, runs feature engineering, and ingests span annotations
into a labeled CSV ready for training.
"""
config = state.pipeline_config or get_default_config()
try:
result = build_dataset_from_db(config)
return BuildDatasetResponse(**result)
except ValueError as exc:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Internal server error")
except Exception as exc:
logger.error(f"Failed to build dataset: {exc}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal server error",
)
# ---------------------------------------------------------------------------
# Model Loading Endpoint
# ---------------------------------------------------------------------------
class ModelLoadRequest(BaseModel):
"""Request model for POST /model/load."""
run_id: str = Field(..., description="Training run ID to load model from")
class ModelLoadResponse(BaseModel):
"""Response model for POST /model/load."""
run_id: str
model_type: str
status: str
# Lock protecting model hot-swap
_model_swap_lock = threading.Lock()
@app.post("/model/load", response_model=ModelLoadResponse, dependencies=[Depends(verify_api_key)])
async def model_load(request: ModelLoadRequest):
"""
Load a trained model from a completed training run.
Looks up the run_id in the training_runs table, loads the model artifact,
and replaces the active model in AppState with a brief lock to prevent
conflicts with in-flight prediction requests.
"""
from sqlalchemy import select
# 0. Validate run_id format to prevent path traversal
if not re.match(r'^[a-zA-Z0-9_-]+$', request.run_id):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid run_id format",
)
# 1. Look up the training run
try:
with get_db() as db:
stmt = select(TrainingRun).where(TrainingRun.run_id == request.run_id)
row = db.execute(stmt).scalar_one_or_none()
except Exception as exc:
logger.error(f"DB lookup failed for run_id={request.run_id}: {exc}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal server error",
)
if row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Training run not found: {request.run_id}",
)
if row.status != "completed":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Training run is not completed (status={row.status})",
)
# 2. Resolve model artifact path
models_base = Path("models").resolve()
model_path = (Path("models") / f"{request.run_id}.pkl").resolve()
if not str(model_path).startswith(str(models_base) + "/"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid run_id format",
)
if not model_path.exists():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Model artifact not found at {model_path}",
)
# 3. Load model (outside lock can be slow)
try:
new_model, new_model_info = load_model_from_local(str(model_path))
except Exception as exc:
logger.error(f"Failed to load model for run_id={request.run_id}: {exc}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal server error",
)
# 4. Thread-safe model swap (3.2 brief lock)
with _model_swap_lock:
state.model = new_model
state.model_info = new_model_info
state.feature_columns = new_model_info.get("feature_columns")
logger.info(f"Model hot-swapped: run_id={request.run_id}, type={row.model_type}")
return ModelLoadResponse(
run_id=request.run_id,
model_type=row.model_type,
status="loaded",
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)