feat: add FastAPI model/load endpoint and all Next.js proxy routes (tasks 2-4)

This commit is contained in:
Marko Djordjevic 2026-02-17 18:47:04 +01:00
parent b8e649e333
commit 2a02669222
29 changed files with 1110 additions and 780 deletions

View file

@ -5,6 +5,8 @@ Provides REST API endpoints for model serving, health checks, and prediction.
"""
import logging
import threading
import uuid as uuid_lib
from pathlib import Path
from typing import Optional, Dict, Any, List
from datetime import datetime
@ -19,8 +21,10 @@ import joblib
import mlflow
import mlflow.sklearn
import mlflow.xgboost
from sqlalchemy import update as sa_update, desc
from app.config import load_config, PipelineConfig
from app.config import load_config, PipelineConfig, get_default_config
from app.db import get_db, TrainingRun
from app.preprocessing import preprocess_candles, extract_feature_columns
from app.patterns import (
TALIB_PATTERNS,
@ -60,7 +64,13 @@ class AppState:
pipeline_config: Optional[PipelineConfig] = None
feature_columns: Optional[List[str]] = None
label_encoder: Optional[Dict[int, str]] = None
# Training thread management
active_training_run_id: Optional[str] = None
training_lock: threading.Lock = None
def __init__(self):
self.training_lock = threading.Lock()
state = AppState()
@ -826,6 +836,413 @@ async def patterns_detect(request: DetectPatternsRequest):
)
# ---------------------------------------------------------------------------
# Training Endpoints
# ---------------------------------------------------------------------------
SUPPORTED_MODEL_TYPES = ["random_forest", "xgboost"]
class TrainingStartRequest(BaseModel):
"""Request model for POST /training/start."""
model_type: str = Field(
"random_forest",
description="Model type: random_forest or xgboost",
)
class TrainingStartResponse(BaseModel):
"""Response model for POST /training/start."""
run_id: str
status: str
class TrainingRunInfo(BaseModel):
"""Summary of a single training run."""
run_id: str
model_type: str
status: str
experiment_name: Optional[str] = None
created_at: Optional[str] = None
completed_at: Optional[str] = None
metrics_summary: Optional[Dict[str, Any]] = None
class TrainingRunsResponse(BaseModel):
"""Response model for GET /training/runs."""
runs: List[TrainingRunInfo]
class DatasetInfoResponse(BaseModel):
"""Response model for GET /training/dataset-info."""
path: str
exists: bool
size_bytes: Optional[int] = None
last_modified: Optional[str] = None
row_count: Optional[int] = None
def _run_training_background(run_id: str, model_type: str, config: PipelineConfig) -> None:
"""
Background thread target: train a model, update DB on completion or failure.
Uses the pre-inserted TrainingRun record identified by ``run_id``.
"""
logger.info(f"Training thread started: run_id={run_id}, model_type={model_type}")
try:
# Import training utilities here to avoid circular import issues
from training.train import create_model, temporal_split
from sklearn.metrics import accuracy_score, f1_score
labeled_path = Path(config.data.labeled_path)
if not labeled_path.exists():
raise FileNotFoundError(f"Labeled dataset not found: {labeled_path}")
# Load dataset
df = pd.read_csv(labeled_path)
if "label" not in df.columns:
raise ValueError("Labeled dataset must have 'label' column")
feature_cols = [
col for col in df.columns
if col not in ("label", "time", "timestamp")
and not col.startswith("label_programmatic_")
]
X = df[feature_cols].values
y = df["label"].values
logger.info(f"Loaded {len(X)} samples, {len(feature_cols)} features")
# Split data
training_cfg = config.stages.training
X_train, X_val, X_test, y_train, y_val, y_test = temporal_split(
X, y, training_cfg.test_split, training_cfg.validation_split
)
# Train model
model_instance = create_model(
model_type, training_cfg.hyperparameters, training_cfg.class_weights
)
model_instance.fit(X_train, y_train)
logger.info("Model training complete")
# Evaluate
y_val_pred = model_instance.predict(X_val)
y_test_pred = model_instance.predict(X_test)
metrics = {
"val_accuracy": float(accuracy_score(y_val, y_val_pred)),
"val_f1_macro": float(
f1_score(y_val, y_val_pred, average="macro", zero_division=0)
),
"test_accuracy": float(accuracy_score(y_test, y_test_pred)),
"test_f1_macro": float(
f1_score(y_test, y_test_pred, average="macro", zero_division=0)
),
"n_samples": int(len(X)),
"n_features": int(X.shape[1]),
}
# Save model locally
models_dir = Path("models")
models_dir.mkdir(exist_ok=True)
model_path = models_dir / f"{run_id}.pkl"
model_data = {
"model": model_instance,
"metadata": {
"model_type": model_type,
"trained_at": datetime.utcnow().isoformat(),
"run_id": run_id,
"feature_columns": feature_cols,
"labels": (
[str(c) for c in model_instance.model.classes_]
if hasattr(model_instance, "model") and hasattr(model_instance.model, "classes_")
else []
),
},
}
joblib.dump(model_data, model_path)
logger.info(f"Model saved to {model_path}")
# Update DB: completed
with get_db() as db:
stmt = (
sa_update(TrainingRun)
.where(TrainingRun.run_id == run_id)
.values(
status="completed",
completed_at=datetime.utcnow(),
metrics_summary=metrics,
)
)
db.execute(stmt)
db.commit()
logger.info(f"Training completed successfully: run_id={run_id}")
except Exception as exc:
logger.error(f"Training failed for run_id={run_id}: {exc}", exc_info=True)
try:
with get_db() as db:
stmt = (
sa_update(TrainingRun)
.where(TrainingRun.run_id == run_id)
.values(
status="failed",
completed_at=datetime.utcnow(),
metrics_summary={"error": str(exc)},
)
)
db.execute(stmt)
db.commit()
except Exception as db_exc:
logger.error(f"Failed to update DB for failed run {run_id}: {db_exc}")
finally:
with state.training_lock:
if state.active_training_run_id == run_id:
state.active_training_run_id = None
logger.info(f"Training thread exiting: run_id={run_id}")
@app.post("/training/start", response_model=TrainingStartResponse)
async def training_start(request: TrainingStartRequest):
"""
Start a training run in a background thread.
Returns immediately with run_id and status "running".
Rejects concurrent runs with HTTP 409.
"""
# Validate model type
if request.model_type not in SUPPORTED_MODEL_TYPES:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported model type. Available: {', '.join(SUPPORTED_MODEL_TYPES)}",
)
# Reject concurrent runs (atomic check-and-set)
with state.training_lock:
if state.active_training_run_id is not None:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail={
"error": "Training already in progress",
"run_id": state.active_training_run_id,
},
)
run_id = str(uuid_lib.uuid4())
state.active_training_run_id = run_id
config = state.pipeline_config or get_default_config()
# Compute config hash (best-effort)
config_hash = "unknown"
try:
from training.train import compute_config_hash
config_hash = compute_config_hash(config)
except Exception:
pass
# Pre-insert the run record so callers can track it immediately
try:
with get_db() as db:
training_run = TrainingRun(
run_id=run_id,
model_type=request.model_type,
experiment_name=config.stages.training.mlflow.experiment_name,
pipeline_config_hash=config_hash,
status="running",
created_at=datetime.utcnow(),
metrics_summary={},
)
db.add(training_run)
db.commit()
except Exception as exc:
with state.training_lock:
state.active_training_run_id = None
logger.error(f"Failed to insert training run record: {exc}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create training run record: {exc}",
)
# Launch background thread (daemon so it doesn't block process exit)
thread = threading.Thread(
target=_run_training_background,
args=(run_id, request.model_type, config),
daemon=True,
name=f"training-{run_id[:8]}",
)
thread.start()
logger.info(f"Training started: run_id={run_id}, model_type={request.model_type}")
return TrainingStartResponse(run_id=run_id, status="running")
@app.get("/training/runs", response_model=TrainingRunsResponse)
async def training_runs():
"""
Return training run history from the database, sorted by date descending.
"""
try:
from sqlalchemy import select
with get_db() as db:
stmt = select(TrainingRun).order_by(desc(TrainingRun.created_at))
rows = db.execute(stmt).scalars().all()
runs = [
TrainingRunInfo(
run_id=row.run_id,
model_type=row.model_type,
status=row.status,
experiment_name=row.experiment_name,
created_at=row.created_at.isoformat() if row.created_at else None,
completed_at=row.completed_at.isoformat() if row.completed_at else None,
metrics_summary=row.metrics_summary,
)
for row in rows
]
except Exception as exc:
logger.error(f"Failed to fetch training runs: {exc}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to fetch training runs: {exc}",
)
return TrainingRunsResponse(runs=runs)
@app.get("/training/dataset-info", response_model=DatasetInfoResponse)
async def training_dataset_info():
"""
Return information about the labeled training dataset.
Includes file path, existence, size, last modified date, and row count.
"""
config = state.pipeline_config or get_default_config()
labeled_path = Path(config.data.labeled_path)
if not labeled_path.exists():
return DatasetInfoResponse(path=str(labeled_path), exists=False)
try:
stat = labeled_path.stat()
size_bytes = stat.st_size
last_modified = datetime.fromtimestamp(stat.st_mtime).isoformat()
row_count = None
try:
# Read only one column for efficiency
df_head = pd.read_csv(labeled_path, usecols=[0])
row_count = len(df_head)
except Exception:
pass
return DatasetInfoResponse(
path=str(labeled_path),
exists=True,
size_bytes=size_bytes,
last_modified=last_modified,
row_count=row_count,
)
except Exception as exc:
logger.error(f"Failed to get dataset info: {exc}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get dataset info: {exc}",
)
# ---------------------------------------------------------------------------
# Model Loading Endpoint
# ---------------------------------------------------------------------------
class ModelLoadRequest(BaseModel):
"""Request model for POST /model/load."""
run_id: str = Field(..., description="Training run ID to load model from")
class ModelLoadResponse(BaseModel):
"""Response model for POST /model/load."""
run_id: str
model_type: str
status: str
# Lock protecting model hot-swap
_model_swap_lock = threading.Lock()
@app.post("/model/load", response_model=ModelLoadResponse)
async def model_load(request: ModelLoadRequest):
"""
Load a trained model from a completed training run.
Looks up the run_id in the training_runs table, loads the model artifact,
and replaces the active model in AppState with a brief lock to prevent
conflicts with in-flight prediction requests.
"""
from sqlalchemy import select
# 1. Look up the training run
try:
with get_db() as db:
stmt = select(TrainingRun).where(TrainingRun.run_id == request.run_id)
row = db.execute(stmt).scalar_one_or_none()
except Exception as exc:
logger.error(f"DB lookup failed for run_id={request.run_id}: {exc}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {exc}",
)
if row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Training run not found: {request.run_id}",
)
if row.status != "completed":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Training run is not completed (status={row.status})",
)
# 2. Resolve model artifact path
model_path = Path("models") / f"{request.run_id}.pkl"
if not model_path.exists():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Model artifact not found at {model_path}",
)
# 3. Load model (outside lock can be slow)
try:
new_model, new_model_info = load_model_from_local(str(model_path))
except Exception as exc:
logger.error(f"Failed to load model for run_id={request.run_id}: {exc}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to load model: {exc}",
)
# 4. Thread-safe model swap (3.2 brief lock)
with _model_swap_lock:
state.model = new_model
state.model_info = new_model_info
logger.info(f"Model hot-swapped: run_id={request.run_id}, type={row.model_type}")
return ModelLoadResponse(
run_id=request.run_id,
model_type=row.model_type,
status="loaded",
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)