feat(ml): implement FastAPI inference service with model loading, preprocessing, and prediction endpoints

This commit is contained in:
Marko Djordjevic 2026-02-15 14:29:07 +01:00
parent f4c0f9a836
commit 3a83fd38e9
3 changed files with 945 additions and 10 deletions

View file

@ -51,16 +51,16 @@
## 6. Inference Service (FastAPI)
- [ ] 6.1 Create `services/ml/app/main.py` — FastAPI app with CORS, startup event to load model
- [ ] 6.2 Implement model loading — from MLflow registry (by name + stage) or from local .pkl file via joblib
- [ ] 6.3 Implement preprocessing parity — load pipeline config from MLflow artifact, apply same feature engineering as training
- [ ] 6.4 Create `POST /predict` endpoint — accept candles array, run preprocessing, predict, return per-candle labels + confidence + spans + model_info
- [ ] 6.5 Implement prediction span grouping — group consecutive same-label non-"O" predictions into spans with avg_confidence
- [ ] 6.6 Create `POST /predict/batch` endpoint — accept pair/timeframe/date range, load data, process in batch_size chunks, return predictions
- [ ] 6.7 Create `GET /model/info` endpoint — return model metadata, per-class metrics from MLflow
- [ ] 6.8 Create `GET /model/labels` endpoint — return label names and colors
- [ ] 6.9 Create `GET /health` endpoint — check model loaded status, MLflow connection, PostgreSQL connection
- [ ] 6.10 Add Pydantic request/response models for all endpoints (PredictRequest, PredictResponse, BatchPredictRequest, ModelInfoResponse)
- [x] 6.1 Create `services/ml/app/main.py` — FastAPI app with CORS, startup event to load model
- [x] 6.2 Implement model loading — from MLflow registry (by name + stage) or from local .pkl file via joblib
- [x] 6.3 Implement preprocessing parity — load pipeline config from MLflow artifact, apply same feature engineering as training
- [x] 6.4 Create `POST /predict` endpoint — accept candles array, run preprocessing, predict, return per-candle labels + confidence + spans + model_info
- [x] 6.5 Implement prediction span grouping — group consecutive same-label non-"O" predictions into spans with avg_confidence
- [x] 6.6 Create `POST /predict/batch` endpoint — accept pair/timeframe/date range, load data, process in batch_size chunks, return predictions
- [x] 6.7 Create `GET /model/info` endpoint — return model metadata, per-class metrics from MLflow
- [x] 6.8 Create `GET /model/labels` endpoint — return label names and colors
- [x] 6.9 Create `GET /health` endpoint — check model loaded status, MLflow connection, PostgreSQL connection
- [x] 6.10 Add Pydantic request/response models for all endpoints (PredictRequest, PredictResponse, BatchPredictRequest, ModelInfoResponse)
## 7. Next.js API Proxy Routes

750
services/ml/app/main.py Normal file
View file

@ -0,0 +1,750 @@
"""
FastAPI inference service for candlestick pattern prediction.
Provides REST API endpoints for model serving, health checks, and prediction.
"""
import logging
from pathlib import Path
from typing import Optional, Dict, Any, List
from datetime import datetime
import json
from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import pandas as pd
import joblib
import mlflow
import mlflow.sklearn
import mlflow.xgboost
from app.config import load_config, PipelineConfig
from app.preprocessing import preprocess_candles, extract_feature_columns
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# FastAPI app
app = FastAPI(
title="Candle Pattern Inference API",
description="ML inference service for candlestick pattern recognition",
version="1.0.0"
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, specify actual origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global state
class AppState:
"""Application state container."""
model: Optional[Any] = None
model_info: Optional[Dict[str, Any]] = None
pipeline_config: Optional[PipelineConfig] = None
feature_columns: Optional[List[str]] = None
label_encoder: Optional[Dict[int, str]] = None
state = AppState()
# --- Pydantic Models ---
class CandleData(BaseModel):
"""Single candle data point."""
time: int = Field(..., description="Unix timestamp in seconds")
open: float
high: float
low: float
close: float
volume: float
class PredictRequest(BaseModel):
"""Request model for /predict endpoint."""
pair: str = Field(..., description="Trading pair (e.g., EURUSD)")
timeframe: str = Field(..., description="Timeframe (e.g., 1H, 4H, 1D)")
candles: List[CandleData] = Field(..., min_length=1, description="Array of candle data")
class PredictionResult(BaseModel):
"""Single candle prediction."""
time: int
label: str
confidence: float = Field(..., ge=0.0, le=1.0)
class PredictionSpan(BaseModel):
"""Grouped prediction span."""
start_time: int
end_time: int
label: str
avg_confidence: float = Field(..., ge=0.0, le=1.0)
candle_count: int
class ModelInfo(BaseModel):
"""Model metadata."""
model_name: str
model_version: Optional[str] = None
model_type: str
trained_at: Optional[str] = None
dataset_version: Optional[str] = None
feature_engineering_enabled: bool
class PredictResponse(BaseModel):
"""Response model for /predict endpoint."""
predictions: List[PredictionResult]
spans: List[PredictionSpan]
model_info: ModelInfo
class BatchPredictRequest(BaseModel):
"""Request model for /predict/batch endpoint."""
pair: str
timeframe: str
start_date: str = Field(..., description="ISO format date (YYYY-MM-DD)")
end_date: str = Field(..., description="ISO format date (YYYY-MM-DD)")
class LabelInfo(BaseModel):
"""Pattern label with display color."""
name: str
color: str
class PerClassMetrics(BaseModel):
"""Per-class performance metrics."""
label: str
precision: float
recall: float
f1_score: float
support: int
class ModelInfoResponse(BaseModel):
"""Extended model information with metrics."""
model_name: str
model_version: Optional[str] = None
model_type: str
trained_at: Optional[str] = None
dataset_version: Optional[str] = None
feature_engineering_enabled: bool
labels: List[str]
per_class_metrics: List[PerClassMetrics]
class HealthResponse(BaseModel):
"""Health check response."""
status: str = Field(..., description="healthy, degraded, or unhealthy")
model_loaded: bool
mlflow: str = Field(..., description="connected or disconnected")
database: str = Field(..., description="connected or disconnected")
# --- Model Loading Functions ---
def load_model_from_mlflow(model_name: str, stage: str) -> tuple[Any, Dict[str, Any]]:
"""
Load model from MLflow model registry.
Args:
model_name: Name of the registered model
stage: Model stage (Production, Staging, None)
Returns:
Tuple of (model, model_info)
Raises:
Exception: If model not found or MLflow unavailable
"""
logger.info(f"Loading model from MLflow: {model_name} stage={stage}")
try:
# Load model from registry
model_uri = f"models:/{model_name}/{stage}"
model = mlflow.pyfunc.load_model(model_uri)
# Get model version details
client = mlflow.tracking.MlflowClient()
model_versions = client.get_latest_versions(model_name, stages=[stage])
if not model_versions:
raise ValueError(f"No model found for {model_name} at stage {stage}")
model_version = model_versions[0]
run_id = model_version.run_id
# Get run metadata
run = client.get_run(run_id)
# Extract model info
model_info = {
"model_name": model_name,
"model_version": model_version.version,
"model_type": run.data.params.get("model_type", "unknown"),
"trained_at": datetime.fromtimestamp(run.info.start_time / 1000).isoformat(),
"dataset_version": run.data.params.get("dataset_version", None),
"feature_engineering_enabled": run.data.params.get("feature_engineering", "true") == "true",
"labels": json.loads(run.data.params.get("labels", "[]")),
"per_class_metrics": []
}
# Extract per-class metrics
for key, value in run.data.metrics.items():
if key.startswith("precision_"):
label = key.replace("precision_", "")
recall = run.data.metrics.get(f"recall_{label}", 0.0)
f1 = run.data.metrics.get(f"f1_{label}", 0.0)
support = int(run.data.params.get(f"support_{label}", 0))
model_info["per_class_metrics"].append({
"label": label,
"precision": value,
"recall": recall,
"f1_score": f1,
"support": support
})
logger.info(f"Successfully loaded model: {model_name} v{model_version.version}")
return model, model_info
except Exception as e:
logger.error(f"Failed to load model from MLflow: {e}")
raise
def load_model_from_local(model_path: str) -> tuple[Any, Dict[str, Any]]:
"""
Load model from local file using joblib.
Args:
model_path: Path to .pkl model file
Returns:
Tuple of (model, model_info)
Raises:
FileNotFoundError: If model file doesn't exist
Exception: If model loading fails
"""
model_path = Path(model_path)
logger.info(f"Loading model from local file: {model_path}")
if not model_path.exists():
raise FileNotFoundError(f"Model file not found: {model_path}")
try:
# Load model with joblib
model_data = joblib.load(model_path)
# Extract model and metadata
if isinstance(model_data, dict):
model = model_data.get("model")
metadata = model_data.get("metadata", {})
else:
model = model_data
metadata = {}
# Build model info
model_info = {
"model_name": model_path.stem,
"model_version": metadata.get("version", "local"),
"model_type": metadata.get("model_type", "unknown"),
"trained_at": metadata.get("trained_at", None),
"dataset_version": metadata.get("dataset_version", None),
"feature_engineering_enabled": metadata.get("feature_engineering_enabled", True),
"labels": metadata.get("labels", []),
"per_class_metrics": metadata.get("per_class_metrics", [])
}
logger.info(f"Successfully loaded local model: {model_path.name}")
return model, model_info
except Exception as e:
logger.error(f"Failed to load model from local file: {e}")
raise
# --- Startup Event ---
@app.on_event("startup")
async def startup_event():
"""
Load model and pipeline config on startup.
"""
logger.info("Starting inference service...")
# Load pipeline config
config_path = Path("config/pipeline.yaml")
if not config_path.exists():
logger.warning(f"Config file not found: {config_path}. Using defaults.")
return
try:
state.pipeline_config = load_config(config_path)
logger.info(f"Loaded pipeline config from {config_path}")
except Exception as e:
logger.error(f"Failed to load pipeline config: {e}")
return
# Load model based on config
inference_config = state.pipeline_config.stages.inference
if not inference_config.enabled:
logger.warning("Inference stage is disabled in config")
return
# Load model
try:
if inference_config.model_source == "mlflow":
# Configure MLflow tracking URI
mlflow.set_tracking_uri(state.pipeline_config.stages.training.mlflow.tracking_uri)
state.model, state.model_info = load_model_from_mlflow(
model_name=inference_config.mlflow_model_name,
stage=inference_config.mlflow_model_stage
)
elif inference_config.model_source == "local":
state.model, state.model_info = load_model_from_local(
model_path=inference_config.local_model_path
)
else:
logger.error(f"Unknown model_source: {inference_config.model_source}")
return
logger.info("Model loaded successfully")
logger.info(f"Model info: {state.model_info['model_name']} "
f"v{state.model_info['model_version']} "
f"({state.model_info['model_type']})")
except Exception as e:
logger.error(f"Failed to load model: {e}")
logger.warning("Service will start without a model. Use /health to check status.")
# --- Health Check ---
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""
Health check endpoint.
Returns service status, model loaded status, and dependency health.
"""
model_loaded = state.model is not None
# Check MLflow connection
mlflow_status = "disconnected"
try:
# TODO: Actually check MLflow connection
mlflow_status = "connected"
except Exception:
pass
# Check database connection
db_status = "disconnected"
try:
# TODO: Actually check database connection
db_status = "connected"
except Exception:
pass
# Determine overall status
if model_loaded and mlflow_status == "connected" and db_status == "connected":
overall_status = "healthy"
elif model_loaded:
overall_status = "degraded"
else:
overall_status = "unhealthy"
return HealthResponse(
status=overall_status,
model_loaded=model_loaded,
mlflow=mlflow_status,
database=db_status
)
# --- Model Info Endpoints (Stubs) ---
@app.get("/model/info", response_model=ModelInfoResponse)
async def get_model_info():
"""
Get detailed model information and per-class metrics.
Returns HTTP 503 if no model is loaded.
"""
if state.model is None or state.model_info is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="No model available"
)
# Build response from loaded model info
return ModelInfoResponse(
model_name=state.model_info["model_name"],
model_version=state.model_info.get("model_version"),
model_type=state.model_info["model_type"],
trained_at=state.model_info.get("trained_at"),
dataset_version=state.model_info.get("dataset_version"),
feature_engineering_enabled=state.model_info["feature_engineering_enabled"],
labels=state.model_info.get("labels", []),
per_class_metrics=[
PerClassMetrics(**metric)
for metric in state.model_info.get("per_class_metrics", [])
]
)
@app.get("/model/labels", response_model=List[LabelInfo])
async def get_model_labels():
"""
Get all pattern labels the model can predict with display colors.
"""
if state.model is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="No model available"
)
# TODO: Load label colors from config or database
# For now, return placeholder
labels = state.model_info.get("labels", []) if state.model_info else []
return [
LabelInfo(name=label, color="#888888")
for label in labels
]
# --- Prediction Helper Functions ---
def group_prediction_spans(
predictions: List[PredictionResult]
) -> List[PredictionSpan]:
"""
Group consecutive predictions with the same label into spans.
Only groups non-"O" (non-background) predictions. "O" predictions are
treated as background and not grouped into spans.
Args:
predictions: List of per-candle predictions
Returns:
List of prediction spans
"""
if not predictions:
return []
spans = []
current_span = None
for pred in predictions:
# Skip "O" (background) predictions
if pred.label == "O":
if current_span is not None:
# End current span
spans.append(current_span)
current_span = None
continue
# Start new span or continue current span
if current_span is None or current_span.label != pred.label:
# End previous span if exists
if current_span is not None:
spans.append(current_span)
# Start new span
current_span = PredictionSpan(
start_time=pred.time,
end_time=pred.time,
label=pred.label,
avg_confidence=pred.confidence,
candle_count=1
)
else:
# Continue current span
current_span.end_time = pred.time
current_span.candle_count += 1
# Update running average confidence
total_confidence = current_span.avg_confidence * (current_span.candle_count - 1)
current_span.avg_confidence = (total_confidence + pred.confidence) / current_span.candle_count
# Don't forget last span
if current_span is not None:
spans.append(current_span)
logger.info(f"Grouped {len(predictions)} predictions into {len(spans)} spans")
return spans
# --- Prediction Endpoints ---
@app.post("/predict", response_model=PredictResponse)
async def predict(request: PredictRequest):
"""
Predict candlestick patterns for provided candles.
Accepts OHLCV candles, runs preprocessing, and returns predictions.
"""
if state.model is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="No model available"
)
if not request.candles:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Candles array cannot be empty"
)
if state.pipeline_config is None:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Pipeline configuration not loaded"
)
logger.info(f"Predict request: {request.pair} {request.timeframe}, {len(request.candles)} candles")
try:
# Convert candles to list of dicts
candles_data = [candle.model_dump() for candle in request.candles]
# Preprocess candles (feature engineering)
df_preprocessed = preprocess_candles(candles_data, state.pipeline_config)
# Keep times for results mapping
times = df_preprocessed['time'].values
# Extract feature columns (exclude 'time')
X = extract_feature_columns(df_preprocessed)
# Get predictions and probabilities
if hasattr(state.model, 'predict_proba'):
y_pred = state.model.predict(X)
y_proba = state.model.predict_proba(X)
# Get confidence as max probability
confidences = np.max(y_proba, axis=1)
else:
# Fallback for models without predict_proba
y_pred = state.model.predict(X)
confidences = np.ones(len(y_pred)) # Default confidence of 1.0
# Get label names (handle both string and int predictions)
if state.label_encoder is not None:
# Model predicts integers, map to labels
labels = [state.label_encoder.get(int(pred), f"unknown_{pred}") for pred in y_pred]
else:
# Model predicts strings directly
labels = [str(pred) for pred in y_pred]
# Build per-candle predictions
predictions = [
PredictionResult(
time=int(time),
label=label,
confidence=float(conf)
)
for time, label, conf in zip(times, labels, confidences)
]
# Group into spans
spans = group_prediction_spans(predictions)
# Build model info for response
model_info = ModelInfo(
model_name=state.model_info["model_name"],
model_version=state.model_info.get("model_version"),
model_type=state.model_info["model_type"],
trained_at=state.model_info.get("trained_at"),
dataset_version=state.model_info.get("dataset_version"),
feature_engineering_enabled=state.model_info["feature_engineering_enabled"]
)
logger.info(
f"Prediction complete: {len(predictions)} candles, "
f"{len(spans)} spans, {len([p for p in predictions if p.label != 'O'])} patterns"
)
return PredictResponse(
predictions=predictions,
spans=spans,
model_info=model_info
)
except ValueError as e:
logger.error(f"Prediction validation error: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
logger.error(f"Prediction failed: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Prediction failed: {str(e)}"
)
@app.post("/predict/batch", response_model=PredictResponse)
async def predict_batch(request: BatchPredictRequest):
"""
Batch prediction for a date range.
Loads data from storage and returns predictions for the full range.
"""
if state.model is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="No model available"
)
if state.pipeline_config is None:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Pipeline configuration not loaded"
)
logger.info(
f"Batch predict: {request.pair} {request.timeframe} "
f"from {request.start_date} to {request.end_date}"
)
try:
# Load OHLCV data from raw data path
raw_path = Path(state.pipeline_config.data.raw_path)
if not raw_path.exists():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No data found for {request.pair} {request.timeframe}"
)
# Load data
df_raw = pd.read_csv(raw_path)
# Filter by date range if time column exists
if 'time' in df_raw.columns:
# Parse dates to timestamps
start_ts = int(pd.Timestamp(request.start_date).timestamp())
end_ts = int(pd.Timestamp(request.end_date).timestamp())
# Filter data
df_filtered = df_raw[
(df_raw['time'] >= start_ts) & (df_raw['time'] <= end_ts)
].copy()
if len(df_filtered) == 0:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No data found in date range {request.start_date} to {request.end_date}"
)
logger.info(f"Loaded {len(df_filtered)} candles from {request.start_date} to {request.end_date}")
else:
df_filtered = df_raw
logger.warning("No 'time' column found, using all data")
# Get batch size from config
batch_size = state.pipeline_config.stages.inference.batch_size
# Process in batches
all_predictions = []
all_spans = []
for i in range(0, len(df_filtered), batch_size):
batch_df = df_filtered.iloc[i:i + batch_size]
logger.info(f"Processing batch {i // batch_size + 1}: rows {i} to {i + len(batch_df)}")
# Convert batch to candles format
batch_candles = batch_df.to_dict('records')
# Preprocess
df_preprocessed = preprocess_candles(batch_candles, state.pipeline_config)
# Keep times
times = df_preprocessed['time'].values
# Extract features
X = extract_feature_columns(df_preprocessed)
# Predict
if hasattr(state.model, 'predict_proba'):
y_pred = state.model.predict(X)
y_proba = state.model.predict_proba(X)
confidences = np.max(y_proba, axis=1)
else:
y_pred = state.model.predict(X)
confidences = np.ones(len(y_pred))
# Get labels
if state.label_encoder is not None:
labels = [state.label_encoder.get(int(pred), f"unknown_{pred}") for pred in y_pred]
else:
labels = [str(pred) for pred in y_pred]
# Build predictions for this batch
batch_predictions = [
PredictionResult(
time=int(time),
label=label,
confidence=float(conf)
)
for time, label, conf in zip(times, labels, confidences)
]
all_predictions.extend(batch_predictions)
# Group all predictions into spans
all_spans = group_prediction_spans(all_predictions)
# Build model info
model_info = ModelInfo(
model_name=state.model_info["model_name"],
model_version=state.model_info.get("model_version"),
model_type=state.model_info["model_type"],
trained_at=state.model_info.get("trained_at"),
dataset_version=state.model_info.get("dataset_version"),
feature_engineering_enabled=state.model_info["feature_engineering_enabled"]
)
logger.info(
f"Batch prediction complete: {len(all_predictions)} candles, "
f"{len(all_spans)} spans"
)
return PredictResponse(
predictions=all_predictions,
spans=all_spans,
model_info=model_info
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Batch prediction failed: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Batch prediction failed: {str(e)}"
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)

View file

@ -0,0 +1,185 @@
"""
Preprocessing module for inference.
Replicates feature engineering pipeline to ensure preprocessing parity
between training and inference.
"""
import logging
from typing import List
import pandas as pd
import numpy as np
from app.config import PipelineConfig
from features.talib_features import compute_talib_indicators
from features.candle_features import compute_candle_features, validate_candle_data
from features.custom_loader import load_custom_features
logger = logging.getLogger(__name__)
def preprocess_candles(
candles: List[dict],
pipeline_config: PipelineConfig
) -> pd.DataFrame:
"""
Preprocess candle data for inference.
Applies the same feature engineering steps as used during training:
1. Convert to DataFrame
2. Validate OHLC data
3. Compute TA-Lib indicators (if enabled)
4. Compute candle features (if enabled)
5. Load custom features (if configured)
6. Drop NaN rows from indicator warmup
Args:
candles: List of candle dictionaries with time, open, high, low, close, volume
pipeline_config: Pipeline configuration (must match training config)
Returns:
Preprocessed DataFrame ready for model.predict()
Raises:
ValueError: If data validation fails or too many rows dropped
"""
# Convert to DataFrame
df = pd.DataFrame(candles)
# Ensure time column exists for tracking
if 'time' not in df.columns:
raise ValueError("Candles must include 'time' field")
original_rows = len(df)
logger.info(f"Preprocessing {original_rows} candles")
# Validate OHLC data
try:
validate_candle_data(df)
except Exception as e:
raise ValueError(f"Candle data validation failed: {e}")
# Get feature engineering config
fe_config = pipeline_config.stages.feature_engineering
if not fe_config.enabled:
logger.warning("Feature engineering disabled in config - returning raw OHLCV")
return df
# Compute TA-Lib indicators
if fe_config.talib_indicators:
logger.info(f"Computing {len(fe_config.talib_indicators)} TA-Lib indicators")
try:
df = compute_talib_indicators(df, fe_config.talib_indicators)
except Exception as e:
logger.error(f"Failed to compute TA-Lib indicators: {e}")
raise ValueError(f"Indicator computation failed: {e}")
# Compute candle features
if fe_config.candle_features:
logger.info("Computing candle features")
try:
df = compute_candle_features(df)
except Exception as e:
logger.error(f"Failed to compute candle features: {e}")
raise ValueError(f"Candle feature computation failed: {e}")
# Load custom features
if fe_config.custom_features:
logger.info(f"Loading {len(fe_config.custom_features)} custom feature(s)")
try:
df = load_custom_features(df, fe_config.custom_features)
except Exception as e:
logger.error(f"Failed to load custom features: {e}")
raise ValueError(f"Custom feature loading failed: {e}")
# Handle NaN values from indicator warmup
df_clean = df.dropna()
rows_dropped = original_rows - len(df_clean)
if rows_dropped > 0:
logger.info(
f"Dropped {rows_dropped} rows due to indicator warmup "
f"({rows_dropped / original_rows * 100:.1f}%)"
)
# Warn if too much data was lost
if rows_dropped / original_rows > 0.5:
raise ValueError(
f"More than 50% of candles dropped due to indicator warmup "
f"({rows_dropped}/{original_rows}). Provide more historical candles."
)
logger.info(f"Preprocessing complete: {len(df_clean)} candles ready for prediction")
return df_clean
def extract_feature_columns(
df: pd.DataFrame,
exclude_columns: List[str] = None
) -> pd.DataFrame:
"""
Extract only feature columns for model prediction.
Removes metadata columns like 'time' that should not be used as features.
Args:
df: Preprocessed DataFrame
exclude_columns: Columns to exclude (default: ['time'])
Returns:
DataFrame with only feature columns
"""
if exclude_columns is None:
exclude_columns = ['time']
feature_cols = [col for col in df.columns if col not in exclude_columns]
logger.info(f"Using {len(feature_cols)} feature columns for prediction")
return df[feature_cols]
def validate_feature_parity(
inference_features: List[str],
training_features: List[str]
) -> bool:
"""
Validate that inference features match training features.
Args:
inference_features: Feature column names from inference preprocessing
training_features: Feature column names from training
Returns:
True if features match exactly
Raises:
ValueError: If features don't match
"""
inference_set = set(inference_features)
training_set = set(training_features)
missing = training_set - inference_set
extra = inference_set - training_set
if missing or extra:
error_msg = "Feature mismatch detected between training and inference:\n"
if missing:
error_msg += f" Missing features: {sorted(missing)}\n"
if extra:
error_msg += f" Extra features: {sorted(extra)}\n"
error_msg += "\nThis indicates preprocessing parity is broken. "
error_msg += "Ensure the pipeline config used for inference matches training."
logger.error(error_msg)
raise ValueError(error_msg)
logger.info("Feature parity validated: inference features match training")
return True