Fix inference feature mismatch with training metadata

This commit is contained in:
Marko Djordjevic 2026-02-18 23:53:38 +01:00
parent 328476a581
commit 73c10a4156
3 changed files with 137 additions and 10 deletions

View file

@ -124,6 +124,7 @@ async def lifespan(app: FastAPI):
else: else:
logger.error(f"Unknown model_source: {inference_config.model_source}") logger.error(f"Unknown model_source: {inference_config.model_source}")
state.feature_columns = state.model_info.get("feature_columns")
logger.info("Model loaded successfully") logger.info("Model loaded successfully")
logger.info(f"Model info: {state.model_info['model_name']} " logger.info(f"Model info: {state.model_info['model_name']} "
f"v{state.model_info['model_version']} " f"v{state.model_info['model_version']} "
@ -181,6 +182,22 @@ class AppState:
state = AppState() state = AppState()
def _get_model_n_features(model: Any) -> Optional[int]:
"""
Extract expected feature count from a model (supports wrappers).
"""
if hasattr(model, "n_features_in_"):
try:
return int(model.n_features_in_)
except Exception:
return None
if hasattr(model, "model") and hasattr(model.model, "n_features_in_"):
try:
return int(model.model.n_features_in_)
except Exception:
return None
return None
# --- Pydantic Models --- # --- Pydantic Models ---
@ -467,7 +484,8 @@ def load_model_from_local(model_path: str) -> tuple[Any, Dict[str, Any]]:
"dataset_version": metadata.get("dataset_version", None), "dataset_version": metadata.get("dataset_version", None),
"feature_engineering_enabled": metadata.get("feature_engineering_enabled", True), "feature_engineering_enabled": metadata.get("feature_engineering_enabled", True),
"labels": labels, "labels": labels,
"per_class_metrics": metadata.get("per_class_metrics", []) "per_class_metrics": metadata.get("per_class_metrics", []),
"feature_columns": metadata.get("feature_columns", None),
} }
logger.info(f"Successfully loaded local model: {model_path.name}") logger.info(f"Successfully loaded local model: {model_path.name}")
@ -707,7 +725,25 @@ async def predict(request: PredictRequest):
candles_data = [candle.model_dump() for candle in candles_sorted] candles_data = [candle.model_dump() for candle in candles_sorted]
# Preprocess candles (feature engineering + windowing) # Preprocess candles (feature engineering + windowing)
X, window_times = preprocess_candles(candles_data, state.pipeline_config) training_feature_columns = (
current_model_info.get("feature_columns") if current_model_info else None
)
X, window_times = preprocess_candles(
candles_data,
state.pipeline_config,
training_feature_columns=training_feature_columns
)
expected_n_features = _get_model_n_features(current_model)
if expected_n_features is not None and X.shape[1] != expected_n_features:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=(
f"Feature mismatch: model expects {expected_n_features} features, "
f"but preprocessing produced {X.shape[1]}. "
"Ensure the loaded model matches the inference preprocessing config."
),
)
# Get predictions and probabilities (using local reference, outside lock) # Get predictions and probabilities (using local reference, outside lock)
if hasattr(current_model, 'predict_proba'): if hasattr(current_model, 'predict_proba'):
@ -765,7 +801,7 @@ async def predict(request: PredictRequest):
logger.error(f"Prediction validation error: {e}") logger.error(f"Prediction validation error: {e}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="Internal server error" detail=str(e)
) )
except Exception as e: except Exception as e:
logger.error(f"Prediction failed: {e}", exc_info=True) logger.error(f"Prediction failed: {e}", exc_info=True)
@ -883,7 +919,25 @@ async def predict_batch(request: BatchPredictRequest):
batch_candles = batch_df.to_dict('records') batch_candles = batch_df.to_dict('records')
# Preprocess (feature engineering + windowing) # Preprocess (feature engineering + windowing)
X, window_times = preprocess_candles(batch_candles, state.pipeline_config) training_feature_columns = (
current_model_info.get("feature_columns") if current_model_info else None
)
X, window_times = preprocess_candles(
batch_candles,
state.pipeline_config,
training_feature_columns=training_feature_columns
)
expected_n_features = _get_model_n_features(current_model)
if expected_n_features is not None and X.shape[1] != expected_n_features:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=(
f"Feature mismatch: model expects {expected_n_features} features, "
f"but preprocessing produced {X.shape[1]}. "
"Ensure the loaded model matches the inference preprocessing config."
),
)
# Predict (using local reference, outside lock) # Predict (using local reference, outside lock)
if hasattr(current_model, 'predict_proba'): if hasattr(current_model, 'predict_proba'):
@ -1755,6 +1809,7 @@ async def model_load(request: ModelLoadRequest):
with _model_swap_lock: with _model_swap_lock:
state.model = new_model state.model = new_model
state.model_info = new_model_info state.model_info = new_model_info
state.feature_columns = new_model_info.get("feature_columns")
logger.info(f"Model hot-swapped: run_id={request.run_id}, type={row.model_type}") logger.info(f"Model hot-swapped: run_id={request.run_id}, type={row.model_type}")

View file

@ -6,7 +6,8 @@ between training and inference.
""" """
import logging import logging
from typing import List, Tuple import re
from typing import List, Tuple, Optional
import pandas as pd import pandas as pd
import numpy as np import numpy as np
@ -34,9 +35,51 @@ TRAINING_FEATURE_ORDER = [
] ]
def _parse_training_feature_columns(
feature_columns: List[str]
) -> Tuple[int, List[str]]:
"""
Derive window size and per-candle feature order from flattened training columns.
Expected column format: "<feature>_<index>" (e.g., "open_0", "rsi_14_12").
"""
if not feature_columns:
raise ValueError("Training feature columns are empty")
feature_order: List[str] = []
max_idx = -1
idx_set = set()
for col in feature_columns:
match = re.match(r"^(.*)_([0-9]+)$", col)
if not match:
raise ValueError(f"Invalid training feature column format: {col}")
base = match.group(1)
idx = int(match.group(2))
if idx == 0:
feature_order.append(base)
if idx > max_idx:
max_idx = idx
idx_set.add(idx)
window_size = max_idx + 1
if window_size <= 0:
raise ValueError("Could not derive window size from training feature columns")
missing_idx = set(range(window_size)) - idx_set
if missing_idx:
raise ValueError(f"Missing window indices in training feature columns: {sorted(missing_idx)[:5]}")
if not feature_order:
raise ValueError("Could not derive per-candle feature order from training feature columns")
return window_size, feature_order
def preprocess_candles( def preprocess_candles(
candles: List[dict], candles: List[dict],
pipeline_config: PipelineConfig pipeline_config: PipelineConfig,
training_feature_columns: Optional[List[str]] = None
) -> Tuple[pd.DataFrame, np.ndarray]: ) -> Tuple[pd.DataFrame, np.ndarray]:
""" """
Preprocess candle data for inference. Preprocess candle data for inference.
@ -124,16 +167,24 @@ def preprocess_candles(
logger.info(f"Filling NaN values in {len(nan_cols)} columns (indicator warmup + missing data)") logger.info(f"Filling NaN values in {len(nan_cols)} columns (indicator warmup + missing data)")
df = df.fillna(0.0) df = df.fillna(0.0)
# Determine expected feature order and window size
if training_feature_columns:
window_size, feature_order = _parse_training_feature_columns(training_feature_columns)
logger.info(f"Using training feature columns: {len(feature_order)} features, window_size={window_size}")
else:
window_size = TRAINING_WINDOW_SIZE
feature_order = TRAINING_FEATURE_ORDER
# Ensure all expected per-candle features exist # Ensure all expected per-candle features exist
for col in TRAINING_FEATURE_ORDER: for col in feature_order:
if col not in df.columns: if col not in df.columns:
logger.warning(f"Missing expected feature column '{col}', filling with 0") logger.warning(f"Missing expected feature column '{col}', filling with 0")
df[col] = 0.0 df[col] = 0.0
logger.info(f"Preprocessing complete: {len(df)} candles with {len(TRAINING_FEATURE_ORDER)} features each") logger.info(f"Preprocessing complete: {len(df)} candles with {len(feature_order)} features each")
# Create sliding windows and flatten # Create sliding windows and flatten
X, window_times = create_sliding_windows(df, TRAINING_WINDOW_SIZE, TRAINING_FEATURE_ORDER) X, window_times = create_sliding_windows(df, window_size, feature_order)
return X, window_times return X, window_times

View file

@ -388,7 +388,28 @@ def train(
import joblib import joblib
output_model_path = Path(output_model_path) output_model_path = Path(output_model_path)
output_model_path.parent.mkdir(parents=True, exist_ok=True) output_model_path.parent.mkdir(parents=True, exist_ok=True)
joblib.dump(model, output_model_path)
labels = []
try:
if hasattr(model, "classes_"):
labels = [str(c) for c in model.classes_]
elif hasattr(model, "model") and hasattr(model.model, "classes_"):
labels = [str(c) for c in model.model.classes_]
except Exception:
labels = []
model_data = {
"model": model,
"metadata": {
"model_type": training_config.model_type,
"trained_at": datetime.utcnow().isoformat(),
"run_id": run_id,
"feature_columns": feature_cols,
"feature_engineering_enabled": config.stages.feature_engineering.enabled,
"labels": labels,
},
}
joblib.dump(model_data, output_model_path)
logger.info(f"Saved model to {output_model_path}") logger.info(f"Saved model to {output_model_path}")
# Update training run record in PostgreSQL # Update training run record in PostgreSQL