fix(ml): add windowed feature flattening for inference parity
The model was trained on 94-candle sliding windows flattened to 2820 features (94 candles x 30 features). Inference was sending raw per-candle features (27 columns). Changes: - Rewrite preprocessing to return (X, window_times) tuple - Add sliding window creation with correct feature ordering - Fill missing columns (average, barCount) with 0 for feature parity - Fill NaN from indicator warmup with 0 instead of dropping rows - Always compute all indicators (including MFI) for feature parity - Update predict and batch predict endpoints for new signature
This commit is contained in:
parent
4c7b3f2676
commit
40d6d1739e
2 changed files with 111 additions and 63 deletions
|
|
@ -524,14 +524,8 @@ async def predict(request: PredictRequest):
|
|||
# Convert candles to list of dicts
|
||||
candles_data = [candle.model_dump() for candle in request.candles]
|
||||
|
||||
# Preprocess candles (feature engineering)
|
||||
df_preprocessed = preprocess_candles(candles_data, state.pipeline_config)
|
||||
|
||||
# Keep times for results mapping
|
||||
times = df_preprocessed['time'].values
|
||||
|
||||
# Extract feature columns (exclude 'time')
|
||||
X = extract_feature_columns(df_preprocessed)
|
||||
# Preprocess candles (feature engineering + windowing)
|
||||
X, window_times = preprocess_candles(candles_data, state.pipeline_config)
|
||||
|
||||
# Get predictions and probabilities
|
||||
if hasattr(state.model, 'predict_proba'):
|
||||
|
|
@ -547,20 +541,18 @@ async def predict(request: PredictRequest):
|
|||
|
||||
# Get label names (handle both string and int predictions)
|
||||
if state.label_encoder is not None:
|
||||
# Model predicts integers, map to labels
|
||||
labels = [state.label_encoder.get(int(pred), f"unknown_{pred}") for pred in y_pred]
|
||||
else:
|
||||
# Model predicts strings directly
|
||||
labels = [str(pred) for pred in y_pred]
|
||||
|
||||
# Build per-candle predictions
|
||||
# Build per-window predictions (each window maps to its last candle time)
|
||||
predictions = [
|
||||
PredictionResult(
|
||||
time=int(time),
|
||||
label=label,
|
||||
confidence=float(conf)
|
||||
)
|
||||
for time, label, conf in zip(times, labels, confidences)
|
||||
for time, label, conf in zip(window_times, labels, confidences)
|
||||
]
|
||||
|
||||
# Group into spans
|
||||
|
|
@ -577,7 +569,7 @@ async def predict(request: PredictRequest):
|
|||
)
|
||||
|
||||
logger.info(
|
||||
f"Prediction complete: {len(predictions)} candles, "
|
||||
f"Prediction complete: {len(predictions)} windows, "
|
||||
f"{len(spans)} spans, {len([p for p in predictions if p.label != 'O'])} patterns"
|
||||
)
|
||||
|
||||
|
|
@ -675,14 +667,8 @@ async def predict_batch(request: BatchPredictRequest):
|
|||
# Convert batch to candles format
|
||||
batch_candles = batch_df.to_dict('records')
|
||||
|
||||
# Preprocess
|
||||
df_preprocessed = preprocess_candles(batch_candles, state.pipeline_config)
|
||||
|
||||
# Keep times
|
||||
times = df_preprocessed['time'].values
|
||||
|
||||
# Extract features
|
||||
X = extract_feature_columns(df_preprocessed)
|
||||
# Preprocess (feature engineering + windowing)
|
||||
X, window_times = preprocess_candles(batch_candles, state.pipeline_config)
|
||||
|
||||
# Predict
|
||||
if hasattr(state.model, 'predict_proba'):
|
||||
|
|
@ -706,7 +692,7 @@ async def predict_batch(request: BatchPredictRequest):
|
|||
label=label,
|
||||
confidence=float(conf)
|
||||
)
|
||||
for time, label, conf in zip(times, labels, confidences)
|
||||
for time, label, conf in zip(window_times, labels, confidences)
|
||||
]
|
||||
|
||||
all_predictions.extend(batch_predictions)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ between training and inference.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import List, Tuple
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
|
@ -18,28 +18,40 @@ from features.custom_loader import load_custom_features
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Window size used during training (number of candles per flattened sample)
|
||||
TRAINING_WINDOW_SIZE = 94
|
||||
|
||||
# Per-candle features expected by the model, in order
|
||||
TRAINING_FEATURE_ORDER = [
|
||||
'open', 'high', 'low', 'close', 'volume', 'average', 'barCount',
|
||||
'rsi_14', 'ema_20', 'ema_50',
|
||||
'macd_macd', 'macd_signal', 'macd_hist',
|
||||
'bbands_upper', 'bbands_middle', 'bbands_lower',
|
||||
'atr_14', 'adx_14', 'cci_14', 'mfi_14',
|
||||
'stoch_slowk', 'stoch_slowd',
|
||||
'body_size', 'body_direction', 'upper_wick', 'lower_wick',
|
||||
'wick_ratio', 'range', 'body_to_range', 'gap',
|
||||
]
|
||||
|
||||
|
||||
def preprocess_candles(
|
||||
candles: List[dict],
|
||||
pipeline_config: PipelineConfig
|
||||
) -> pd.DataFrame:
|
||||
) -> Tuple[pd.DataFrame, np.ndarray]:
|
||||
"""
|
||||
Preprocess candle data for inference.
|
||||
|
||||
Applies the same feature engineering steps as used during training:
|
||||
1. Convert to DataFrame
|
||||
2. Validate OHLC data
|
||||
3. Compute TA-Lib indicators (if enabled)
|
||||
4. Compute candle features (if enabled)
|
||||
5. Load custom features (if configured)
|
||||
6. Drop NaN rows from indicator warmup
|
||||
Applies the same feature engineering steps as used during training,
|
||||
then creates sliding windows and flattens them to match training format.
|
||||
|
||||
Args:
|
||||
candles: List of candle dictionaries with time, open, high, low, close, volume
|
||||
pipeline_config: Pipeline configuration (must match training config)
|
||||
|
||||
Returns:
|
||||
Preprocessed DataFrame ready for model.predict()
|
||||
Tuple of:
|
||||
- X: DataFrame with flattened windowed features (one row per window)
|
||||
- window_times: Array of time values, one per window (time of last candle)
|
||||
|
||||
Raises:
|
||||
ValueError: If data validation fails or too many rows dropped
|
||||
|
|
@ -60,33 +72,28 @@ def preprocess_candles(
|
|||
except Exception as e:
|
||||
raise ValueError(f"Candle data validation failed: {e}")
|
||||
|
||||
# Handle missing or all-null volume column - fill with 0 if absent/empty
|
||||
# Handle missing or all-null volume column
|
||||
if 'volume' not in df.columns or df['volume'].isna().all():
|
||||
logger.warning("Volume data missing from candles, filling with 0")
|
||||
df['volume'] = 0.0
|
||||
|
||||
# Add missing columns that were present in training data
|
||||
for col in ['average', 'barCount']:
|
||||
if col not in df.columns:
|
||||
df[col] = 0.0
|
||||
|
||||
# Get feature engineering config
|
||||
fe_config = pipeline_config.stages.feature_engineering
|
||||
|
||||
if not fe_config.enabled:
|
||||
logger.warning("Feature engineering disabled in config - returning raw OHLCV")
|
||||
return df
|
||||
raise ValueError("Feature engineering must be enabled for windowed inference")
|
||||
|
||||
# Compute TA-Lib indicators
|
||||
# Compute ALL TA-Lib indicators (including volume-dependent ones)
|
||||
if fe_config.talib_indicators:
|
||||
indicators = fe_config.talib_indicators
|
||||
# Skip volume-dependent indicators when volume data is unavailable
|
||||
volume_indicators = {'MFI', 'OBV', 'AD', 'ADOSC'}
|
||||
has_real_volume = df['volume'].sum() > 0
|
||||
if not has_real_volume:
|
||||
skipped = [i.name for i in indicators if i.name.upper() in volume_indicators]
|
||||
if skipped:
|
||||
logger.warning(f"Skipping volume-dependent indicators (no volume data): {skipped}")
|
||||
indicators = [i for i in indicators if i.name.upper() not in volume_indicators]
|
||||
|
||||
logger.info(f"Computing {len(indicators)} TA-Lib indicators")
|
||||
logger.info(f"Computing {len(fe_config.talib_indicators)} TA-Lib indicators")
|
||||
try:
|
||||
df = compute_talib_indicators(df, indicators)
|
||||
df = compute_talib_indicators(df, fe_config.talib_indicators)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to compute TA-Lib indicators: {e}")
|
||||
raise ValueError(f"Indicator computation failed: {e}")
|
||||
|
|
@ -109,27 +116,82 @@ def preprocess_candles(
|
|||
logger.error(f"Failed to load custom features: {e}")
|
||||
raise ValueError(f"Custom feature loading failed: {e}")
|
||||
|
||||
# Handle NaN values from indicator warmup
|
||||
df_clean = df.dropna()
|
||||
# Fill NaN values from indicator warmup and missing data with 0
|
||||
# (instead of dropping rows, since we need contiguous windows)
|
||||
nan_counts = df.isna().sum()
|
||||
nan_cols = nan_counts[nan_counts > 0]
|
||||
if not nan_cols.empty:
|
||||
logger.info(f"Filling NaN values in {len(nan_cols)} columns (indicator warmup + missing data)")
|
||||
df = df.fillna(0.0)
|
||||
|
||||
rows_dropped = original_rows - len(df_clean)
|
||||
# Ensure all expected per-candle features exist
|
||||
for col in TRAINING_FEATURE_ORDER:
|
||||
if col not in df.columns:
|
||||
logger.warning(f"Missing expected feature column '{col}', filling with 0")
|
||||
df[col] = 0.0
|
||||
|
||||
if rows_dropped > 0:
|
||||
logger.info(
|
||||
f"Dropped {rows_dropped} rows due to indicator warmup "
|
||||
f"({rows_dropped / original_rows * 100:.1f}%)"
|
||||
)
|
||||
logger.info(f"Preprocessing complete: {len(df)} candles with {len(TRAINING_FEATURE_ORDER)} features each")
|
||||
|
||||
# Create sliding windows and flatten
|
||||
X, window_times = create_sliding_windows(df, TRAINING_WINDOW_SIZE, TRAINING_FEATURE_ORDER)
|
||||
|
||||
return X, window_times
|
||||
|
||||
|
||||
def create_sliding_windows(
|
||||
df: pd.DataFrame,
|
||||
window_size: int,
|
||||
feature_cols: List[str]
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Create sliding windows from per-candle features and flatten.
|
||||
|
||||
Each window of `window_size` consecutive candles is flattened into a single
|
||||
row of features: [feat0_candle0, feat1_candle0, ..., featN_candle0,
|
||||
feat0_candle1, ..., featN_candleM]
|
||||
|
||||
Args:
|
||||
df: DataFrame with per-candle features and 'time' column
|
||||
window_size: Number of candles per window
|
||||
feature_cols: Ordered list of feature column names
|
||||
|
||||
# Warn if too much data was lost
|
||||
if rows_dropped / original_rows > 0.5:
|
||||
raise ValueError(
|
||||
f"More than 50% of candles dropped due to indicator warmup "
|
||||
f"({rows_dropped}/{original_rows}). Provide more historical candles."
|
||||
)
|
||||
Returns:
|
||||
Tuple of:
|
||||
- X: numpy array of shape (n_windows, window_size * n_features)
|
||||
- window_times: array of time values (last candle time in each window)
|
||||
"""
|
||||
n_candles = len(df)
|
||||
n_features = len(feature_cols)
|
||||
|
||||
logger.info(f"Preprocessing complete: {len(df_clean)} candles ready for prediction")
|
||||
if n_candles < window_size:
|
||||
raise ValueError(
|
||||
f"Not enough candles ({n_candles}) for window size {window_size}. "
|
||||
f"Need at least {window_size} candles."
|
||||
)
|
||||
|
||||
return df_clean
|
||||
# Extract feature matrix in correct column order
|
||||
feature_matrix = df[feature_cols].values # shape: (n_candles, n_features)
|
||||
times = df['time'].values
|
||||
|
||||
n_windows = n_candles - window_size + 1
|
||||
|
||||
# Create flattened windows using stride tricks for efficiency
|
||||
# Each window: candle features are interleaved as col_0, col_1, ..., col_N for each candle index
|
||||
X = np.zeros((n_windows, window_size * n_features), dtype=np.float64)
|
||||
window_times = np.zeros(n_windows, dtype=times.dtype)
|
||||
|
||||
for i in range(n_windows):
|
||||
window = feature_matrix[i:i + window_size] # shape: (window_size, n_features)
|
||||
# Flatten: row-major means [candle0_feat0, candle0_feat1, ..., candle1_feat0, ...]
|
||||
# But training used {col}_{candle_idx} ordering, which is column-first per candle
|
||||
# i.e., open_0, high_0, ..., gap_0, open_1, high_1, ..., gap_1, ...
|
||||
X[i] = window.flatten() # row-major: candle0_all_feats, candle1_all_feats, ...
|
||||
window_times[i] = times[i + window_size - 1] # last candle in window
|
||||
|
||||
logger.info(f"Created {n_windows} sliding windows of size {window_size} "
|
||||
f"({n_windows * n_features * window_size} total features)")
|
||||
|
||||
return X, window_times
|
||||
|
||||
|
||||
def extract_feature_columns(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue