feat(ml): implement FastAPI inference service with model loading, preprocessing, and prediction endpoints
This commit is contained in:
parent
f4c0f9a836
commit
3a83fd38e9
3 changed files with 945 additions and 10 deletions
185
services/ml/app/preprocessing.py
Normal file
185
services/ml/app/preprocessing.py
Normal file
|
|
@ -0,0 +1,185 @@
|
|||
"""
|
||||
Preprocessing module for inference.
|
||||
|
||||
Replicates feature engineering pipeline to ensure preprocessing parity
|
||||
between training and inference.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from app.config import PipelineConfig
|
||||
from features.talib_features import compute_talib_indicators
|
||||
from features.candle_features import compute_candle_features, validate_candle_data
|
||||
from features.custom_loader import load_custom_features
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def preprocess_candles(
|
||||
candles: List[dict],
|
||||
pipeline_config: PipelineConfig
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Preprocess candle data for inference.
|
||||
|
||||
Applies the same feature engineering steps as used during training:
|
||||
1. Convert to DataFrame
|
||||
2. Validate OHLC data
|
||||
3. Compute TA-Lib indicators (if enabled)
|
||||
4. Compute candle features (if enabled)
|
||||
5. Load custom features (if configured)
|
||||
6. Drop NaN rows from indicator warmup
|
||||
|
||||
Args:
|
||||
candles: List of candle dictionaries with time, open, high, low, close, volume
|
||||
pipeline_config: Pipeline configuration (must match training config)
|
||||
|
||||
Returns:
|
||||
Preprocessed DataFrame ready for model.predict()
|
||||
|
||||
Raises:
|
||||
ValueError: If data validation fails or too many rows dropped
|
||||
"""
|
||||
# Convert to DataFrame
|
||||
df = pd.DataFrame(candles)
|
||||
|
||||
# Ensure time column exists for tracking
|
||||
if 'time' not in df.columns:
|
||||
raise ValueError("Candles must include 'time' field")
|
||||
|
||||
original_rows = len(df)
|
||||
logger.info(f"Preprocessing {original_rows} candles")
|
||||
|
||||
# Validate OHLC data
|
||||
try:
|
||||
validate_candle_data(df)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Candle data validation failed: {e}")
|
||||
|
||||
# Get feature engineering config
|
||||
fe_config = pipeline_config.stages.feature_engineering
|
||||
|
||||
if not fe_config.enabled:
|
||||
logger.warning("Feature engineering disabled in config - returning raw OHLCV")
|
||||
return df
|
||||
|
||||
# Compute TA-Lib indicators
|
||||
if fe_config.talib_indicators:
|
||||
logger.info(f"Computing {len(fe_config.talib_indicators)} TA-Lib indicators")
|
||||
try:
|
||||
df = compute_talib_indicators(df, fe_config.talib_indicators)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to compute TA-Lib indicators: {e}")
|
||||
raise ValueError(f"Indicator computation failed: {e}")
|
||||
|
||||
# Compute candle features
|
||||
if fe_config.candle_features:
|
||||
logger.info("Computing candle features")
|
||||
try:
|
||||
df = compute_candle_features(df)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to compute candle features: {e}")
|
||||
raise ValueError(f"Candle feature computation failed: {e}")
|
||||
|
||||
# Load custom features
|
||||
if fe_config.custom_features:
|
||||
logger.info(f"Loading {len(fe_config.custom_features)} custom feature(s)")
|
||||
try:
|
||||
df = load_custom_features(df, fe_config.custom_features)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load custom features: {e}")
|
||||
raise ValueError(f"Custom feature loading failed: {e}")
|
||||
|
||||
# Handle NaN values from indicator warmup
|
||||
df_clean = df.dropna()
|
||||
|
||||
rows_dropped = original_rows - len(df_clean)
|
||||
|
||||
if rows_dropped > 0:
|
||||
logger.info(
|
||||
f"Dropped {rows_dropped} rows due to indicator warmup "
|
||||
f"({rows_dropped / original_rows * 100:.1f}%)"
|
||||
)
|
||||
|
||||
# Warn if too much data was lost
|
||||
if rows_dropped / original_rows > 0.5:
|
||||
raise ValueError(
|
||||
f"More than 50% of candles dropped due to indicator warmup "
|
||||
f"({rows_dropped}/{original_rows}). Provide more historical candles."
|
||||
)
|
||||
|
||||
logger.info(f"Preprocessing complete: {len(df_clean)} candles ready for prediction")
|
||||
|
||||
return df_clean
|
||||
|
||||
|
||||
def extract_feature_columns(
|
||||
df: pd.DataFrame,
|
||||
exclude_columns: List[str] = None
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Extract only feature columns for model prediction.
|
||||
|
||||
Removes metadata columns like 'time' that should not be used as features.
|
||||
|
||||
Args:
|
||||
df: Preprocessed DataFrame
|
||||
exclude_columns: Columns to exclude (default: ['time'])
|
||||
|
||||
Returns:
|
||||
DataFrame with only feature columns
|
||||
"""
|
||||
if exclude_columns is None:
|
||||
exclude_columns = ['time']
|
||||
|
||||
feature_cols = [col for col in df.columns if col not in exclude_columns]
|
||||
|
||||
logger.info(f"Using {len(feature_cols)} feature columns for prediction")
|
||||
|
||||
return df[feature_cols]
|
||||
|
||||
|
||||
def validate_feature_parity(
|
||||
inference_features: List[str],
|
||||
training_features: List[str]
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that inference features match training features.
|
||||
|
||||
Args:
|
||||
inference_features: Feature column names from inference preprocessing
|
||||
training_features: Feature column names from training
|
||||
|
||||
Returns:
|
||||
True if features match exactly
|
||||
|
||||
Raises:
|
||||
ValueError: If features don't match
|
||||
"""
|
||||
inference_set = set(inference_features)
|
||||
training_set = set(training_features)
|
||||
|
||||
missing = training_set - inference_set
|
||||
extra = inference_set - training_set
|
||||
|
||||
if missing or extra:
|
||||
error_msg = "Feature mismatch detected between training and inference:\n"
|
||||
|
||||
if missing:
|
||||
error_msg += f" Missing features: {sorted(missing)}\n"
|
||||
|
||||
if extra:
|
||||
error_msg += f" Extra features: {sorted(extra)}\n"
|
||||
|
||||
error_msg += "\nThis indicates preprocessing parity is broken. "
|
||||
error_msg += "Ensure the pipeline config used for inference matches training."
|
||||
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
logger.info("Feature parity validated: inference features match training")
|
||||
return True
|
||||
Loading…
Add table
Add a link
Reference in a new issue