Replace all instances of `detail=str(e)`, `detail=f"...{exc}"`, and similar
patterns that exposed internal exception messages to HTTP clients in
services/ml/app/main.py. All exception details are now logged server-side
only via logger.error(), while clients receive a generic "Internal server error"
message. Fixes 9 handlers across predict, batch predict, pattern detection,
training start, training runs fetch, training run delete, dataset info,
build dataset, and model load endpoints.
Mark task 5.1 as completed in tasks.md.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1502 lines
50 KiB
Python
1502 lines
50 KiB
Python
"""
|
||
FastAPI inference service for candlestick pattern prediction.
|
||
|
||
Provides REST API endpoints for model serving, health checks, and prediction.
|
||
"""
|
||
|
||
import logging
|
||
import os
|
||
import re
|
||
import threading
|
||
import uuid as uuid_lib
|
||
from pathlib import Path
|
||
from typing import Optional, Dict, Any, List
|
||
from datetime import datetime
|
||
import json
|
||
|
||
from fastapi import FastAPI, HTTPException, Header, Depends, Security, status
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from pydantic import BaseModel, Field
|
||
import numpy as np
|
||
import pandas as pd
|
||
import joblib
|
||
import mlflow
|
||
import mlflow.sklearn
|
||
import mlflow.xgboost
|
||
from sqlalchemy import update as sa_update, desc
|
||
|
||
from app.config import load_config, PipelineConfig, get_default_config
|
||
from app.db import get_db, TrainingRun, init_db
|
||
from app.preprocessing import preprocess_candles, extract_feature_columns
|
||
from app.patterns import (
|
||
TALIB_PATTERNS,
|
||
get_available_patterns,
|
||
validate_pattern_names,
|
||
detect_patterns,
|
||
)
|
||
|
||
# Configure logging
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# --- API Key Dependency ---
|
||
|
||
async def verify_api_key(x_api_key: str = Header(default="")):
|
||
"""Verify X-API-Key header against API_KEY env var. Fail-open if not configured."""
|
||
api_key = os.getenv("API_KEY")
|
||
if not api_key:
|
||
return # fail-open if not configured
|
||
if x_api_key != api_key:
|
||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||
|
||
|
||
# FastAPI app
|
||
app = FastAPI(
|
||
title="Candle Pattern Inference API",
|
||
description="ML inference service for candlestick pattern recognition",
|
||
version="1.0.0"
|
||
)
|
||
|
||
# Parse CORS origins from environment variable or use default
|
||
cors_origins_str = os.getenv("CORS_ORIGINS", "http://localhost:3000")
|
||
allow_origins = [origin.strip() for origin in cors_origins_str.split(",")]
|
||
|
||
# CORS middleware
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=allow_origins,
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# Global state
|
||
class AppState:
|
||
"""Application state container."""
|
||
model: Optional[Any] = None
|
||
model_info: Optional[Dict[str, Any]] = None
|
||
pipeline_config: Optional[PipelineConfig] = None
|
||
feature_columns: Optional[List[str]] = None
|
||
label_encoder: Optional[Dict[int, str]] = None
|
||
# Training thread management
|
||
active_training_run_id: Optional[str] = None
|
||
training_lock: threading.Lock = None
|
||
|
||
def __init__(self):
|
||
self.training_lock = threading.Lock()
|
||
|
||
state = AppState()
|
||
|
||
|
||
# --- Pydantic Models ---
|
||
|
||
class CandleData(BaseModel):
|
||
"""Single candle data point."""
|
||
time: int = Field(..., description="Unix timestamp in seconds")
|
||
open: float
|
||
high: float
|
||
low: float
|
||
close: float
|
||
volume: Optional[float] = None
|
||
|
||
|
||
class PredictRequest(BaseModel):
|
||
"""Request model for /predict endpoint."""
|
||
pair: Optional[str] = Field(None, description="Trading pair (e.g., EURUSD)")
|
||
timeframe: Optional[str] = Field(None, description="Timeframe (e.g., 1H, 4H, 1D)")
|
||
candles: List[CandleData] = Field(..., min_length=1, description="Array of candle data")
|
||
|
||
|
||
class PredictionResult(BaseModel):
|
||
"""Single candle prediction."""
|
||
time: int
|
||
label: str
|
||
confidence: float = Field(..., ge=0.0, le=1.0)
|
||
|
||
|
||
class PredictionSpan(BaseModel):
|
||
"""Grouped prediction span."""
|
||
start_time: int
|
||
end_time: int
|
||
label: str
|
||
avg_confidence: float = Field(..., ge=0.0, le=1.0)
|
||
candle_count: int
|
||
|
||
|
||
class ModelInfo(BaseModel):
|
||
"""Model metadata."""
|
||
model_name: str
|
||
model_version: Optional[str] = None
|
||
model_type: str
|
||
trained_at: Optional[str] = None
|
||
dataset_version: Optional[str] = None
|
||
feature_engineering_enabled: bool
|
||
|
||
|
||
class PredictResponse(BaseModel):
|
||
"""Response model for /predict endpoint."""
|
||
predictions: List[PredictionResult]
|
||
spans: List[PredictionSpan]
|
||
model_info: ModelInfo
|
||
|
||
|
||
class BatchPredictRequest(BaseModel):
|
||
"""Request model for /predict/batch endpoint."""
|
||
pair: str
|
||
timeframe: str
|
||
start_date: str = Field(..., description="ISO format date (YYYY-MM-DD)")
|
||
end_date: str = Field(..., description="ISO format date (YYYY-MM-DD)")
|
||
|
||
|
||
class LabelInfo(BaseModel):
|
||
"""Pattern label with display color."""
|
||
name: str
|
||
color: str
|
||
|
||
|
||
class PerClassMetrics(BaseModel):
|
||
"""Per-class performance metrics."""
|
||
label: str
|
||
precision: float
|
||
recall: float
|
||
f1_score: float
|
||
support: int
|
||
|
||
|
||
class ModelInfoResponse(BaseModel):
|
||
"""Extended model information with metrics."""
|
||
model_name: str
|
||
model_version: Optional[str] = None
|
||
model_type: str
|
||
trained_at: Optional[str] = None
|
||
dataset_version: Optional[str] = None
|
||
feature_engineering_enabled: bool
|
||
labels: List[str]
|
||
per_class_metrics: List[PerClassMetrics]
|
||
|
||
|
||
class HealthResponse(BaseModel):
|
||
"""Health check response."""
|
||
status: str = Field(..., description="healthy, degraded, or unhealthy")
|
||
model_loaded: bool
|
||
mlflow: str = Field(..., description="connected or disconnected")
|
||
database: str = Field(..., description="connected or disconnected")
|
||
|
||
|
||
# --- Model Loading Functions ---
|
||
|
||
def load_model_from_mlflow(model_name: str, stage: str) -> tuple[Any, Dict[str, Any]]:
|
||
"""
|
||
Load model from MLflow model registry.
|
||
|
||
Args:
|
||
model_name: Name of the registered model
|
||
stage: Model stage (Production, Staging, None)
|
||
|
||
Returns:
|
||
Tuple of (model, model_info)
|
||
|
||
Raises:
|
||
Exception: If model not found or MLflow unavailable
|
||
"""
|
||
logger.info(f"Loading model from MLflow: {model_name} stage={stage}")
|
||
|
||
try:
|
||
# Load model from registry
|
||
model_uri = f"models:/{model_name}/{stage}"
|
||
model = mlflow.pyfunc.load_model(model_uri)
|
||
|
||
# Get model version details
|
||
client = mlflow.tracking.MlflowClient()
|
||
model_versions = client.get_latest_versions(model_name, stages=[stage])
|
||
|
||
if not model_versions:
|
||
raise ValueError(f"No model found for {model_name} at stage {stage}")
|
||
|
||
model_version = model_versions[0]
|
||
run_id = model_version.run_id
|
||
|
||
# Get run metadata
|
||
run = client.get_run(run_id)
|
||
|
||
# Extract model info
|
||
model_info = {
|
||
"model_name": model_name,
|
||
"model_version": model_version.version,
|
||
"model_type": run.data.params.get("model_type", "unknown"),
|
||
"trained_at": datetime.fromtimestamp(run.info.start_time / 1000).isoformat(),
|
||
"dataset_version": run.data.params.get("dataset_version", None),
|
||
"feature_engineering_enabled": run.data.params.get("feature_engineering", "true") == "true",
|
||
"labels": json.loads(run.data.params.get("labels", "[]")),
|
||
"per_class_metrics": []
|
||
}
|
||
|
||
# Extract per-class metrics
|
||
for key, value in run.data.metrics.items():
|
||
if key.startswith("precision_"):
|
||
label = key.replace("precision_", "")
|
||
recall = run.data.metrics.get(f"recall_{label}", 0.0)
|
||
f1 = run.data.metrics.get(f"f1_{label}", 0.0)
|
||
support = int(run.data.params.get(f"support_{label}", 0))
|
||
|
||
model_info["per_class_metrics"].append({
|
||
"label": label,
|
||
"precision": value,
|
||
"recall": recall,
|
||
"f1_score": f1,
|
||
"support": support
|
||
})
|
||
|
||
logger.info(f"Successfully loaded model: {model_name} v{model_version.version}")
|
||
return model, model_info
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to load model from MLflow: {e}")
|
||
raise
|
||
|
||
|
||
def load_model_from_local(model_path: str) -> tuple[Any, Dict[str, Any]]:
|
||
"""
|
||
Load model from local file using joblib.
|
||
|
||
Args:
|
||
model_path: Path to .pkl model file
|
||
|
||
Returns:
|
||
Tuple of (model, model_info)
|
||
|
||
Raises:
|
||
FileNotFoundError: If model file doesn't exist
|
||
Exception: If model loading fails
|
||
"""
|
||
model_path = Path(model_path)
|
||
logger.info(f"Loading model from local file: {model_path}")
|
||
|
||
if not model_path.exists():
|
||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||
|
||
try:
|
||
# Load model with joblib
|
||
model_data = joblib.load(model_path)
|
||
|
||
# Extract model and metadata
|
||
if isinstance(model_data, dict):
|
||
model = model_data.get("model")
|
||
metadata = model_data.get("metadata", {})
|
||
else:
|
||
model = model_data
|
||
metadata = {}
|
||
|
||
# Extract labels from model if not in metadata
|
||
labels = metadata.get("labels", [])
|
||
if not labels:
|
||
# Try to get class labels from the model itself
|
||
if hasattr(model, 'classes_'):
|
||
labels = [str(c) for c in model.classes_]
|
||
elif hasattr(model, 'model') and hasattr(model.model, 'classes_'):
|
||
labels = [str(c) for c in model.model.classes_]
|
||
|
||
# Build model info
|
||
model_info = {
|
||
"model_name": model_path.stem,
|
||
"model_version": metadata.get("version", "local"),
|
||
"model_type": metadata.get("model_type", "unknown"),
|
||
"trained_at": metadata.get("trained_at", None),
|
||
"dataset_version": metadata.get("dataset_version", None),
|
||
"feature_engineering_enabled": metadata.get("feature_engineering_enabled", True),
|
||
"labels": labels,
|
||
"per_class_metrics": metadata.get("per_class_metrics", [])
|
||
}
|
||
|
||
logger.info(f"Successfully loaded local model: {model_path.name}")
|
||
return model, model_info
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to load model from local file: {e}")
|
||
raise
|
||
|
||
|
||
# --- Startup Event ---
|
||
|
||
@app.on_event("startup")
|
||
async def startup_event():
|
||
"""
|
||
Load model and pipeline config on startup.
|
||
"""
|
||
logger.info("Starting inference service...")
|
||
|
||
# Ensure training_runs table exists
|
||
init_db()
|
||
|
||
# Mark any stale "running" records as failed — they belong to a previous
|
||
# process and will never complete.
|
||
try:
|
||
with get_db() as db:
|
||
stmt = (
|
||
sa_update(TrainingRun)
|
||
.where(TrainingRun.status == "running")
|
||
.values(
|
||
status="failed",
|
||
completed_at=datetime.utcnow(),
|
||
metrics_summary={"error": "Service restarted while training was in progress"},
|
||
)
|
||
)
|
||
result = db.execute(stmt)
|
||
db.commit()
|
||
if result.rowcount:
|
||
logger.warning(
|
||
f"Marked {result.rowcount} stale training run(s) as failed on startup"
|
||
)
|
||
except Exception as exc:
|
||
logger.error(f"Failed to clean up stale training runs: {exc}")
|
||
|
||
# Load pipeline config
|
||
config_path = Path("config/pipeline.yaml")
|
||
if not config_path.exists():
|
||
logger.warning(f"Config file not found: {config_path}. Using defaults.")
|
||
return
|
||
|
||
try:
|
||
state.pipeline_config = load_config(config_path)
|
||
logger.info(f"Loaded pipeline config from {config_path}")
|
||
except Exception as e:
|
||
logger.error(f"Failed to load pipeline config: {e}")
|
||
return
|
||
|
||
# Load model based on config
|
||
inference_config = state.pipeline_config.stages.inference
|
||
|
||
if not inference_config.enabled:
|
||
logger.warning("Inference stage is disabled in config")
|
||
return
|
||
|
||
# Load model
|
||
try:
|
||
if inference_config.model_source == "mlflow":
|
||
# Configure MLflow tracking URI
|
||
mlflow.set_tracking_uri(state.pipeline_config.stages.training.mlflow.tracking_uri)
|
||
|
||
state.model, state.model_info = load_model_from_mlflow(
|
||
model_name=inference_config.mlflow_model_name,
|
||
stage=inference_config.mlflow_model_stage
|
||
)
|
||
elif inference_config.model_source == "local":
|
||
state.model, state.model_info = load_model_from_local(
|
||
model_path=inference_config.local_model_path
|
||
)
|
||
else:
|
||
logger.error(f"Unknown model_source: {inference_config.model_source}")
|
||
return
|
||
|
||
logger.info("Model loaded successfully")
|
||
logger.info(f"Model info: {state.model_info['model_name']} "
|
||
f"v{state.model_info['model_version']} "
|
||
f"({state.model_info['model_type']})")
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to load model: {e}")
|
||
logger.warning("Service will start without a model. Use /health to check status.")
|
||
|
||
|
||
# --- Health Check ---
|
||
|
||
@app.get("/health", response_model=HealthResponse)
|
||
async def health_check():
|
||
"""
|
||
Health check endpoint.
|
||
|
||
Returns service status, model loaded status, and dependency health.
|
||
"""
|
||
model_loaded = state.model is not None
|
||
|
||
# Check MLflow connection
|
||
mlflow_status = "disconnected"
|
||
try:
|
||
# TODO: Actually check MLflow connection
|
||
mlflow_status = "connected"
|
||
except Exception:
|
||
pass
|
||
|
||
# Check database connection
|
||
db_status = "disconnected"
|
||
try:
|
||
# TODO: Actually check database connection
|
||
db_status = "connected"
|
||
except Exception:
|
||
pass
|
||
|
||
# Determine overall status
|
||
if model_loaded and mlflow_status == "connected" and db_status == "connected":
|
||
overall_status = "healthy"
|
||
elif model_loaded:
|
||
overall_status = "degraded"
|
||
else:
|
||
overall_status = "unhealthy"
|
||
|
||
return HealthResponse(
|
||
status=overall_status,
|
||
model_loaded=model_loaded,
|
||
mlflow=mlflow_status,
|
||
database=db_status
|
||
)
|
||
|
||
|
||
# --- Model Info Endpoints (Stubs) ---
|
||
|
||
@app.get("/model/info", response_model=ModelInfoResponse, dependencies=[Depends(verify_api_key)])
|
||
async def get_model_info():
|
||
"""
|
||
Get detailed model information and per-class metrics.
|
||
|
||
Returns HTTP 503 if no model is loaded.
|
||
"""
|
||
if state.model is None or state.model_info is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||
detail="No model available"
|
||
)
|
||
|
||
# Build response from loaded model info
|
||
return ModelInfoResponse(
|
||
model_name=state.model_info["model_name"],
|
||
model_version=state.model_info.get("model_version"),
|
||
model_type=state.model_info["model_type"],
|
||
trained_at=state.model_info.get("trained_at"),
|
||
dataset_version=state.model_info.get("dataset_version"),
|
||
feature_engineering_enabled=state.model_info["feature_engineering_enabled"],
|
||
labels=state.model_info.get("labels", []),
|
||
per_class_metrics=[
|
||
PerClassMetrics(**metric)
|
||
for metric in state.model_info.get("per_class_metrics", [])
|
||
]
|
||
)
|
||
|
||
|
||
@app.get("/model/labels", response_model=List[LabelInfo], dependencies=[Depends(verify_api_key)])
|
||
async def get_model_labels():
|
||
"""
|
||
Get all pattern labels the model can predict with display colors.
|
||
"""
|
||
if state.model is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||
detail="No model available"
|
||
)
|
||
|
||
# TODO: Load label colors from config or database
|
||
# For now, return placeholder
|
||
labels = state.model_info.get("labels", []) if state.model_info else []
|
||
|
||
return [
|
||
LabelInfo(name=label, color="#888888")
|
||
for label in labels
|
||
]
|
||
|
||
|
||
# --- Prediction Helper Functions ---
|
||
|
||
def group_prediction_spans(
|
||
predictions: List[PredictionResult]
|
||
) -> List[PredictionSpan]:
|
||
"""
|
||
Group consecutive predictions with the same label into spans.
|
||
|
||
Only groups non-"O" (non-background) predictions. "O" predictions are
|
||
treated as background and not grouped into spans.
|
||
|
||
Args:
|
||
predictions: List of per-candle predictions
|
||
|
||
Returns:
|
||
List of prediction spans
|
||
"""
|
||
if not predictions:
|
||
return []
|
||
|
||
spans = []
|
||
current_span = None
|
||
|
||
for pred in predictions:
|
||
# Skip "O" (background) predictions
|
||
if pred.label == "O":
|
||
if current_span is not None:
|
||
# End current span
|
||
spans.append(current_span)
|
||
current_span = None
|
||
continue
|
||
|
||
# Start new span or continue current span
|
||
if current_span is None or current_span.label != pred.label:
|
||
# End previous span if exists
|
||
if current_span is not None:
|
||
spans.append(current_span)
|
||
|
||
# Start new span
|
||
current_span = PredictionSpan(
|
||
start_time=pred.time,
|
||
end_time=pred.time,
|
||
label=pred.label,
|
||
avg_confidence=pred.confidence,
|
||
candle_count=1
|
||
)
|
||
else:
|
||
# Continue current span
|
||
current_span.end_time = pred.time
|
||
current_span.candle_count += 1
|
||
# Update running average confidence
|
||
total_confidence = current_span.avg_confidence * (current_span.candle_count - 1)
|
||
current_span.avg_confidence = (total_confidence + pred.confidence) / current_span.candle_count
|
||
|
||
# Don't forget last span
|
||
if current_span is not None:
|
||
spans.append(current_span)
|
||
|
||
logger.info(f"Grouped {len(predictions)} predictions into {len(spans)} spans")
|
||
|
||
return spans
|
||
|
||
|
||
# --- Prediction Endpoints ---
|
||
|
||
@app.post("/predict", response_model=PredictResponse, dependencies=[Depends(verify_api_key)])
|
||
async def predict(request: PredictRequest):
|
||
"""
|
||
Predict candlestick patterns for provided candles.
|
||
|
||
Accepts OHLCV candles, runs preprocessing, and returns predictions.
|
||
"""
|
||
if state.model is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||
detail="No model available"
|
||
)
|
||
|
||
if not request.candles:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="Candles array cannot be empty"
|
||
)
|
||
|
||
if state.pipeline_config is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="Pipeline configuration not loaded"
|
||
)
|
||
|
||
logger.info(f"Predict request: {request.pair or 'unknown'} {request.timeframe or 'unknown'}, {len(request.candles)} candles")
|
||
|
||
try:
|
||
# Convert candles to list of dicts
|
||
candles_data = [candle.model_dump() for candle in request.candles]
|
||
|
||
# Preprocess candles (feature engineering + windowing)
|
||
X, window_times = preprocess_candles(candles_data, state.pipeline_config)
|
||
|
||
# Get predictions and probabilities
|
||
if hasattr(state.model, 'predict_proba'):
|
||
y_pred = state.model.predict(X)
|
||
y_proba = state.model.predict_proba(X)
|
||
|
||
# Get confidence as max probability
|
||
confidences = np.max(y_proba, axis=1)
|
||
else:
|
||
# Fallback for models without predict_proba
|
||
y_pred = state.model.predict(X)
|
||
confidences = np.ones(len(y_pred)) # Default confidence of 1.0
|
||
|
||
# Get label names (handle both string and int predictions)
|
||
if state.label_encoder is not None:
|
||
labels = [state.label_encoder.get(int(pred), f"unknown_{pred}") for pred in y_pred]
|
||
else:
|
||
labels = [str(pred) for pred in y_pred]
|
||
|
||
# Build per-window predictions (each window maps to its last candle time)
|
||
predictions = [
|
||
PredictionResult(
|
||
time=int(time),
|
||
label=label,
|
||
confidence=float(conf)
|
||
)
|
||
for time, label, conf in zip(window_times, labels, confidences)
|
||
]
|
||
|
||
# Group into spans
|
||
spans = group_prediction_spans(predictions)
|
||
|
||
# Build model info for response
|
||
model_info = ModelInfo(
|
||
model_name=state.model_info["model_name"],
|
||
model_version=state.model_info.get("model_version"),
|
||
model_type=state.model_info["model_type"],
|
||
trained_at=state.model_info.get("trained_at"),
|
||
dataset_version=state.model_info.get("dataset_version"),
|
||
feature_engineering_enabled=state.model_info["feature_engineering_enabled"]
|
||
)
|
||
|
||
logger.info(
|
||
f"Prediction complete: {len(predictions)} windows, "
|
||
f"{len(spans)} spans, {len([p for p in predictions if p.label != 'O'])} patterns"
|
||
)
|
||
|
||
return PredictResponse(
|
||
predictions=predictions,
|
||
spans=spans,
|
||
model_info=model_info
|
||
)
|
||
|
||
except ValueError as e:
|
||
logger.error(f"Prediction validation error: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="Internal server error"
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Prediction failed: {e}", exc_info=True)
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="Internal server error"
|
||
)
|
||
|
||
|
||
@app.post("/predict/batch", response_model=PredictResponse, dependencies=[Depends(verify_api_key)])
|
||
async def predict_batch(request: BatchPredictRequest):
|
||
"""
|
||
Batch prediction for a date range.
|
||
|
||
Loads data from storage and returns predictions for the full range.
|
||
"""
|
||
if state.model is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||
detail="No model available"
|
||
)
|
||
|
||
if state.pipeline_config is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="Pipeline configuration not loaded"
|
||
)
|
||
|
||
logger.info(
|
||
f"Batch predict: {request.pair} {request.timeframe} "
|
||
f"from {request.start_date} to {request.end_date}"
|
||
)
|
||
|
||
try:
|
||
# Load OHLCV data from raw data path
|
||
raw_path = Path(state.pipeline_config.data.raw_path)
|
||
|
||
if not raw_path.exists():
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"No data found for {request.pair} {request.timeframe}"
|
||
)
|
||
|
||
# Load data
|
||
df_raw = pd.read_csv(raw_path)
|
||
|
||
# Filter by date range if time column exists
|
||
if 'time' in df_raw.columns:
|
||
# Parse dates to timestamps
|
||
start_ts = int(pd.Timestamp(request.start_date).timestamp())
|
||
end_ts = int(pd.Timestamp(request.end_date).timestamp())
|
||
|
||
# Filter data
|
||
df_filtered = df_raw[
|
||
(df_raw['time'] >= start_ts) & (df_raw['time'] <= end_ts)
|
||
].copy()
|
||
|
||
if len(df_filtered) == 0:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"No data found in date range {request.start_date} to {request.end_date}"
|
||
)
|
||
|
||
logger.info(f"Loaded {len(df_filtered)} candles from {request.start_date} to {request.end_date}")
|
||
else:
|
||
df_filtered = df_raw
|
||
logger.warning("No 'time' column found, using all data")
|
||
|
||
# Get batch size from config
|
||
batch_size = state.pipeline_config.stages.inference.batch_size
|
||
|
||
# Process in batches
|
||
all_predictions = []
|
||
all_spans = []
|
||
|
||
for i in range(0, len(df_filtered), batch_size):
|
||
batch_df = df_filtered.iloc[i:i + batch_size]
|
||
|
||
logger.info(f"Processing batch {i // batch_size + 1}: rows {i} to {i + len(batch_df)}")
|
||
|
||
# Convert batch to candles format
|
||
batch_candles = batch_df.to_dict('records')
|
||
|
||
# Preprocess (feature engineering + windowing)
|
||
X, window_times = preprocess_candles(batch_candles, state.pipeline_config)
|
||
|
||
# Predict
|
||
if hasattr(state.model, 'predict_proba'):
|
||
y_pred = state.model.predict(X)
|
||
y_proba = state.model.predict_proba(X)
|
||
confidences = np.max(y_proba, axis=1)
|
||
else:
|
||
y_pred = state.model.predict(X)
|
||
confidences = np.ones(len(y_pred))
|
||
|
||
# Get labels
|
||
if state.label_encoder is not None:
|
||
labels = [state.label_encoder.get(int(pred), f"unknown_{pred}") for pred in y_pred]
|
||
else:
|
||
labels = [str(pred) for pred in y_pred]
|
||
|
||
# Build predictions for this batch
|
||
batch_predictions = [
|
||
PredictionResult(
|
||
time=int(time),
|
||
label=label,
|
||
confidence=float(conf)
|
||
)
|
||
for time, label, conf in zip(window_times, labels, confidences)
|
||
]
|
||
|
||
all_predictions.extend(batch_predictions)
|
||
|
||
# Group all predictions into spans
|
||
all_spans = group_prediction_spans(all_predictions)
|
||
|
||
# Build model info
|
||
model_info = ModelInfo(
|
||
model_name=state.model_info["model_name"],
|
||
model_version=state.model_info.get("model_version"),
|
||
model_type=state.model_info["model_type"],
|
||
trained_at=state.model_info.get("trained_at"),
|
||
dataset_version=state.model_info.get("dataset_version"),
|
||
feature_engineering_enabled=state.model_info["feature_engineering_enabled"]
|
||
)
|
||
|
||
logger.info(
|
||
f"Batch prediction complete: {len(all_predictions)} candles, "
|
||
f"{len(all_spans)} spans"
|
||
)
|
||
|
||
return PredictResponse(
|
||
predictions=all_predictions,
|
||
spans=all_spans,
|
||
model_info=model_info
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"Batch prediction failed: {e}", exc_info=True)
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="Internal server error"
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Pattern Detection Endpoints
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class PatternInfo(BaseModel):
|
||
"""A single supported CDL pattern."""
|
||
function_name: str
|
||
display_name: str
|
||
|
||
|
||
class DetectPatternsRequest(BaseModel):
|
||
"""Request model for POST /patterns/detect."""
|
||
candles: List[CandleData] = Field(..., min_length=1, description="Array of candle data")
|
||
patterns: List[str] = Field(
|
||
default=[],
|
||
description="CDL function names to run. Empty list means run all.",
|
||
)
|
||
|
||
|
||
class SpanAnnotation(BaseModel):
|
||
"""A span annotation returned by pattern detection."""
|
||
start_time: int
|
||
end_time: int
|
||
label: str
|
||
confidence: float = Field(..., ge=0.0, le=1.0)
|
||
source: str
|
||
notes: str
|
||
|
||
|
||
class DetectPatternsResponse(BaseModel):
|
||
"""Response model for POST /patterns/detect."""
|
||
annotations: List[SpanAnnotation]
|
||
metadata: Dict[str, Any]
|
||
|
||
|
||
@app.get("/patterns/available", response_model=List[PatternInfo], dependencies=[Depends(verify_api_key)])
|
||
async def patterns_available():
|
||
"""
|
||
Return all supported CDL pattern names with display names.
|
||
"""
|
||
return get_available_patterns()
|
||
|
||
|
||
@app.post("/patterns/detect", response_model=DetectPatternsResponse, dependencies=[Depends(verify_api_key)])
|
||
async def patterns_detect(request: DetectPatternsRequest):
|
||
"""
|
||
Detect TA-Lib CDL patterns on provided candle data.
|
||
|
||
- Empty ``patterns`` list runs all available CDL functions.
|
||
- Invalid pattern names return HTTP 400.
|
||
"""
|
||
# 1.4 – Validate pattern names
|
||
if request.patterns:
|
||
invalid = validate_pattern_names(request.patterns)
|
||
if invalid:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=f"Invalid pattern name(s): {', '.join(invalid)}. "
|
||
f"Use GET /patterns/available to see supported patterns.",
|
||
)
|
||
|
||
candles_data = [c.model_dump() for c in request.candles]
|
||
|
||
try:
|
||
raw_annotations = detect_patterns(candles_data, request.patterns or None)
|
||
except RuntimeError as exc:
|
||
logger.error(f"Pattern detection runtime error: {exc}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||
detail="Internal server error",
|
||
)
|
||
|
||
annotations = [SpanAnnotation(**ann) for ann in raw_annotations]
|
||
|
||
return DetectPatternsResponse(
|
||
annotations=annotations,
|
||
metadata={"source": "talib", "count": len(annotations)},
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Training Endpoints
|
||
# ---------------------------------------------------------------------------
|
||
|
||
SUPPORTED_MODEL_TYPES = ["random_forest", "xgboost"]
|
||
|
||
|
||
class TrainingStartRequest(BaseModel):
|
||
"""Request model for POST /training/start."""
|
||
model_type: str = Field(
|
||
"random_forest",
|
||
description="Model type: random_forest or xgboost",
|
||
)
|
||
|
||
|
||
class TrainingStartResponse(BaseModel):
|
||
"""Response model for POST /training/start."""
|
||
run_id: str
|
||
status: str
|
||
|
||
|
||
class TrainingRunInfo(BaseModel):
|
||
"""Summary of a single training run."""
|
||
run_id: str
|
||
model_type: str
|
||
status: str
|
||
experiment_name: Optional[str] = None
|
||
created_at: Optional[str] = None
|
||
completed_at: Optional[str] = None
|
||
metrics_summary: Optional[Dict[str, Any]] = None
|
||
|
||
|
||
class TrainingRunsResponse(BaseModel):
|
||
"""Response model for GET /training/runs."""
|
||
runs: List[TrainingRunInfo]
|
||
|
||
|
||
class DatasetInfoResponse(BaseModel):
|
||
"""Response model for GET /training/dataset-info."""
|
||
path: str
|
||
exists: bool
|
||
size_bytes: Optional[int] = None
|
||
last_modified: Optional[str] = None
|
||
row_count: Optional[int] = None
|
||
|
||
|
||
def build_dataset_from_db(config: PipelineConfig) -> dict:
|
||
"""
|
||
Build the labeled training dataset directly from the database.
|
||
|
||
Steps:
|
||
1. Export candles from PostgreSQL to raw CSV
|
||
2. Run feature engineering (TA-Lib indicators, candle features)
|
||
3. Run annotation ingestion from DB (span_annotations -> labeled CSV)
|
||
|
||
Returns:
|
||
dict with keys: chart_name, n_candles, n_annotations, n_samples, labeled_path
|
||
"""
|
||
from app.data_access import DataAccess
|
||
from app.annotation_ingestion import AnnotationIngestion
|
||
from features.engineer import run_feature_engineering_stage
|
||
|
||
data_access = DataAccess()
|
||
|
||
# Find all charts, use the first one (single-chart app)
|
||
charts_df = data_access.get_all_charts()
|
||
if charts_df.empty:
|
||
raise ValueError("No charts found in database. Upload candle data first.")
|
||
|
||
chart = charts_df.iloc[0]
|
||
chart_name = chart["name"]
|
||
chart_id = int(chart["id"])
|
||
logger.info(f"Building dataset for chart: {chart_name} (id={chart_id})")
|
||
|
||
# Step 1: Export candles to raw CSV
|
||
candles_df = data_access.get_candles(chart_id)
|
||
if candles_df.empty:
|
||
raise ValueError(f"No candles found for chart: {chart_name}")
|
||
|
||
raw_path = Path(config.data.raw_path)
|
||
raw_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
# Ensure 'time' column is suitable for feature engineering
|
||
export_df = candles_df[["time", "open", "high", "low", "close"]].copy()
|
||
export_df.to_csv(raw_path, index=False)
|
||
logger.info(f"Exported {len(export_df)} candles to {raw_path}")
|
||
|
||
# Step 2: Run feature engineering
|
||
run_feature_engineering_stage(config)
|
||
enriched_path = Path(config.data.enriched_path)
|
||
logger.info(f"Feature engineering complete: {enriched_path}")
|
||
|
||
# Step 3: Run annotation ingestion from database
|
||
enriched_df = pd.read_csv(enriched_path, parse_dates=["time"])
|
||
ingestion = AnnotationIngestion(config.stages.annotation_ingestion)
|
||
labeled_df = ingestion.process_from_db(enriched_df, chart_name, source="human")
|
||
|
||
if labeled_df.empty:
|
||
raise ValueError(
|
||
f"No labeled samples produced. "
|
||
f"Ensure you have span annotations on chart '{chart_name}'."
|
||
)
|
||
|
||
# Write labeled dataset
|
||
labeled_path = Path(config.data.labeled_path)
|
||
labeled_path.parent.mkdir(parents=True, exist_ok=True)
|
||
labeled_df.to_csv(labeled_path, index=False)
|
||
|
||
result = {
|
||
"chart_name": chart_name,
|
||
"n_candles": len(export_df),
|
||
"n_samples": len(labeled_df),
|
||
"n_features": len([c for c in labeled_df.columns if c != "label"]),
|
||
"labeled_path": str(labeled_path),
|
||
}
|
||
logger.info(f"Dataset built: {result}")
|
||
return result
|
||
|
||
|
||
def _run_training_background(run_id: str, model_type: str, config: PipelineConfig) -> None:
|
||
"""
|
||
Background thread target: build dataset then train a model.
|
||
|
||
Uses the pre-inserted TrainingRun record identified by ``run_id``.
|
||
"""
|
||
logger.info(f"Training thread started: run_id={run_id}, model_type={model_type}")
|
||
|
||
try:
|
||
# Import training utilities here to avoid circular import issues
|
||
from training.train import create_model, temporal_split
|
||
from sklearn.metrics import accuracy_score, f1_score
|
||
|
||
# Build dataset from database (feature engineering + annotation ingestion)
|
||
logger.info("Building dataset from database...")
|
||
build_dataset_from_db(config)
|
||
|
||
labeled_path = Path(config.data.labeled_path)
|
||
if not labeled_path.exists():
|
||
raise FileNotFoundError(f"Labeled dataset not found: {labeled_path}")
|
||
|
||
# Load dataset
|
||
df = pd.read_csv(labeled_path)
|
||
|
||
if "label" not in df.columns:
|
||
raise ValueError("Labeled dataset must have 'label' column")
|
||
|
||
feature_cols = [
|
||
col for col in df.columns
|
||
if col not in ("label", "time", "timestamp")
|
||
and not col.startswith("label_programmatic_")
|
||
]
|
||
|
||
X = df[feature_cols].values
|
||
y = df["label"].values
|
||
|
||
logger.info(f"Loaded {len(X)} samples, {len(feature_cols)} features")
|
||
|
||
# Split data
|
||
training_cfg = config.stages.training
|
||
X_train, X_val, X_test, y_train, y_val, y_test = temporal_split(
|
||
X, y, training_cfg.test_split, training_cfg.validation_split
|
||
)
|
||
|
||
# Train model
|
||
model_instance = create_model(
|
||
model_type, training_cfg.hyperparameters, training_cfg.class_weights
|
||
)
|
||
model_instance.fit(X_train, y_train)
|
||
logger.info("Model training complete")
|
||
|
||
# Evaluate
|
||
y_val_pred = model_instance.predict(X_val)
|
||
y_test_pred = model_instance.predict(X_test)
|
||
|
||
metrics = {
|
||
"val_accuracy": float(accuracy_score(y_val, y_val_pred)),
|
||
"val_f1_macro": float(
|
||
f1_score(y_val, y_val_pred, average="macro", zero_division=0)
|
||
),
|
||
"test_accuracy": float(accuracy_score(y_test, y_test_pred)),
|
||
"test_f1_macro": float(
|
||
f1_score(y_test, y_test_pred, average="macro", zero_division=0)
|
||
),
|
||
"n_samples": int(len(X)),
|
||
"n_features": int(X.shape[1]),
|
||
}
|
||
|
||
# Save model locally
|
||
models_dir = Path("models")
|
||
models_dir.mkdir(exist_ok=True)
|
||
model_path = models_dir / f"{run_id}.pkl"
|
||
|
||
model_data = {
|
||
"model": model_instance,
|
||
"metadata": {
|
||
"model_type": model_type,
|
||
"trained_at": datetime.utcnow().isoformat(),
|
||
"run_id": run_id,
|
||
"feature_columns": feature_cols,
|
||
"labels": (
|
||
[str(c) for c in model_instance.model.classes_]
|
||
if hasattr(model_instance, "model") and hasattr(model_instance.model, "classes_")
|
||
else []
|
||
),
|
||
},
|
||
}
|
||
joblib.dump(model_data, model_path)
|
||
logger.info(f"Model saved to {model_path}")
|
||
|
||
# Update DB: completed
|
||
with get_db() as db:
|
||
stmt = (
|
||
sa_update(TrainingRun)
|
||
.where(TrainingRun.run_id == run_id)
|
||
.values(
|
||
status="completed",
|
||
completed_at=datetime.utcnow(),
|
||
metrics_summary=metrics,
|
||
)
|
||
)
|
||
db.execute(stmt)
|
||
db.commit()
|
||
|
||
logger.info(f"Training completed successfully: run_id={run_id}")
|
||
|
||
except Exception as exc:
|
||
logger.error(f"Training failed for run_id={run_id}: {exc}", exc_info=True)
|
||
try:
|
||
with get_db() as db:
|
||
stmt = (
|
||
sa_update(TrainingRun)
|
||
.where(TrainingRun.run_id == run_id)
|
||
.values(
|
||
status="failed",
|
||
completed_at=datetime.utcnow(),
|
||
metrics_summary={"error": str(exc)},
|
||
)
|
||
)
|
||
db.execute(stmt)
|
||
db.commit()
|
||
except Exception as db_exc:
|
||
logger.error(f"Failed to update DB for failed run {run_id}: {db_exc}")
|
||
finally:
|
||
with state.training_lock:
|
||
if state.active_training_run_id == run_id:
|
||
state.active_training_run_id = None
|
||
logger.info(f"Training thread exiting: run_id={run_id}")
|
||
|
||
|
||
@app.post("/training/start", response_model=TrainingStartResponse, dependencies=[Depends(verify_api_key)])
|
||
async def training_start(request: TrainingStartRequest):
|
||
"""
|
||
Start a training run in a background thread.
|
||
|
||
Returns immediately with run_id and status "running".
|
||
Rejects concurrent runs with HTTP 409.
|
||
"""
|
||
# Validate model type
|
||
if request.model_type not in SUPPORTED_MODEL_TYPES:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=f"Unsupported model type. Available: {', '.join(SUPPORTED_MODEL_TYPES)}",
|
||
)
|
||
|
||
# Reject concurrent runs (atomic check-and-set)
|
||
with state.training_lock:
|
||
if state.active_training_run_id is not None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_409_CONFLICT,
|
||
detail={
|
||
"error": "Training already in progress",
|
||
"run_id": state.active_training_run_id,
|
||
},
|
||
)
|
||
run_id = str(uuid_lib.uuid4())
|
||
state.active_training_run_id = run_id
|
||
|
||
config = state.pipeline_config or get_default_config()
|
||
|
||
# Compute config hash (best-effort)
|
||
config_hash = "unknown"
|
||
try:
|
||
from training.train import compute_config_hash
|
||
config_hash = compute_config_hash(config)
|
||
except Exception:
|
||
pass
|
||
|
||
# Pre-insert the run record so callers can track it immediately
|
||
try:
|
||
with get_db() as db:
|
||
training_run = TrainingRun(
|
||
run_id=run_id,
|
||
model_type=request.model_type,
|
||
experiment_name=config.stages.training.mlflow.experiment_name,
|
||
pipeline_config_hash=config_hash,
|
||
status="running",
|
||
created_at=datetime.utcnow(),
|
||
metrics_summary={},
|
||
)
|
||
db.add(training_run)
|
||
db.commit()
|
||
except Exception as exc:
|
||
with state.training_lock:
|
||
state.active_training_run_id = None
|
||
logger.error(f"Failed to insert training run record: {exc}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="Internal server error",
|
||
)
|
||
|
||
# Launch background thread (daemon so it doesn't block process exit)
|
||
thread = threading.Thread(
|
||
target=_run_training_background,
|
||
args=(run_id, request.model_type, config),
|
||
daemon=True,
|
||
name=f"training-{run_id[:8]}",
|
||
)
|
||
thread.start()
|
||
|
||
logger.info(f"Training started: run_id={run_id}, model_type={request.model_type}")
|
||
return TrainingStartResponse(run_id=run_id, status="running")
|
||
|
||
|
||
@app.get("/training/runs", response_model=TrainingRunsResponse, dependencies=[Depends(verify_api_key)])
|
||
async def training_runs():
|
||
"""
|
||
Return training run history from the database, sorted by date descending.
|
||
"""
|
||
try:
|
||
from sqlalchemy import select
|
||
with get_db() as db:
|
||
stmt = select(TrainingRun).order_by(desc(TrainingRun.created_at))
|
||
rows = db.execute(stmt).scalars().all()
|
||
|
||
runs = [
|
||
TrainingRunInfo(
|
||
run_id=row.run_id,
|
||
model_type=row.model_type,
|
||
status=row.status,
|
||
experiment_name=row.experiment_name,
|
||
created_at=row.created_at.isoformat() if row.created_at else None,
|
||
completed_at=row.completed_at.isoformat() if row.completed_at else None,
|
||
metrics_summary=row.metrics_summary,
|
||
)
|
||
for row in rows
|
||
]
|
||
except Exception as exc:
|
||
logger.error(f"Failed to fetch training runs: {exc}", exc_info=True)
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="Internal server error",
|
||
)
|
||
|
||
return TrainingRunsResponse(runs=runs)
|
||
|
||
|
||
class ActiveTrainingResponse(BaseModel):
|
||
"""Response model for GET /training/active."""
|
||
active: bool
|
||
run_id: Optional[str] = None
|
||
|
||
|
||
@app.get("/training/active", response_model=ActiveTrainingResponse, dependencies=[Depends(verify_api_key)])
|
||
async def training_active():
|
||
"""
|
||
Return whether a training run is currently active and its run_id.
|
||
"""
|
||
with state.training_lock:
|
||
run_id = state.active_training_run_id
|
||
return ActiveTrainingResponse(active=run_id is not None, run_id=run_id)
|
||
|
||
|
||
class DeleteRunResponse(BaseModel):
|
||
"""Response model for DELETE /training/runs/{run_id}."""
|
||
run_id: str
|
||
deleted: bool
|
||
|
||
|
||
@app.delete("/training/runs/{run_id}", response_model=DeleteRunResponse, dependencies=[Depends(verify_api_key)])
|
||
async def delete_training_run(run_id: str):
|
||
"""
|
||
Delete a training run record and its model artifact.
|
||
|
||
Returns HTTP 400 if the run_id format is invalid.
|
||
Returns HTTP 409 if the run is currently active.
|
||
Returns HTTP 404 if the run_id doesn't exist.
|
||
"""
|
||
from sqlalchemy import select, delete as sa_delete
|
||
|
||
# Validate run_id format to prevent path traversal
|
||
if not re.match(r'^[a-zA-Z0-9_-]+$', run_id):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="Invalid run_id format",
|
||
)
|
||
|
||
# Reject deletion of the active run
|
||
with state.training_lock:
|
||
if state.active_training_run_id == run_id:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_409_CONFLICT,
|
||
detail="Cannot delete an active training run",
|
||
)
|
||
|
||
try:
|
||
with get_db() as db:
|
||
stmt = select(TrainingRun).where(TrainingRun.run_id == run_id)
|
||
row = db.execute(stmt).scalar_one_or_none()
|
||
|
||
if row is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"Training run not found: {run_id}",
|
||
)
|
||
|
||
db.execute(sa_delete(TrainingRun).where(TrainingRun.run_id == run_id))
|
||
db.commit()
|
||
except HTTPException:
|
||
raise
|
||
except Exception as exc:
|
||
logger.error(f"Failed to delete training run {run_id}: {exc}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="Internal server error",
|
||
)
|
||
|
||
# Remove model artifact if it exists
|
||
models_base = Path("models").resolve()
|
||
model_path = (Path("models") / f"{run_id}.pkl").resolve()
|
||
if not str(model_path).startswith(str(models_base) + "/"):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="Invalid run_id format",
|
||
)
|
||
if model_path.exists():
|
||
try:
|
||
model_path.unlink()
|
||
logger.info(f"Deleted model artifact: {model_path}")
|
||
except Exception as exc:
|
||
logger.warning(f"Could not delete model artifact {model_path}: {exc}")
|
||
|
||
logger.info(f"Deleted training run: {run_id}")
|
||
return DeleteRunResponse(run_id=run_id, deleted=True)
|
||
|
||
|
||
@app.get("/training/dataset-info", response_model=DatasetInfoResponse, dependencies=[Depends(verify_api_key)])
|
||
async def training_dataset_info():
|
||
"""
|
||
Return information about the labeled training dataset.
|
||
|
||
Includes file path, existence, size, last modified date, and row count.
|
||
"""
|
||
config = state.pipeline_config or get_default_config()
|
||
labeled_path = Path(config.data.labeled_path)
|
||
|
||
if not labeled_path.exists():
|
||
return DatasetInfoResponse(path=str(labeled_path), exists=False)
|
||
|
||
try:
|
||
stat = labeled_path.stat()
|
||
size_bytes = stat.st_size
|
||
last_modified = datetime.fromtimestamp(stat.st_mtime).isoformat()
|
||
|
||
row_count = None
|
||
try:
|
||
# Read only one column for efficiency
|
||
df_head = pd.read_csv(labeled_path, usecols=[0])
|
||
row_count = len(df_head)
|
||
except Exception:
|
||
pass
|
||
|
||
return DatasetInfoResponse(
|
||
path=str(labeled_path),
|
||
exists=True,
|
||
size_bytes=size_bytes,
|
||
last_modified=last_modified,
|
||
row_count=row_count,
|
||
)
|
||
except Exception as exc:
|
||
logger.error(f"Failed to get dataset info: {exc}", exc_info=True)
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="Internal server error",
|
||
)
|
||
|
||
|
||
class BuildDatasetResponse(BaseModel):
|
||
"""Response model for POST /training/build-dataset."""
|
||
chart_name: str
|
||
n_candles: int
|
||
n_samples: int
|
||
n_features: int
|
||
labeled_path: str
|
||
|
||
|
||
@app.post("/training/build-dataset", response_model=BuildDatasetResponse, dependencies=[Depends(verify_api_key)])
|
||
async def training_build_dataset():
|
||
"""
|
||
Build the labeled training dataset from database annotations.
|
||
|
||
Exports candles, runs feature engineering, and ingests span annotations
|
||
into a labeled CSV ready for training.
|
||
"""
|
||
config = state.pipeline_config or get_default_config()
|
||
|
||
try:
|
||
result = build_dataset_from_db(config)
|
||
return BuildDatasetResponse(**result)
|
||
except ValueError as exc:
|
||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Internal server error")
|
||
except Exception as exc:
|
||
logger.error(f"Failed to build dataset: {exc}", exc_info=True)
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="Internal server error",
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Model Loading Endpoint
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class ModelLoadRequest(BaseModel):
|
||
"""Request model for POST /model/load."""
|
||
run_id: str = Field(..., description="Training run ID to load model from")
|
||
|
||
|
||
class ModelLoadResponse(BaseModel):
|
||
"""Response model for POST /model/load."""
|
||
run_id: str
|
||
model_type: str
|
||
status: str
|
||
|
||
|
||
# Lock protecting model hot-swap
|
||
_model_swap_lock = threading.Lock()
|
||
|
||
|
||
@app.post("/model/load", response_model=ModelLoadResponse, dependencies=[Depends(verify_api_key)])
|
||
async def model_load(request: ModelLoadRequest):
|
||
"""
|
||
Load a trained model from a completed training run.
|
||
|
||
Looks up the run_id in the training_runs table, loads the model artifact,
|
||
and replaces the active model in AppState with a brief lock to prevent
|
||
conflicts with in-flight prediction requests.
|
||
"""
|
||
from sqlalchemy import select
|
||
|
||
# 0. Validate run_id format to prevent path traversal
|
||
if not re.match(r'^[a-zA-Z0-9_-]+$', request.run_id):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="Invalid run_id format",
|
||
)
|
||
|
||
# 1. Look up the training run
|
||
try:
|
||
with get_db() as db:
|
||
stmt = select(TrainingRun).where(TrainingRun.run_id == request.run_id)
|
||
row = db.execute(stmt).scalar_one_or_none()
|
||
except Exception as exc:
|
||
logger.error(f"DB lookup failed for run_id={request.run_id}: {exc}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="Internal server error",
|
||
)
|
||
|
||
if row is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"Training run not found: {request.run_id}",
|
||
)
|
||
|
||
if row.status != "completed":
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=f"Training run is not completed (status={row.status})",
|
||
)
|
||
|
||
# 2. Resolve model artifact path
|
||
models_base = Path("models").resolve()
|
||
model_path = (Path("models") / f"{request.run_id}.pkl").resolve()
|
||
if not str(model_path).startswith(str(models_base) + "/"):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="Invalid run_id format",
|
||
)
|
||
if not model_path.exists():
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"Model artifact not found at {model_path}",
|
||
)
|
||
|
||
# 3. Load model (outside lock – can be slow)
|
||
try:
|
||
new_model, new_model_info = load_model_from_local(str(model_path))
|
||
except Exception as exc:
|
||
logger.error(f"Failed to load model for run_id={request.run_id}: {exc}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="Internal server error",
|
||
)
|
||
|
||
# 4. Thread-safe model swap (3.2 – brief lock)
|
||
with _model_swap_lock:
|
||
state.model = new_model
|
||
state.model_info = new_model_info
|
||
|
||
logger.info(f"Model hot-swapped: run_id={request.run_id}, type={row.model_type}")
|
||
|
||
return ModelLoadResponse(
|
||
run_id=request.run_id,
|
||
model_type=row.model_type,
|
||
status="loaded",
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
uvicorn.run(app, host="0.0.0.0", port=8001)
|