750 lines
24 KiB
Python
750 lines
24 KiB
Python
"""
|
|
FastAPI inference service for candlestick pattern prediction.
|
|
|
|
Provides REST API endpoints for model serving, health checks, and prediction.
|
|
"""
|
|
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Optional, Dict, Any, List
|
|
from datetime import datetime
|
|
import json
|
|
|
|
from fastapi import FastAPI, HTTPException, status
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel, Field
|
|
import pandas as pd
|
|
import joblib
|
|
import mlflow
|
|
import mlflow.sklearn
|
|
import mlflow.xgboost
|
|
|
|
from app.config import load_config, PipelineConfig
|
|
from app.preprocessing import preprocess_candles, extract_feature_columns
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# FastAPI app
|
|
app = FastAPI(
|
|
title="Candle Pattern Inference API",
|
|
description="ML inference service for candlestick pattern recognition",
|
|
version="1.0.0"
|
|
)
|
|
|
|
# CORS middleware
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"], # In production, specify actual origins
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Global state
|
|
class AppState:
|
|
"""Application state container."""
|
|
model: Optional[Any] = None
|
|
model_info: Optional[Dict[str, Any]] = None
|
|
pipeline_config: Optional[PipelineConfig] = None
|
|
feature_columns: Optional[List[str]] = None
|
|
label_encoder: Optional[Dict[int, str]] = None
|
|
|
|
state = AppState()
|
|
|
|
|
|
# --- Pydantic Models ---
|
|
|
|
class CandleData(BaseModel):
|
|
"""Single candle data point."""
|
|
time: int = Field(..., description="Unix timestamp in seconds")
|
|
open: float
|
|
high: float
|
|
low: float
|
|
close: float
|
|
volume: float
|
|
|
|
|
|
class PredictRequest(BaseModel):
|
|
"""Request model for /predict endpoint."""
|
|
pair: str = Field(..., description="Trading pair (e.g., EURUSD)")
|
|
timeframe: str = Field(..., description="Timeframe (e.g., 1H, 4H, 1D)")
|
|
candles: List[CandleData] = Field(..., min_length=1, description="Array of candle data")
|
|
|
|
|
|
class PredictionResult(BaseModel):
|
|
"""Single candle prediction."""
|
|
time: int
|
|
label: str
|
|
confidence: float = Field(..., ge=0.0, le=1.0)
|
|
|
|
|
|
class PredictionSpan(BaseModel):
|
|
"""Grouped prediction span."""
|
|
start_time: int
|
|
end_time: int
|
|
label: str
|
|
avg_confidence: float = Field(..., ge=0.0, le=1.0)
|
|
candle_count: int
|
|
|
|
|
|
class ModelInfo(BaseModel):
|
|
"""Model metadata."""
|
|
model_name: str
|
|
model_version: Optional[str] = None
|
|
model_type: str
|
|
trained_at: Optional[str] = None
|
|
dataset_version: Optional[str] = None
|
|
feature_engineering_enabled: bool
|
|
|
|
|
|
class PredictResponse(BaseModel):
|
|
"""Response model for /predict endpoint."""
|
|
predictions: List[PredictionResult]
|
|
spans: List[PredictionSpan]
|
|
model_info: ModelInfo
|
|
|
|
|
|
class BatchPredictRequest(BaseModel):
|
|
"""Request model for /predict/batch endpoint."""
|
|
pair: str
|
|
timeframe: str
|
|
start_date: str = Field(..., description="ISO format date (YYYY-MM-DD)")
|
|
end_date: str = Field(..., description="ISO format date (YYYY-MM-DD)")
|
|
|
|
|
|
class LabelInfo(BaseModel):
|
|
"""Pattern label with display color."""
|
|
name: str
|
|
color: str
|
|
|
|
|
|
class PerClassMetrics(BaseModel):
|
|
"""Per-class performance metrics."""
|
|
label: str
|
|
precision: float
|
|
recall: float
|
|
f1_score: float
|
|
support: int
|
|
|
|
|
|
class ModelInfoResponse(BaseModel):
|
|
"""Extended model information with metrics."""
|
|
model_name: str
|
|
model_version: Optional[str] = None
|
|
model_type: str
|
|
trained_at: Optional[str] = None
|
|
dataset_version: Optional[str] = None
|
|
feature_engineering_enabled: bool
|
|
labels: List[str]
|
|
per_class_metrics: List[PerClassMetrics]
|
|
|
|
|
|
class HealthResponse(BaseModel):
|
|
"""Health check response."""
|
|
status: str = Field(..., description="healthy, degraded, or unhealthy")
|
|
model_loaded: bool
|
|
mlflow: str = Field(..., description="connected or disconnected")
|
|
database: str = Field(..., description="connected or disconnected")
|
|
|
|
|
|
# --- Model Loading Functions ---
|
|
|
|
def load_model_from_mlflow(model_name: str, stage: str) -> tuple[Any, Dict[str, Any]]:
|
|
"""
|
|
Load model from MLflow model registry.
|
|
|
|
Args:
|
|
model_name: Name of the registered model
|
|
stage: Model stage (Production, Staging, None)
|
|
|
|
Returns:
|
|
Tuple of (model, model_info)
|
|
|
|
Raises:
|
|
Exception: If model not found or MLflow unavailable
|
|
"""
|
|
logger.info(f"Loading model from MLflow: {model_name} stage={stage}")
|
|
|
|
try:
|
|
# Load model from registry
|
|
model_uri = f"models:/{model_name}/{stage}"
|
|
model = mlflow.pyfunc.load_model(model_uri)
|
|
|
|
# Get model version details
|
|
client = mlflow.tracking.MlflowClient()
|
|
model_versions = client.get_latest_versions(model_name, stages=[stage])
|
|
|
|
if not model_versions:
|
|
raise ValueError(f"No model found for {model_name} at stage {stage}")
|
|
|
|
model_version = model_versions[0]
|
|
run_id = model_version.run_id
|
|
|
|
# Get run metadata
|
|
run = client.get_run(run_id)
|
|
|
|
# Extract model info
|
|
model_info = {
|
|
"model_name": model_name,
|
|
"model_version": model_version.version,
|
|
"model_type": run.data.params.get("model_type", "unknown"),
|
|
"trained_at": datetime.fromtimestamp(run.info.start_time / 1000).isoformat(),
|
|
"dataset_version": run.data.params.get("dataset_version", None),
|
|
"feature_engineering_enabled": run.data.params.get("feature_engineering", "true") == "true",
|
|
"labels": json.loads(run.data.params.get("labels", "[]")),
|
|
"per_class_metrics": []
|
|
}
|
|
|
|
# Extract per-class metrics
|
|
for key, value in run.data.metrics.items():
|
|
if key.startswith("precision_"):
|
|
label = key.replace("precision_", "")
|
|
recall = run.data.metrics.get(f"recall_{label}", 0.0)
|
|
f1 = run.data.metrics.get(f"f1_{label}", 0.0)
|
|
support = int(run.data.params.get(f"support_{label}", 0))
|
|
|
|
model_info["per_class_metrics"].append({
|
|
"label": label,
|
|
"precision": value,
|
|
"recall": recall,
|
|
"f1_score": f1,
|
|
"support": support
|
|
})
|
|
|
|
logger.info(f"Successfully loaded model: {model_name} v{model_version.version}")
|
|
return model, model_info
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load model from MLflow: {e}")
|
|
raise
|
|
|
|
|
|
def load_model_from_local(model_path: str) -> tuple[Any, Dict[str, Any]]:
|
|
"""
|
|
Load model from local file using joblib.
|
|
|
|
Args:
|
|
model_path: Path to .pkl model file
|
|
|
|
Returns:
|
|
Tuple of (model, model_info)
|
|
|
|
Raises:
|
|
FileNotFoundError: If model file doesn't exist
|
|
Exception: If model loading fails
|
|
"""
|
|
model_path = Path(model_path)
|
|
logger.info(f"Loading model from local file: {model_path}")
|
|
|
|
if not model_path.exists():
|
|
raise FileNotFoundError(f"Model file not found: {model_path}")
|
|
|
|
try:
|
|
# Load model with joblib
|
|
model_data = joblib.load(model_path)
|
|
|
|
# Extract model and metadata
|
|
if isinstance(model_data, dict):
|
|
model = model_data.get("model")
|
|
metadata = model_data.get("metadata", {})
|
|
else:
|
|
model = model_data
|
|
metadata = {}
|
|
|
|
# Build model info
|
|
model_info = {
|
|
"model_name": model_path.stem,
|
|
"model_version": metadata.get("version", "local"),
|
|
"model_type": metadata.get("model_type", "unknown"),
|
|
"trained_at": metadata.get("trained_at", None),
|
|
"dataset_version": metadata.get("dataset_version", None),
|
|
"feature_engineering_enabled": metadata.get("feature_engineering_enabled", True),
|
|
"labels": metadata.get("labels", []),
|
|
"per_class_metrics": metadata.get("per_class_metrics", [])
|
|
}
|
|
|
|
logger.info(f"Successfully loaded local model: {model_path.name}")
|
|
return model, model_info
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load model from local file: {e}")
|
|
raise
|
|
|
|
|
|
# --- Startup Event ---
|
|
|
|
@app.on_event("startup")
|
|
async def startup_event():
|
|
"""
|
|
Load model and pipeline config on startup.
|
|
"""
|
|
logger.info("Starting inference service...")
|
|
|
|
# Load pipeline config
|
|
config_path = Path("config/pipeline.yaml")
|
|
if not config_path.exists():
|
|
logger.warning(f"Config file not found: {config_path}. Using defaults.")
|
|
return
|
|
|
|
try:
|
|
state.pipeline_config = load_config(config_path)
|
|
logger.info(f"Loaded pipeline config from {config_path}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to load pipeline config: {e}")
|
|
return
|
|
|
|
# Load model based on config
|
|
inference_config = state.pipeline_config.stages.inference
|
|
|
|
if not inference_config.enabled:
|
|
logger.warning("Inference stage is disabled in config")
|
|
return
|
|
|
|
# Load model
|
|
try:
|
|
if inference_config.model_source == "mlflow":
|
|
# Configure MLflow tracking URI
|
|
mlflow.set_tracking_uri(state.pipeline_config.stages.training.mlflow.tracking_uri)
|
|
|
|
state.model, state.model_info = load_model_from_mlflow(
|
|
model_name=inference_config.mlflow_model_name,
|
|
stage=inference_config.mlflow_model_stage
|
|
)
|
|
elif inference_config.model_source == "local":
|
|
state.model, state.model_info = load_model_from_local(
|
|
model_path=inference_config.local_model_path
|
|
)
|
|
else:
|
|
logger.error(f"Unknown model_source: {inference_config.model_source}")
|
|
return
|
|
|
|
logger.info("Model loaded successfully")
|
|
logger.info(f"Model info: {state.model_info['model_name']} "
|
|
f"v{state.model_info['model_version']} "
|
|
f"({state.model_info['model_type']})")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load model: {e}")
|
|
logger.warning("Service will start without a model. Use /health to check status.")
|
|
|
|
|
|
# --- Health Check ---
|
|
|
|
@app.get("/health", response_model=HealthResponse)
|
|
async def health_check():
|
|
"""
|
|
Health check endpoint.
|
|
|
|
Returns service status, model loaded status, and dependency health.
|
|
"""
|
|
model_loaded = state.model is not None
|
|
|
|
# Check MLflow connection
|
|
mlflow_status = "disconnected"
|
|
try:
|
|
# TODO: Actually check MLflow connection
|
|
mlflow_status = "connected"
|
|
except Exception:
|
|
pass
|
|
|
|
# Check database connection
|
|
db_status = "disconnected"
|
|
try:
|
|
# TODO: Actually check database connection
|
|
db_status = "connected"
|
|
except Exception:
|
|
pass
|
|
|
|
# Determine overall status
|
|
if model_loaded and mlflow_status == "connected" and db_status == "connected":
|
|
overall_status = "healthy"
|
|
elif model_loaded:
|
|
overall_status = "degraded"
|
|
else:
|
|
overall_status = "unhealthy"
|
|
|
|
return HealthResponse(
|
|
status=overall_status,
|
|
model_loaded=model_loaded,
|
|
mlflow=mlflow_status,
|
|
database=db_status
|
|
)
|
|
|
|
|
|
# --- Model Info Endpoints (Stubs) ---
|
|
|
|
@app.get("/model/info", response_model=ModelInfoResponse)
|
|
async def get_model_info():
|
|
"""
|
|
Get detailed model information and per-class metrics.
|
|
|
|
Returns HTTP 503 if no model is loaded.
|
|
"""
|
|
if state.model is None or state.model_info is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="No model available"
|
|
)
|
|
|
|
# Build response from loaded model info
|
|
return ModelInfoResponse(
|
|
model_name=state.model_info["model_name"],
|
|
model_version=state.model_info.get("model_version"),
|
|
model_type=state.model_info["model_type"],
|
|
trained_at=state.model_info.get("trained_at"),
|
|
dataset_version=state.model_info.get("dataset_version"),
|
|
feature_engineering_enabled=state.model_info["feature_engineering_enabled"],
|
|
labels=state.model_info.get("labels", []),
|
|
per_class_metrics=[
|
|
PerClassMetrics(**metric)
|
|
for metric in state.model_info.get("per_class_metrics", [])
|
|
]
|
|
)
|
|
|
|
|
|
@app.get("/model/labels", response_model=List[LabelInfo])
|
|
async def get_model_labels():
|
|
"""
|
|
Get all pattern labels the model can predict with display colors.
|
|
"""
|
|
if state.model is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="No model available"
|
|
)
|
|
|
|
# TODO: Load label colors from config or database
|
|
# For now, return placeholder
|
|
labels = state.model_info.get("labels", []) if state.model_info else []
|
|
|
|
return [
|
|
LabelInfo(name=label, color="#888888")
|
|
for label in labels
|
|
]
|
|
|
|
|
|
# --- Prediction Helper Functions ---
|
|
|
|
def group_prediction_spans(
|
|
predictions: List[PredictionResult]
|
|
) -> List[PredictionSpan]:
|
|
"""
|
|
Group consecutive predictions with the same label into spans.
|
|
|
|
Only groups non-"O" (non-background) predictions. "O" predictions are
|
|
treated as background and not grouped into spans.
|
|
|
|
Args:
|
|
predictions: List of per-candle predictions
|
|
|
|
Returns:
|
|
List of prediction spans
|
|
"""
|
|
if not predictions:
|
|
return []
|
|
|
|
spans = []
|
|
current_span = None
|
|
|
|
for pred in predictions:
|
|
# Skip "O" (background) predictions
|
|
if pred.label == "O":
|
|
if current_span is not None:
|
|
# End current span
|
|
spans.append(current_span)
|
|
current_span = None
|
|
continue
|
|
|
|
# Start new span or continue current span
|
|
if current_span is None or current_span.label != pred.label:
|
|
# End previous span if exists
|
|
if current_span is not None:
|
|
spans.append(current_span)
|
|
|
|
# Start new span
|
|
current_span = PredictionSpan(
|
|
start_time=pred.time,
|
|
end_time=pred.time,
|
|
label=pred.label,
|
|
avg_confidence=pred.confidence,
|
|
candle_count=1
|
|
)
|
|
else:
|
|
# Continue current span
|
|
current_span.end_time = pred.time
|
|
current_span.candle_count += 1
|
|
# Update running average confidence
|
|
total_confidence = current_span.avg_confidence * (current_span.candle_count - 1)
|
|
current_span.avg_confidence = (total_confidence + pred.confidence) / current_span.candle_count
|
|
|
|
# Don't forget last span
|
|
if current_span is not None:
|
|
spans.append(current_span)
|
|
|
|
logger.info(f"Grouped {len(predictions)} predictions into {len(spans)} spans")
|
|
|
|
return spans
|
|
|
|
|
|
# --- Prediction Endpoints ---
|
|
|
|
@app.post("/predict", response_model=PredictResponse)
|
|
async def predict(request: PredictRequest):
|
|
"""
|
|
Predict candlestick patterns for provided candles.
|
|
|
|
Accepts OHLCV candles, runs preprocessing, and returns predictions.
|
|
"""
|
|
if state.model is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="No model available"
|
|
)
|
|
|
|
if not request.candles:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Candles array cannot be empty"
|
|
)
|
|
|
|
if state.pipeline_config is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Pipeline configuration not loaded"
|
|
)
|
|
|
|
logger.info(f"Predict request: {request.pair} {request.timeframe}, {len(request.candles)} candles")
|
|
|
|
try:
|
|
# Convert candles to list of dicts
|
|
candles_data = [candle.model_dump() for candle in request.candles]
|
|
|
|
# Preprocess candles (feature engineering)
|
|
df_preprocessed = preprocess_candles(candles_data, state.pipeline_config)
|
|
|
|
# Keep times for results mapping
|
|
times = df_preprocessed['time'].values
|
|
|
|
# Extract feature columns (exclude 'time')
|
|
X = extract_feature_columns(df_preprocessed)
|
|
|
|
# Get predictions and probabilities
|
|
if hasattr(state.model, 'predict_proba'):
|
|
y_pred = state.model.predict(X)
|
|
y_proba = state.model.predict_proba(X)
|
|
|
|
# Get confidence as max probability
|
|
confidences = np.max(y_proba, axis=1)
|
|
else:
|
|
# Fallback for models without predict_proba
|
|
y_pred = state.model.predict(X)
|
|
confidences = np.ones(len(y_pred)) # Default confidence of 1.0
|
|
|
|
# Get label names (handle both string and int predictions)
|
|
if state.label_encoder is not None:
|
|
# Model predicts integers, map to labels
|
|
labels = [state.label_encoder.get(int(pred), f"unknown_{pred}") for pred in y_pred]
|
|
else:
|
|
# Model predicts strings directly
|
|
labels = [str(pred) for pred in y_pred]
|
|
|
|
# Build per-candle predictions
|
|
predictions = [
|
|
PredictionResult(
|
|
time=int(time),
|
|
label=label,
|
|
confidence=float(conf)
|
|
)
|
|
for time, label, conf in zip(times, labels, confidences)
|
|
]
|
|
|
|
# Group into spans
|
|
spans = group_prediction_spans(predictions)
|
|
|
|
# Build model info for response
|
|
model_info = ModelInfo(
|
|
model_name=state.model_info["model_name"],
|
|
model_version=state.model_info.get("model_version"),
|
|
model_type=state.model_info["model_type"],
|
|
trained_at=state.model_info.get("trained_at"),
|
|
dataset_version=state.model_info.get("dataset_version"),
|
|
feature_engineering_enabled=state.model_info["feature_engineering_enabled"]
|
|
)
|
|
|
|
logger.info(
|
|
f"Prediction complete: {len(predictions)} candles, "
|
|
f"{len(spans)} spans, {len([p for p in predictions if p.label != 'O'])} patterns"
|
|
)
|
|
|
|
return PredictResponse(
|
|
predictions=predictions,
|
|
spans=spans,
|
|
model_info=model_info
|
|
)
|
|
|
|
except ValueError as e:
|
|
logger.error(f"Prediction validation error: {e}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=str(e)
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Prediction failed: {e}", exc_info=True)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Prediction failed: {str(e)}"
|
|
)
|
|
|
|
|
|
@app.post("/predict/batch", response_model=PredictResponse)
|
|
async def predict_batch(request: BatchPredictRequest):
|
|
"""
|
|
Batch prediction for a date range.
|
|
|
|
Loads data from storage and returns predictions for the full range.
|
|
"""
|
|
if state.model is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="No model available"
|
|
)
|
|
|
|
if state.pipeline_config is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Pipeline configuration not loaded"
|
|
)
|
|
|
|
logger.info(
|
|
f"Batch predict: {request.pair} {request.timeframe} "
|
|
f"from {request.start_date} to {request.end_date}"
|
|
)
|
|
|
|
try:
|
|
# Load OHLCV data from raw data path
|
|
raw_path = Path(state.pipeline_config.data.raw_path)
|
|
|
|
if not raw_path.exists():
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"No data found for {request.pair} {request.timeframe}"
|
|
)
|
|
|
|
# Load data
|
|
df_raw = pd.read_csv(raw_path)
|
|
|
|
# Filter by date range if time column exists
|
|
if 'time' in df_raw.columns:
|
|
# Parse dates to timestamps
|
|
start_ts = int(pd.Timestamp(request.start_date).timestamp())
|
|
end_ts = int(pd.Timestamp(request.end_date).timestamp())
|
|
|
|
# Filter data
|
|
df_filtered = df_raw[
|
|
(df_raw['time'] >= start_ts) & (df_raw['time'] <= end_ts)
|
|
].copy()
|
|
|
|
if len(df_filtered) == 0:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"No data found in date range {request.start_date} to {request.end_date}"
|
|
)
|
|
|
|
logger.info(f"Loaded {len(df_filtered)} candles from {request.start_date} to {request.end_date}")
|
|
else:
|
|
df_filtered = df_raw
|
|
logger.warning("No 'time' column found, using all data")
|
|
|
|
# Get batch size from config
|
|
batch_size = state.pipeline_config.stages.inference.batch_size
|
|
|
|
# Process in batches
|
|
all_predictions = []
|
|
all_spans = []
|
|
|
|
for i in range(0, len(df_filtered), batch_size):
|
|
batch_df = df_filtered.iloc[i:i + batch_size]
|
|
|
|
logger.info(f"Processing batch {i // batch_size + 1}: rows {i} to {i + len(batch_df)}")
|
|
|
|
# Convert batch to candles format
|
|
batch_candles = batch_df.to_dict('records')
|
|
|
|
# Preprocess
|
|
df_preprocessed = preprocess_candles(batch_candles, state.pipeline_config)
|
|
|
|
# Keep times
|
|
times = df_preprocessed['time'].values
|
|
|
|
# Extract features
|
|
X = extract_feature_columns(df_preprocessed)
|
|
|
|
# Predict
|
|
if hasattr(state.model, 'predict_proba'):
|
|
y_pred = state.model.predict(X)
|
|
y_proba = state.model.predict_proba(X)
|
|
confidences = np.max(y_proba, axis=1)
|
|
else:
|
|
y_pred = state.model.predict(X)
|
|
confidences = np.ones(len(y_pred))
|
|
|
|
# Get labels
|
|
if state.label_encoder is not None:
|
|
labels = [state.label_encoder.get(int(pred), f"unknown_{pred}") for pred in y_pred]
|
|
else:
|
|
labels = [str(pred) for pred in y_pred]
|
|
|
|
# Build predictions for this batch
|
|
batch_predictions = [
|
|
PredictionResult(
|
|
time=int(time),
|
|
label=label,
|
|
confidence=float(conf)
|
|
)
|
|
for time, label, conf in zip(times, labels, confidences)
|
|
]
|
|
|
|
all_predictions.extend(batch_predictions)
|
|
|
|
# Group all predictions into spans
|
|
all_spans = group_prediction_spans(all_predictions)
|
|
|
|
# Build model info
|
|
model_info = ModelInfo(
|
|
model_name=state.model_info["model_name"],
|
|
model_version=state.model_info.get("model_version"),
|
|
model_type=state.model_info["model_type"],
|
|
trained_at=state.model_info.get("trained_at"),
|
|
dataset_version=state.model_info.get("dataset_version"),
|
|
feature_engineering_enabled=state.model_info["feature_engineering_enabled"]
|
|
)
|
|
|
|
logger.info(
|
|
f"Batch prediction complete: {len(all_predictions)} candles, "
|
|
f"{len(all_spans)} spans"
|
|
)
|
|
|
|
return PredictResponse(
|
|
predictions=all_predictions,
|
|
spans=all_spans,
|
|
model_info=model_info
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Batch prediction failed: {e}", exc_info=True)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Batch prediction failed: {str(e)}"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8001)
|