feat: add FastAPI model/load endpoint and all Next.js proxy routes (tasks 2-4)
This commit is contained in:
parent
b8e649e333
commit
2a02669222
29 changed files with 1110 additions and 780 deletions
|
|
@ -5,6 +5,8 @@ Provides REST API endpoints for model serving, health checks, and prediction.
|
|||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import uuid as uuid_lib
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
|
|
@ -19,8 +21,10 @@ import joblib
|
|||
import mlflow
|
||||
import mlflow.sklearn
|
||||
import mlflow.xgboost
|
||||
from sqlalchemy import update as sa_update, desc
|
||||
|
||||
from app.config import load_config, PipelineConfig
|
||||
from app.config import load_config, PipelineConfig, get_default_config
|
||||
from app.db import get_db, TrainingRun
|
||||
from app.preprocessing import preprocess_candles, extract_feature_columns
|
||||
from app.patterns import (
|
||||
TALIB_PATTERNS,
|
||||
|
|
@ -60,7 +64,13 @@ class AppState:
|
|||
pipeline_config: Optional[PipelineConfig] = None
|
||||
feature_columns: Optional[List[str]] = None
|
||||
label_encoder: Optional[Dict[int, str]] = None
|
||||
|
||||
# Training thread management
|
||||
active_training_run_id: Optional[str] = None
|
||||
training_lock: threading.Lock = None
|
||||
|
||||
def __init__(self):
|
||||
self.training_lock = threading.Lock()
|
||||
|
||||
state = AppState()
|
||||
|
||||
|
||||
|
|
@ -826,6 +836,413 @@ async def patterns_detect(request: DetectPatternsRequest):
|
|||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Training Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SUPPORTED_MODEL_TYPES = ["random_forest", "xgboost"]
|
||||
|
||||
|
||||
class TrainingStartRequest(BaseModel):
|
||||
"""Request model for POST /training/start."""
|
||||
model_type: str = Field(
|
||||
"random_forest",
|
||||
description="Model type: random_forest or xgboost",
|
||||
)
|
||||
|
||||
|
||||
class TrainingStartResponse(BaseModel):
|
||||
"""Response model for POST /training/start."""
|
||||
run_id: str
|
||||
status: str
|
||||
|
||||
|
||||
class TrainingRunInfo(BaseModel):
|
||||
"""Summary of a single training run."""
|
||||
run_id: str
|
||||
model_type: str
|
||||
status: str
|
||||
experiment_name: Optional[str] = None
|
||||
created_at: Optional[str] = None
|
||||
completed_at: Optional[str] = None
|
||||
metrics_summary: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class TrainingRunsResponse(BaseModel):
|
||||
"""Response model for GET /training/runs."""
|
||||
runs: List[TrainingRunInfo]
|
||||
|
||||
|
||||
class DatasetInfoResponse(BaseModel):
|
||||
"""Response model for GET /training/dataset-info."""
|
||||
path: str
|
||||
exists: bool
|
||||
size_bytes: Optional[int] = None
|
||||
last_modified: Optional[str] = None
|
||||
row_count: Optional[int] = None
|
||||
|
||||
|
||||
def _run_training_background(run_id: str, model_type: str, config: PipelineConfig) -> None:
|
||||
"""
|
||||
Background thread target: train a model, update DB on completion or failure.
|
||||
|
||||
Uses the pre-inserted TrainingRun record identified by ``run_id``.
|
||||
"""
|
||||
logger.info(f"Training thread started: run_id={run_id}, model_type={model_type}")
|
||||
|
||||
try:
|
||||
# Import training utilities here to avoid circular import issues
|
||||
from training.train import create_model, temporal_split
|
||||
from sklearn.metrics import accuracy_score, f1_score
|
||||
|
||||
labeled_path = Path(config.data.labeled_path)
|
||||
if not labeled_path.exists():
|
||||
raise FileNotFoundError(f"Labeled dataset not found: {labeled_path}")
|
||||
|
||||
# Load dataset
|
||||
df = pd.read_csv(labeled_path)
|
||||
|
||||
if "label" not in df.columns:
|
||||
raise ValueError("Labeled dataset must have 'label' column")
|
||||
|
||||
feature_cols = [
|
||||
col for col in df.columns
|
||||
if col not in ("label", "time", "timestamp")
|
||||
and not col.startswith("label_programmatic_")
|
||||
]
|
||||
|
||||
X = df[feature_cols].values
|
||||
y = df["label"].values
|
||||
|
||||
logger.info(f"Loaded {len(X)} samples, {len(feature_cols)} features")
|
||||
|
||||
# Split data
|
||||
training_cfg = config.stages.training
|
||||
X_train, X_val, X_test, y_train, y_val, y_test = temporal_split(
|
||||
X, y, training_cfg.test_split, training_cfg.validation_split
|
||||
)
|
||||
|
||||
# Train model
|
||||
model_instance = create_model(
|
||||
model_type, training_cfg.hyperparameters, training_cfg.class_weights
|
||||
)
|
||||
model_instance.fit(X_train, y_train)
|
||||
logger.info("Model training complete")
|
||||
|
||||
# Evaluate
|
||||
y_val_pred = model_instance.predict(X_val)
|
||||
y_test_pred = model_instance.predict(X_test)
|
||||
|
||||
metrics = {
|
||||
"val_accuracy": float(accuracy_score(y_val, y_val_pred)),
|
||||
"val_f1_macro": float(
|
||||
f1_score(y_val, y_val_pred, average="macro", zero_division=0)
|
||||
),
|
||||
"test_accuracy": float(accuracy_score(y_test, y_test_pred)),
|
||||
"test_f1_macro": float(
|
||||
f1_score(y_test, y_test_pred, average="macro", zero_division=0)
|
||||
),
|
||||
"n_samples": int(len(X)),
|
||||
"n_features": int(X.shape[1]),
|
||||
}
|
||||
|
||||
# Save model locally
|
||||
models_dir = Path("models")
|
||||
models_dir.mkdir(exist_ok=True)
|
||||
model_path = models_dir / f"{run_id}.pkl"
|
||||
|
||||
model_data = {
|
||||
"model": model_instance,
|
||||
"metadata": {
|
||||
"model_type": model_type,
|
||||
"trained_at": datetime.utcnow().isoformat(),
|
||||
"run_id": run_id,
|
||||
"feature_columns": feature_cols,
|
||||
"labels": (
|
||||
[str(c) for c in model_instance.model.classes_]
|
||||
if hasattr(model_instance, "model") and hasattr(model_instance.model, "classes_")
|
||||
else []
|
||||
),
|
||||
},
|
||||
}
|
||||
joblib.dump(model_data, model_path)
|
||||
logger.info(f"Model saved to {model_path}")
|
||||
|
||||
# Update DB: completed
|
||||
with get_db() as db:
|
||||
stmt = (
|
||||
sa_update(TrainingRun)
|
||||
.where(TrainingRun.run_id == run_id)
|
||||
.values(
|
||||
status="completed",
|
||||
completed_at=datetime.utcnow(),
|
||||
metrics_summary=metrics,
|
||||
)
|
||||
)
|
||||
db.execute(stmt)
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Training completed successfully: run_id={run_id}")
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Training failed for run_id={run_id}: {exc}", exc_info=True)
|
||||
try:
|
||||
with get_db() as db:
|
||||
stmt = (
|
||||
sa_update(TrainingRun)
|
||||
.where(TrainingRun.run_id == run_id)
|
||||
.values(
|
||||
status="failed",
|
||||
completed_at=datetime.utcnow(),
|
||||
metrics_summary={"error": str(exc)},
|
||||
)
|
||||
)
|
||||
db.execute(stmt)
|
||||
db.commit()
|
||||
except Exception as db_exc:
|
||||
logger.error(f"Failed to update DB for failed run {run_id}: {db_exc}")
|
||||
finally:
|
||||
with state.training_lock:
|
||||
if state.active_training_run_id == run_id:
|
||||
state.active_training_run_id = None
|
||||
logger.info(f"Training thread exiting: run_id={run_id}")
|
||||
|
||||
|
||||
@app.post("/training/start", response_model=TrainingStartResponse)
|
||||
async def training_start(request: TrainingStartRequest):
|
||||
"""
|
||||
Start a training run in a background thread.
|
||||
|
||||
Returns immediately with run_id and status "running".
|
||||
Rejects concurrent runs with HTTP 409.
|
||||
"""
|
||||
# Validate model type
|
||||
if request.model_type not in SUPPORTED_MODEL_TYPES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unsupported model type. Available: {', '.join(SUPPORTED_MODEL_TYPES)}",
|
||||
)
|
||||
|
||||
# Reject concurrent runs (atomic check-and-set)
|
||||
with state.training_lock:
|
||||
if state.active_training_run_id is not None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail={
|
||||
"error": "Training already in progress",
|
||||
"run_id": state.active_training_run_id,
|
||||
},
|
||||
)
|
||||
run_id = str(uuid_lib.uuid4())
|
||||
state.active_training_run_id = run_id
|
||||
|
||||
config = state.pipeline_config or get_default_config()
|
||||
|
||||
# Compute config hash (best-effort)
|
||||
config_hash = "unknown"
|
||||
try:
|
||||
from training.train import compute_config_hash
|
||||
config_hash = compute_config_hash(config)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Pre-insert the run record so callers can track it immediately
|
||||
try:
|
||||
with get_db() as db:
|
||||
training_run = TrainingRun(
|
||||
run_id=run_id,
|
||||
model_type=request.model_type,
|
||||
experiment_name=config.stages.training.mlflow.experiment_name,
|
||||
pipeline_config_hash=config_hash,
|
||||
status="running",
|
||||
created_at=datetime.utcnow(),
|
||||
metrics_summary={},
|
||||
)
|
||||
db.add(training_run)
|
||||
db.commit()
|
||||
except Exception as exc:
|
||||
with state.training_lock:
|
||||
state.active_training_run_id = None
|
||||
logger.error(f"Failed to insert training run record: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to create training run record: {exc}",
|
||||
)
|
||||
|
||||
# Launch background thread (daemon so it doesn't block process exit)
|
||||
thread = threading.Thread(
|
||||
target=_run_training_background,
|
||||
args=(run_id, request.model_type, config),
|
||||
daemon=True,
|
||||
name=f"training-{run_id[:8]}",
|
||||
)
|
||||
thread.start()
|
||||
|
||||
logger.info(f"Training started: run_id={run_id}, model_type={request.model_type}")
|
||||
return TrainingStartResponse(run_id=run_id, status="running")
|
||||
|
||||
|
||||
@app.get("/training/runs", response_model=TrainingRunsResponse)
|
||||
async def training_runs():
|
||||
"""
|
||||
Return training run history from the database, sorted by date descending.
|
||||
"""
|
||||
try:
|
||||
from sqlalchemy import select
|
||||
with get_db() as db:
|
||||
stmt = select(TrainingRun).order_by(desc(TrainingRun.created_at))
|
||||
rows = db.execute(stmt).scalars().all()
|
||||
|
||||
runs = [
|
||||
TrainingRunInfo(
|
||||
run_id=row.run_id,
|
||||
model_type=row.model_type,
|
||||
status=row.status,
|
||||
experiment_name=row.experiment_name,
|
||||
created_at=row.created_at.isoformat() if row.created_at else None,
|
||||
completed_at=row.completed_at.isoformat() if row.completed_at else None,
|
||||
metrics_summary=row.metrics_summary,
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to fetch training runs: {exc}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to fetch training runs: {exc}",
|
||||
)
|
||||
|
||||
return TrainingRunsResponse(runs=runs)
|
||||
|
||||
|
||||
@app.get("/training/dataset-info", response_model=DatasetInfoResponse)
|
||||
async def training_dataset_info():
|
||||
"""
|
||||
Return information about the labeled training dataset.
|
||||
|
||||
Includes file path, existence, size, last modified date, and row count.
|
||||
"""
|
||||
config = state.pipeline_config or get_default_config()
|
||||
labeled_path = Path(config.data.labeled_path)
|
||||
|
||||
if not labeled_path.exists():
|
||||
return DatasetInfoResponse(path=str(labeled_path), exists=False)
|
||||
|
||||
try:
|
||||
stat = labeled_path.stat()
|
||||
size_bytes = stat.st_size
|
||||
last_modified = datetime.fromtimestamp(stat.st_mtime).isoformat()
|
||||
|
||||
row_count = None
|
||||
try:
|
||||
# Read only one column for efficiency
|
||||
df_head = pd.read_csv(labeled_path, usecols=[0])
|
||||
row_count = len(df_head)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return DatasetInfoResponse(
|
||||
path=str(labeled_path),
|
||||
exists=True,
|
||||
size_bytes=size_bytes,
|
||||
last_modified=last_modified,
|
||||
row_count=row_count,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to get dataset info: {exc}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get dataset info: {exc}",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model Loading Endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ModelLoadRequest(BaseModel):
|
||||
"""Request model for POST /model/load."""
|
||||
run_id: str = Field(..., description="Training run ID to load model from")
|
||||
|
||||
|
||||
class ModelLoadResponse(BaseModel):
|
||||
"""Response model for POST /model/load."""
|
||||
run_id: str
|
||||
model_type: str
|
||||
status: str
|
||||
|
||||
|
||||
# Lock protecting model hot-swap
|
||||
_model_swap_lock = threading.Lock()
|
||||
|
||||
|
||||
@app.post("/model/load", response_model=ModelLoadResponse)
|
||||
async def model_load(request: ModelLoadRequest):
|
||||
"""
|
||||
Load a trained model from a completed training run.
|
||||
|
||||
Looks up the run_id in the training_runs table, loads the model artifact,
|
||||
and replaces the active model in AppState with a brief lock to prevent
|
||||
conflicts with in-flight prediction requests.
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
|
||||
# 1. Look up the training run
|
||||
try:
|
||||
with get_db() as db:
|
||||
stmt = select(TrainingRun).where(TrainingRun.run_id == request.run_id)
|
||||
row = db.execute(stmt).scalar_one_or_none()
|
||||
except Exception as exc:
|
||||
logger.error(f"DB lookup failed for run_id={request.run_id}: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Database error: {exc}",
|
||||
)
|
||||
|
||||
if row is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Training run not found: {request.run_id}",
|
||||
)
|
||||
|
||||
if row.status != "completed":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Training run is not completed (status={row.status})",
|
||||
)
|
||||
|
||||
# 2. Resolve model artifact path
|
||||
model_path = Path("models") / f"{request.run_id}.pkl"
|
||||
if not model_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Model artifact not found at {model_path}",
|
||||
)
|
||||
|
||||
# 3. Load model (outside lock – can be slow)
|
||||
try:
|
||||
new_model, new_model_info = load_model_from_local(str(model_path))
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to load model for run_id={request.run_id}: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to load model: {exc}",
|
||||
)
|
||||
|
||||
# 4. Thread-safe model swap (3.2 – brief lock)
|
||||
with _model_swap_lock:
|
||||
state.model = new_model
|
||||
state.model_info = new_model_info
|
||||
|
||||
logger.info(f"Model hot-swapped: run_id={request.run_id}, type={row.model_type}")
|
||||
|
||||
return ModelLoadResponse(
|
||||
run_id=request.run_id,
|
||||
model_type=row.model_type,
|
||||
status="loaded",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8001)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue