feat(ml): implement FastAPI inference service with model loading, preprocessing, and prediction endpoints
This commit is contained in:
parent
f4c0f9a836
commit
3a83fd38e9
3 changed files with 945 additions and 10 deletions
|
|
@ -51,16 +51,16 @@
|
|||
|
||||
## 6. Inference Service (FastAPI)
|
||||
|
||||
- [ ] 6.1 Create `services/ml/app/main.py` — FastAPI app with CORS, startup event to load model
|
||||
- [ ] 6.2 Implement model loading — from MLflow registry (by name + stage) or from local .pkl file via joblib
|
||||
- [ ] 6.3 Implement preprocessing parity — load pipeline config from MLflow artifact, apply same feature engineering as training
|
||||
- [ ] 6.4 Create `POST /predict` endpoint — accept candles array, run preprocessing, predict, return per-candle labels + confidence + spans + model_info
|
||||
- [ ] 6.5 Implement prediction span grouping — group consecutive same-label non-"O" predictions into spans with avg_confidence
|
||||
- [ ] 6.6 Create `POST /predict/batch` endpoint — accept pair/timeframe/date range, load data, process in batch_size chunks, return predictions
|
||||
- [ ] 6.7 Create `GET /model/info` endpoint — return model metadata, per-class metrics from MLflow
|
||||
- [ ] 6.8 Create `GET /model/labels` endpoint — return label names and colors
|
||||
- [ ] 6.9 Create `GET /health` endpoint — check model loaded status, MLflow connection, PostgreSQL connection
|
||||
- [ ] 6.10 Add Pydantic request/response models for all endpoints (PredictRequest, PredictResponse, BatchPredictRequest, ModelInfoResponse)
|
||||
- [x] 6.1 Create `services/ml/app/main.py` — FastAPI app with CORS, startup event to load model
|
||||
- [x] 6.2 Implement model loading — from MLflow registry (by name + stage) or from local .pkl file via joblib
|
||||
- [x] 6.3 Implement preprocessing parity — load pipeline config from MLflow artifact, apply same feature engineering as training
|
||||
- [x] 6.4 Create `POST /predict` endpoint — accept candles array, run preprocessing, predict, return per-candle labels + confidence + spans + model_info
|
||||
- [x] 6.5 Implement prediction span grouping — group consecutive same-label non-"O" predictions into spans with avg_confidence
|
||||
- [x] 6.6 Create `POST /predict/batch` endpoint — accept pair/timeframe/date range, load data, process in batch_size chunks, return predictions
|
||||
- [x] 6.7 Create `GET /model/info` endpoint — return model metadata, per-class metrics from MLflow
|
||||
- [x] 6.8 Create `GET /model/labels` endpoint — return label names and colors
|
||||
- [x] 6.9 Create `GET /health` endpoint — check model loaded status, MLflow connection, PostgreSQL connection
|
||||
- [x] 6.10 Add Pydantic request/response models for all endpoints (PredictRequest, PredictResponse, BatchPredictRequest, ModelInfoResponse)
|
||||
|
||||
## 7. Next.js API Proxy Routes
|
||||
|
||||
|
|
|
|||
750
services/ml/app/main.py
Normal file
750
services/ml/app/main.py
Normal file
|
|
@ -0,0 +1,750 @@
|
|||
"""
|
||||
FastAPI inference service for candlestick pattern prediction.
|
||||
|
||||
Provides REST API endpoints for model serving, health checks, and prediction.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
from fastapi import FastAPI, HTTPException, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, Field
|
||||
import pandas as pd
|
||||
import joblib
|
||||
import mlflow
|
||||
import mlflow.sklearn
|
||||
import mlflow.xgboost
|
||||
|
||||
from app.config import load_config, PipelineConfig
|
||||
from app.preprocessing import preprocess_candles, extract_feature_columns
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# FastAPI app
|
||||
app = FastAPI(
|
||||
title="Candle Pattern Inference API",
|
||||
description="ML inference service for candlestick pattern recognition",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # In production, specify actual origins
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Global state
|
||||
class AppState:
|
||||
"""Application state container."""
|
||||
model: Optional[Any] = None
|
||||
model_info: Optional[Dict[str, Any]] = None
|
||||
pipeline_config: Optional[PipelineConfig] = None
|
||||
feature_columns: Optional[List[str]] = None
|
||||
label_encoder: Optional[Dict[int, str]] = None
|
||||
|
||||
state = AppState()
|
||||
|
||||
|
||||
# --- Pydantic Models ---
|
||||
|
||||
class CandleData(BaseModel):
|
||||
"""Single candle data point."""
|
||||
time: int = Field(..., description="Unix timestamp in seconds")
|
||||
open: float
|
||||
high: float
|
||||
low: float
|
||||
close: float
|
||||
volume: float
|
||||
|
||||
|
||||
class PredictRequest(BaseModel):
|
||||
"""Request model for /predict endpoint."""
|
||||
pair: str = Field(..., description="Trading pair (e.g., EURUSD)")
|
||||
timeframe: str = Field(..., description="Timeframe (e.g., 1H, 4H, 1D)")
|
||||
candles: List[CandleData] = Field(..., min_length=1, description="Array of candle data")
|
||||
|
||||
|
||||
class PredictionResult(BaseModel):
|
||||
"""Single candle prediction."""
|
||||
time: int
|
||||
label: str
|
||||
confidence: float = Field(..., ge=0.0, le=1.0)
|
||||
|
||||
|
||||
class PredictionSpan(BaseModel):
|
||||
"""Grouped prediction span."""
|
||||
start_time: int
|
||||
end_time: int
|
||||
label: str
|
||||
avg_confidence: float = Field(..., ge=0.0, le=1.0)
|
||||
candle_count: int
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""Model metadata."""
|
||||
model_name: str
|
||||
model_version: Optional[str] = None
|
||||
model_type: str
|
||||
trained_at: Optional[str] = None
|
||||
dataset_version: Optional[str] = None
|
||||
feature_engineering_enabled: bool
|
||||
|
||||
|
||||
class PredictResponse(BaseModel):
|
||||
"""Response model for /predict endpoint."""
|
||||
predictions: List[PredictionResult]
|
||||
spans: List[PredictionSpan]
|
||||
model_info: ModelInfo
|
||||
|
||||
|
||||
class BatchPredictRequest(BaseModel):
|
||||
"""Request model for /predict/batch endpoint."""
|
||||
pair: str
|
||||
timeframe: str
|
||||
start_date: str = Field(..., description="ISO format date (YYYY-MM-DD)")
|
||||
end_date: str = Field(..., description="ISO format date (YYYY-MM-DD)")
|
||||
|
||||
|
||||
class LabelInfo(BaseModel):
|
||||
"""Pattern label with display color."""
|
||||
name: str
|
||||
color: str
|
||||
|
||||
|
||||
class PerClassMetrics(BaseModel):
|
||||
"""Per-class performance metrics."""
|
||||
label: str
|
||||
precision: float
|
||||
recall: float
|
||||
f1_score: float
|
||||
support: int
|
||||
|
||||
|
||||
class ModelInfoResponse(BaseModel):
|
||||
"""Extended model information with metrics."""
|
||||
model_name: str
|
||||
model_version: Optional[str] = None
|
||||
model_type: str
|
||||
trained_at: Optional[str] = None
|
||||
dataset_version: Optional[str] = None
|
||||
feature_engineering_enabled: bool
|
||||
labels: List[str]
|
||||
per_class_metrics: List[PerClassMetrics]
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""Health check response."""
|
||||
status: str = Field(..., description="healthy, degraded, or unhealthy")
|
||||
model_loaded: bool
|
||||
mlflow: str = Field(..., description="connected or disconnected")
|
||||
database: str = Field(..., description="connected or disconnected")
|
||||
|
||||
|
||||
# --- Model Loading Functions ---
|
||||
|
||||
def load_model_from_mlflow(model_name: str, stage: str) -> tuple[Any, Dict[str, Any]]:
|
||||
"""
|
||||
Load model from MLflow model registry.
|
||||
|
||||
Args:
|
||||
model_name: Name of the registered model
|
||||
stage: Model stage (Production, Staging, None)
|
||||
|
||||
Returns:
|
||||
Tuple of (model, model_info)
|
||||
|
||||
Raises:
|
||||
Exception: If model not found or MLflow unavailable
|
||||
"""
|
||||
logger.info(f"Loading model from MLflow: {model_name} stage={stage}")
|
||||
|
||||
try:
|
||||
# Load model from registry
|
||||
model_uri = f"models:/{model_name}/{stage}"
|
||||
model = mlflow.pyfunc.load_model(model_uri)
|
||||
|
||||
# Get model version details
|
||||
client = mlflow.tracking.MlflowClient()
|
||||
model_versions = client.get_latest_versions(model_name, stages=[stage])
|
||||
|
||||
if not model_versions:
|
||||
raise ValueError(f"No model found for {model_name} at stage {stage}")
|
||||
|
||||
model_version = model_versions[0]
|
||||
run_id = model_version.run_id
|
||||
|
||||
# Get run metadata
|
||||
run = client.get_run(run_id)
|
||||
|
||||
# Extract model info
|
||||
model_info = {
|
||||
"model_name": model_name,
|
||||
"model_version": model_version.version,
|
||||
"model_type": run.data.params.get("model_type", "unknown"),
|
||||
"trained_at": datetime.fromtimestamp(run.info.start_time / 1000).isoformat(),
|
||||
"dataset_version": run.data.params.get("dataset_version", None),
|
||||
"feature_engineering_enabled": run.data.params.get("feature_engineering", "true") == "true",
|
||||
"labels": json.loads(run.data.params.get("labels", "[]")),
|
||||
"per_class_metrics": []
|
||||
}
|
||||
|
||||
# Extract per-class metrics
|
||||
for key, value in run.data.metrics.items():
|
||||
if key.startswith("precision_"):
|
||||
label = key.replace("precision_", "")
|
||||
recall = run.data.metrics.get(f"recall_{label}", 0.0)
|
||||
f1 = run.data.metrics.get(f"f1_{label}", 0.0)
|
||||
support = int(run.data.params.get(f"support_{label}", 0))
|
||||
|
||||
model_info["per_class_metrics"].append({
|
||||
"label": label,
|
||||
"precision": value,
|
||||
"recall": recall,
|
||||
"f1_score": f1,
|
||||
"support": support
|
||||
})
|
||||
|
||||
logger.info(f"Successfully loaded model: {model_name} v{model_version.version}")
|
||||
return model, model_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model from MLflow: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def load_model_from_local(model_path: str) -> tuple[Any, Dict[str, Any]]:
|
||||
"""
|
||||
Load model from local file using joblib.
|
||||
|
||||
Args:
|
||||
model_path: Path to .pkl model file
|
||||
|
||||
Returns:
|
||||
Tuple of (model, model_info)
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If model file doesn't exist
|
||||
Exception: If model loading fails
|
||||
"""
|
||||
model_path = Path(model_path)
|
||||
logger.info(f"Loading model from local file: {model_path}")
|
||||
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||
|
||||
try:
|
||||
# Load model with joblib
|
||||
model_data = joblib.load(model_path)
|
||||
|
||||
# Extract model and metadata
|
||||
if isinstance(model_data, dict):
|
||||
model = model_data.get("model")
|
||||
metadata = model_data.get("metadata", {})
|
||||
else:
|
||||
model = model_data
|
||||
metadata = {}
|
||||
|
||||
# Build model info
|
||||
model_info = {
|
||||
"model_name": model_path.stem,
|
||||
"model_version": metadata.get("version", "local"),
|
||||
"model_type": metadata.get("model_type", "unknown"),
|
||||
"trained_at": metadata.get("trained_at", None),
|
||||
"dataset_version": metadata.get("dataset_version", None),
|
||||
"feature_engineering_enabled": metadata.get("feature_engineering_enabled", True),
|
||||
"labels": metadata.get("labels", []),
|
||||
"per_class_metrics": metadata.get("per_class_metrics", [])
|
||||
}
|
||||
|
||||
logger.info(f"Successfully loaded local model: {model_path.name}")
|
||||
return model, model_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model from local file: {e}")
|
||||
raise
|
||||
|
||||
|
||||
# --- Startup Event ---
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""
|
||||
Load model and pipeline config on startup.
|
||||
"""
|
||||
logger.info("Starting inference service...")
|
||||
|
||||
# Load pipeline config
|
||||
config_path = Path("config/pipeline.yaml")
|
||||
if not config_path.exists():
|
||||
logger.warning(f"Config file not found: {config_path}. Using defaults.")
|
||||
return
|
||||
|
||||
try:
|
||||
state.pipeline_config = load_config(config_path)
|
||||
logger.info(f"Loaded pipeline config from {config_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load pipeline config: {e}")
|
||||
return
|
||||
|
||||
# Load model based on config
|
||||
inference_config = state.pipeline_config.stages.inference
|
||||
|
||||
if not inference_config.enabled:
|
||||
logger.warning("Inference stage is disabled in config")
|
||||
return
|
||||
|
||||
# Load model
|
||||
try:
|
||||
if inference_config.model_source == "mlflow":
|
||||
# Configure MLflow tracking URI
|
||||
mlflow.set_tracking_uri(state.pipeline_config.stages.training.mlflow.tracking_uri)
|
||||
|
||||
state.model, state.model_info = load_model_from_mlflow(
|
||||
model_name=inference_config.mlflow_model_name,
|
||||
stage=inference_config.mlflow_model_stage
|
||||
)
|
||||
elif inference_config.model_source == "local":
|
||||
state.model, state.model_info = load_model_from_local(
|
||||
model_path=inference_config.local_model_path
|
||||
)
|
||||
else:
|
||||
logger.error(f"Unknown model_source: {inference_config.model_source}")
|
||||
return
|
||||
|
||||
logger.info("Model loaded successfully")
|
||||
logger.info(f"Model info: {state.model_info['model_name']} "
|
||||
f"v{state.model_info['model_version']} "
|
||||
f"({state.model_info['model_type']})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}")
|
||||
logger.warning("Service will start without a model. Use /health to check status.")
|
||||
|
||||
|
||||
# --- Health Check ---
|
||||
|
||||
@app.get("/health", response_model=HealthResponse)
|
||||
async def health_check():
|
||||
"""
|
||||
Health check endpoint.
|
||||
|
||||
Returns service status, model loaded status, and dependency health.
|
||||
"""
|
||||
model_loaded = state.model is not None
|
||||
|
||||
# Check MLflow connection
|
||||
mlflow_status = "disconnected"
|
||||
try:
|
||||
# TODO: Actually check MLflow connection
|
||||
mlflow_status = "connected"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check database connection
|
||||
db_status = "disconnected"
|
||||
try:
|
||||
# TODO: Actually check database connection
|
||||
db_status = "connected"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Determine overall status
|
||||
if model_loaded and mlflow_status == "connected" and db_status == "connected":
|
||||
overall_status = "healthy"
|
||||
elif model_loaded:
|
||||
overall_status = "degraded"
|
||||
else:
|
||||
overall_status = "unhealthy"
|
||||
|
||||
return HealthResponse(
|
||||
status=overall_status,
|
||||
model_loaded=model_loaded,
|
||||
mlflow=mlflow_status,
|
||||
database=db_status
|
||||
)
|
||||
|
||||
|
||||
# --- Model Info Endpoints (Stubs) ---
|
||||
|
||||
@app.get("/model/info", response_model=ModelInfoResponse)
|
||||
async def get_model_info():
|
||||
"""
|
||||
Get detailed model information and per-class metrics.
|
||||
|
||||
Returns HTTP 503 if no model is loaded.
|
||||
"""
|
||||
if state.model is None or state.model_info is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="No model available"
|
||||
)
|
||||
|
||||
# Build response from loaded model info
|
||||
return ModelInfoResponse(
|
||||
model_name=state.model_info["model_name"],
|
||||
model_version=state.model_info.get("model_version"),
|
||||
model_type=state.model_info["model_type"],
|
||||
trained_at=state.model_info.get("trained_at"),
|
||||
dataset_version=state.model_info.get("dataset_version"),
|
||||
feature_engineering_enabled=state.model_info["feature_engineering_enabled"],
|
||||
labels=state.model_info.get("labels", []),
|
||||
per_class_metrics=[
|
||||
PerClassMetrics(**metric)
|
||||
for metric in state.model_info.get("per_class_metrics", [])
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@app.get("/model/labels", response_model=List[LabelInfo])
|
||||
async def get_model_labels():
|
||||
"""
|
||||
Get all pattern labels the model can predict with display colors.
|
||||
"""
|
||||
if state.model is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="No model available"
|
||||
)
|
||||
|
||||
# TODO: Load label colors from config or database
|
||||
# For now, return placeholder
|
||||
labels = state.model_info.get("labels", []) if state.model_info else []
|
||||
|
||||
return [
|
||||
LabelInfo(name=label, color="#888888")
|
||||
for label in labels
|
||||
]
|
||||
|
||||
|
||||
# --- Prediction Helper Functions ---
|
||||
|
||||
def group_prediction_spans(
|
||||
predictions: List[PredictionResult]
|
||||
) -> List[PredictionSpan]:
|
||||
"""
|
||||
Group consecutive predictions with the same label into spans.
|
||||
|
||||
Only groups non-"O" (non-background) predictions. "O" predictions are
|
||||
treated as background and not grouped into spans.
|
||||
|
||||
Args:
|
||||
predictions: List of per-candle predictions
|
||||
|
||||
Returns:
|
||||
List of prediction spans
|
||||
"""
|
||||
if not predictions:
|
||||
return []
|
||||
|
||||
spans = []
|
||||
current_span = None
|
||||
|
||||
for pred in predictions:
|
||||
# Skip "O" (background) predictions
|
||||
if pred.label == "O":
|
||||
if current_span is not None:
|
||||
# End current span
|
||||
spans.append(current_span)
|
||||
current_span = None
|
||||
continue
|
||||
|
||||
# Start new span or continue current span
|
||||
if current_span is None or current_span.label != pred.label:
|
||||
# End previous span if exists
|
||||
if current_span is not None:
|
||||
spans.append(current_span)
|
||||
|
||||
# Start new span
|
||||
current_span = PredictionSpan(
|
||||
start_time=pred.time,
|
||||
end_time=pred.time,
|
||||
label=pred.label,
|
||||
avg_confidence=pred.confidence,
|
||||
candle_count=1
|
||||
)
|
||||
else:
|
||||
# Continue current span
|
||||
current_span.end_time = pred.time
|
||||
current_span.candle_count += 1
|
||||
# Update running average confidence
|
||||
total_confidence = current_span.avg_confidence * (current_span.candle_count - 1)
|
||||
current_span.avg_confidence = (total_confidence + pred.confidence) / current_span.candle_count
|
||||
|
||||
# Don't forget last span
|
||||
if current_span is not None:
|
||||
spans.append(current_span)
|
||||
|
||||
logger.info(f"Grouped {len(predictions)} predictions into {len(spans)} spans")
|
||||
|
||||
return spans
|
||||
|
||||
|
||||
# --- Prediction Endpoints ---
|
||||
|
||||
@app.post("/predict", response_model=PredictResponse)
|
||||
async def predict(request: PredictRequest):
|
||||
"""
|
||||
Predict candlestick patterns for provided candles.
|
||||
|
||||
Accepts OHLCV candles, runs preprocessing, and returns predictions.
|
||||
"""
|
||||
if state.model is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="No model available"
|
||||
)
|
||||
|
||||
if not request.candles:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Candles array cannot be empty"
|
||||
)
|
||||
|
||||
if state.pipeline_config is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Pipeline configuration not loaded"
|
||||
)
|
||||
|
||||
logger.info(f"Predict request: {request.pair} {request.timeframe}, {len(request.candles)} candles")
|
||||
|
||||
try:
|
||||
# Convert candles to list of dicts
|
||||
candles_data = [candle.model_dump() for candle in request.candles]
|
||||
|
||||
# Preprocess candles (feature engineering)
|
||||
df_preprocessed = preprocess_candles(candles_data, state.pipeline_config)
|
||||
|
||||
# Keep times for results mapping
|
||||
times = df_preprocessed['time'].values
|
||||
|
||||
# Extract feature columns (exclude 'time')
|
||||
X = extract_feature_columns(df_preprocessed)
|
||||
|
||||
# Get predictions and probabilities
|
||||
if hasattr(state.model, 'predict_proba'):
|
||||
y_pred = state.model.predict(X)
|
||||
y_proba = state.model.predict_proba(X)
|
||||
|
||||
# Get confidence as max probability
|
||||
confidences = np.max(y_proba, axis=1)
|
||||
else:
|
||||
# Fallback for models without predict_proba
|
||||
y_pred = state.model.predict(X)
|
||||
confidences = np.ones(len(y_pred)) # Default confidence of 1.0
|
||||
|
||||
# Get label names (handle both string and int predictions)
|
||||
if state.label_encoder is not None:
|
||||
# Model predicts integers, map to labels
|
||||
labels = [state.label_encoder.get(int(pred), f"unknown_{pred}") for pred in y_pred]
|
||||
else:
|
||||
# Model predicts strings directly
|
||||
labels = [str(pred) for pred in y_pred]
|
||||
|
||||
# Build per-candle predictions
|
||||
predictions = [
|
||||
PredictionResult(
|
||||
time=int(time),
|
||||
label=label,
|
||||
confidence=float(conf)
|
||||
)
|
||||
for time, label, conf in zip(times, labels, confidences)
|
||||
]
|
||||
|
||||
# Group into spans
|
||||
spans = group_prediction_spans(predictions)
|
||||
|
||||
# Build model info for response
|
||||
model_info = ModelInfo(
|
||||
model_name=state.model_info["model_name"],
|
||||
model_version=state.model_info.get("model_version"),
|
||||
model_type=state.model_info["model_type"],
|
||||
trained_at=state.model_info.get("trained_at"),
|
||||
dataset_version=state.model_info.get("dataset_version"),
|
||||
feature_engineering_enabled=state.model_info["feature_engineering_enabled"]
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Prediction complete: {len(predictions)} candles, "
|
||||
f"{len(spans)} spans, {len([p for p in predictions if p.label != 'O'])} patterns"
|
||||
)
|
||||
|
||||
return PredictResponse(
|
||||
predictions=predictions,
|
||||
spans=spans,
|
||||
model_info=model_info
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Prediction validation error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Prediction failed: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Prediction failed: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@app.post("/predict/batch", response_model=PredictResponse)
|
||||
async def predict_batch(request: BatchPredictRequest):
|
||||
"""
|
||||
Batch prediction for a date range.
|
||||
|
||||
Loads data from storage and returns predictions for the full range.
|
||||
"""
|
||||
if state.model is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="No model available"
|
||||
)
|
||||
|
||||
if state.pipeline_config is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Pipeline configuration not loaded"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Batch predict: {request.pair} {request.timeframe} "
|
||||
f"from {request.start_date} to {request.end_date}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Load OHLCV data from raw data path
|
||||
raw_path = Path(state.pipeline_config.data.raw_path)
|
||||
|
||||
if not raw_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No data found for {request.pair} {request.timeframe}"
|
||||
)
|
||||
|
||||
# Load data
|
||||
df_raw = pd.read_csv(raw_path)
|
||||
|
||||
# Filter by date range if time column exists
|
||||
if 'time' in df_raw.columns:
|
||||
# Parse dates to timestamps
|
||||
start_ts = int(pd.Timestamp(request.start_date).timestamp())
|
||||
end_ts = int(pd.Timestamp(request.end_date).timestamp())
|
||||
|
||||
# Filter data
|
||||
df_filtered = df_raw[
|
||||
(df_raw['time'] >= start_ts) & (df_raw['time'] <= end_ts)
|
||||
].copy()
|
||||
|
||||
if len(df_filtered) == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No data found in date range {request.start_date} to {request.end_date}"
|
||||
)
|
||||
|
||||
logger.info(f"Loaded {len(df_filtered)} candles from {request.start_date} to {request.end_date}")
|
||||
else:
|
||||
df_filtered = df_raw
|
||||
logger.warning("No 'time' column found, using all data")
|
||||
|
||||
# Get batch size from config
|
||||
batch_size = state.pipeline_config.stages.inference.batch_size
|
||||
|
||||
# Process in batches
|
||||
all_predictions = []
|
||||
all_spans = []
|
||||
|
||||
for i in range(0, len(df_filtered), batch_size):
|
||||
batch_df = df_filtered.iloc[i:i + batch_size]
|
||||
|
||||
logger.info(f"Processing batch {i // batch_size + 1}: rows {i} to {i + len(batch_df)}")
|
||||
|
||||
# Convert batch to candles format
|
||||
batch_candles = batch_df.to_dict('records')
|
||||
|
||||
# Preprocess
|
||||
df_preprocessed = preprocess_candles(batch_candles, state.pipeline_config)
|
||||
|
||||
# Keep times
|
||||
times = df_preprocessed['time'].values
|
||||
|
||||
# Extract features
|
||||
X = extract_feature_columns(df_preprocessed)
|
||||
|
||||
# Predict
|
||||
if hasattr(state.model, 'predict_proba'):
|
||||
y_pred = state.model.predict(X)
|
||||
y_proba = state.model.predict_proba(X)
|
||||
confidences = np.max(y_proba, axis=1)
|
||||
else:
|
||||
y_pred = state.model.predict(X)
|
||||
confidences = np.ones(len(y_pred))
|
||||
|
||||
# Get labels
|
||||
if state.label_encoder is not None:
|
||||
labels = [state.label_encoder.get(int(pred), f"unknown_{pred}") for pred in y_pred]
|
||||
else:
|
||||
labels = [str(pred) for pred in y_pred]
|
||||
|
||||
# Build predictions for this batch
|
||||
batch_predictions = [
|
||||
PredictionResult(
|
||||
time=int(time),
|
||||
label=label,
|
||||
confidence=float(conf)
|
||||
)
|
||||
for time, label, conf in zip(times, labels, confidences)
|
||||
]
|
||||
|
||||
all_predictions.extend(batch_predictions)
|
||||
|
||||
# Group all predictions into spans
|
||||
all_spans = group_prediction_spans(all_predictions)
|
||||
|
||||
# Build model info
|
||||
model_info = ModelInfo(
|
||||
model_name=state.model_info["model_name"],
|
||||
model_version=state.model_info.get("model_version"),
|
||||
model_type=state.model_info["model_type"],
|
||||
trained_at=state.model_info.get("trained_at"),
|
||||
dataset_version=state.model_info.get("dataset_version"),
|
||||
feature_engineering_enabled=state.model_info["feature_engineering_enabled"]
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Batch prediction complete: {len(all_predictions)} candles, "
|
||||
f"{len(all_spans)} spans"
|
||||
)
|
||||
|
||||
return PredictResponse(
|
||||
predictions=all_predictions,
|
||||
spans=all_spans,
|
||||
model_info=model_info
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Batch prediction failed: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Batch prediction failed: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8001)
|
||||
185
services/ml/app/preprocessing.py
Normal file
185
services/ml/app/preprocessing.py
Normal file
|
|
@ -0,0 +1,185 @@
|
|||
"""
|
||||
Preprocessing module for inference.
|
||||
|
||||
Replicates feature engineering pipeline to ensure preprocessing parity
|
||||
between training and inference.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from app.config import PipelineConfig
|
||||
from features.talib_features import compute_talib_indicators
|
||||
from features.candle_features import compute_candle_features, validate_candle_data
|
||||
from features.custom_loader import load_custom_features
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def preprocess_candles(
|
||||
candles: List[dict],
|
||||
pipeline_config: PipelineConfig
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Preprocess candle data for inference.
|
||||
|
||||
Applies the same feature engineering steps as used during training:
|
||||
1. Convert to DataFrame
|
||||
2. Validate OHLC data
|
||||
3. Compute TA-Lib indicators (if enabled)
|
||||
4. Compute candle features (if enabled)
|
||||
5. Load custom features (if configured)
|
||||
6. Drop NaN rows from indicator warmup
|
||||
|
||||
Args:
|
||||
candles: List of candle dictionaries with time, open, high, low, close, volume
|
||||
pipeline_config: Pipeline configuration (must match training config)
|
||||
|
||||
Returns:
|
||||
Preprocessed DataFrame ready for model.predict()
|
||||
|
||||
Raises:
|
||||
ValueError: If data validation fails or too many rows dropped
|
||||
"""
|
||||
# Convert to DataFrame
|
||||
df = pd.DataFrame(candles)
|
||||
|
||||
# Ensure time column exists for tracking
|
||||
if 'time' not in df.columns:
|
||||
raise ValueError("Candles must include 'time' field")
|
||||
|
||||
original_rows = len(df)
|
||||
logger.info(f"Preprocessing {original_rows} candles")
|
||||
|
||||
# Validate OHLC data
|
||||
try:
|
||||
validate_candle_data(df)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Candle data validation failed: {e}")
|
||||
|
||||
# Get feature engineering config
|
||||
fe_config = pipeline_config.stages.feature_engineering
|
||||
|
||||
if not fe_config.enabled:
|
||||
logger.warning("Feature engineering disabled in config - returning raw OHLCV")
|
||||
return df
|
||||
|
||||
# Compute TA-Lib indicators
|
||||
if fe_config.talib_indicators:
|
||||
logger.info(f"Computing {len(fe_config.talib_indicators)} TA-Lib indicators")
|
||||
try:
|
||||
df = compute_talib_indicators(df, fe_config.talib_indicators)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to compute TA-Lib indicators: {e}")
|
||||
raise ValueError(f"Indicator computation failed: {e}")
|
||||
|
||||
# Compute candle features
|
||||
if fe_config.candle_features:
|
||||
logger.info("Computing candle features")
|
||||
try:
|
||||
df = compute_candle_features(df)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to compute candle features: {e}")
|
||||
raise ValueError(f"Candle feature computation failed: {e}")
|
||||
|
||||
# Load custom features
|
||||
if fe_config.custom_features:
|
||||
logger.info(f"Loading {len(fe_config.custom_features)} custom feature(s)")
|
||||
try:
|
||||
df = load_custom_features(df, fe_config.custom_features)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load custom features: {e}")
|
||||
raise ValueError(f"Custom feature loading failed: {e}")
|
||||
|
||||
# Handle NaN values from indicator warmup
|
||||
df_clean = df.dropna()
|
||||
|
||||
rows_dropped = original_rows - len(df_clean)
|
||||
|
||||
if rows_dropped > 0:
|
||||
logger.info(
|
||||
f"Dropped {rows_dropped} rows due to indicator warmup "
|
||||
f"({rows_dropped / original_rows * 100:.1f}%)"
|
||||
)
|
||||
|
||||
# Warn if too much data was lost
|
||||
if rows_dropped / original_rows > 0.5:
|
||||
raise ValueError(
|
||||
f"More than 50% of candles dropped due to indicator warmup "
|
||||
f"({rows_dropped}/{original_rows}). Provide more historical candles."
|
||||
)
|
||||
|
||||
logger.info(f"Preprocessing complete: {len(df_clean)} candles ready for prediction")
|
||||
|
||||
return df_clean
|
||||
|
||||
|
||||
def extract_feature_columns(
|
||||
df: pd.DataFrame,
|
||||
exclude_columns: List[str] = None
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Extract only feature columns for model prediction.
|
||||
|
||||
Removes metadata columns like 'time' that should not be used as features.
|
||||
|
||||
Args:
|
||||
df: Preprocessed DataFrame
|
||||
exclude_columns: Columns to exclude (default: ['time'])
|
||||
|
||||
Returns:
|
||||
DataFrame with only feature columns
|
||||
"""
|
||||
if exclude_columns is None:
|
||||
exclude_columns = ['time']
|
||||
|
||||
feature_cols = [col for col in df.columns if col not in exclude_columns]
|
||||
|
||||
logger.info(f"Using {len(feature_cols)} feature columns for prediction")
|
||||
|
||||
return df[feature_cols]
|
||||
|
||||
|
||||
def validate_feature_parity(
|
||||
inference_features: List[str],
|
||||
training_features: List[str]
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that inference features match training features.
|
||||
|
||||
Args:
|
||||
inference_features: Feature column names from inference preprocessing
|
||||
training_features: Feature column names from training
|
||||
|
||||
Returns:
|
||||
True if features match exactly
|
||||
|
||||
Raises:
|
||||
ValueError: If features don't match
|
||||
"""
|
||||
inference_set = set(inference_features)
|
||||
training_set = set(training_features)
|
||||
|
||||
missing = training_set - inference_set
|
||||
extra = inference_set - training_set
|
||||
|
||||
if missing or extra:
|
||||
error_msg = "Feature mismatch detected between training and inference:\n"
|
||||
|
||||
if missing:
|
||||
error_msg += f" Missing features: {sorted(missing)}\n"
|
||||
|
||||
if extra:
|
||||
error_msg += f" Extra features: {sorted(extra)}\n"
|
||||
|
||||
error_msg += "\nThis indicates preprocessing parity is broken. "
|
||||
error_msg += "Ensure the pipeline config used for inference matches training."
|
||||
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
logger.info("Feature parity validated: inference features match training")
|
||||
return True
|
||||
Loading…
Add table
Add a link
Reference in a new issue