From 73c10a4156cb88e5cf84d23e664d076625b737d1 Mon Sep 17 00:00:00 2001 From: Marko Djordjevic Date: Wed, 18 Feb 2026 23:53:38 +0100 Subject: [PATCH] Fix inference feature mismatch with training metadata --- services/ml/app/main.py | 63 ++++++++++++++++++++++++++++++-- services/ml/app/preprocessing.py | 61 ++++++++++++++++++++++++++++--- services/ml/training/train.py | 23 +++++++++++- 3 files changed, 137 insertions(+), 10 deletions(-) diff --git a/services/ml/app/main.py b/services/ml/app/main.py index d428520..82d4c6b 100644 --- a/services/ml/app/main.py +++ b/services/ml/app/main.py @@ -124,6 +124,7 @@ async def lifespan(app: FastAPI): else: logger.error(f"Unknown model_source: {inference_config.model_source}") + state.feature_columns = state.model_info.get("feature_columns") logger.info("Model loaded successfully") logger.info(f"Model info: {state.model_info['model_name']} " f"v{state.model_info['model_version']} " @@ -181,6 +182,22 @@ class AppState: state = AppState() +def _get_model_n_features(model: Any) -> Optional[int]: + """ + Extract expected feature count from a model (supports wrappers). + """ + if hasattr(model, "n_features_in_"): + try: + return int(model.n_features_in_) + except Exception: + return None + if hasattr(model, "model") and hasattr(model.model, "n_features_in_"): + try: + return int(model.model.n_features_in_) + except Exception: + return None + return None + # --- Pydantic Models --- @@ -467,7 +484,8 @@ def load_model_from_local(model_path: str) -> tuple[Any, Dict[str, Any]]: "dataset_version": metadata.get("dataset_version", None), "feature_engineering_enabled": metadata.get("feature_engineering_enabled", True), "labels": labels, - "per_class_metrics": metadata.get("per_class_metrics", []) + "per_class_metrics": metadata.get("per_class_metrics", []), + "feature_columns": metadata.get("feature_columns", None), } logger.info(f"Successfully loaded local model: {model_path.name}") @@ -707,7 +725,25 @@ async def predict(request: PredictRequest): candles_data = [candle.model_dump() for candle in candles_sorted] # Preprocess candles (feature engineering + windowing) - X, window_times = preprocess_candles(candles_data, state.pipeline_config) + training_feature_columns = ( + current_model_info.get("feature_columns") if current_model_info else None + ) + X, window_times = preprocess_candles( + candles_data, + state.pipeline_config, + training_feature_columns=training_feature_columns + ) + + expected_n_features = _get_model_n_features(current_model) + if expected_n_features is not None and X.shape[1] != expected_n_features: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=( + f"Feature mismatch: model expects {expected_n_features} features, " + f"but preprocessing produced {X.shape[1]}. " + "Ensure the loaded model matches the inference preprocessing config." + ), + ) # Get predictions and probabilities (using local reference, outside lock) if hasattr(current_model, 'predict_proba'): @@ -765,7 +801,7 @@ async def predict(request: PredictRequest): logger.error(f"Prediction validation error: {e}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Internal server error" + detail=str(e) ) except Exception as e: logger.error(f"Prediction failed: {e}", exc_info=True) @@ -883,7 +919,25 @@ async def predict_batch(request: BatchPredictRequest): batch_candles = batch_df.to_dict('records') # Preprocess (feature engineering + windowing) - X, window_times = preprocess_candles(batch_candles, state.pipeline_config) + training_feature_columns = ( + current_model_info.get("feature_columns") if current_model_info else None + ) + X, window_times = preprocess_candles( + batch_candles, + state.pipeline_config, + training_feature_columns=training_feature_columns + ) + + expected_n_features = _get_model_n_features(current_model) + if expected_n_features is not None and X.shape[1] != expected_n_features: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=( + f"Feature mismatch: model expects {expected_n_features} features, " + f"but preprocessing produced {X.shape[1]}. " + "Ensure the loaded model matches the inference preprocessing config." + ), + ) # Predict (using local reference, outside lock) if hasattr(current_model, 'predict_proba'): @@ -1755,6 +1809,7 @@ async def model_load(request: ModelLoadRequest): with _model_swap_lock: state.model = new_model state.model_info = new_model_info + state.feature_columns = new_model_info.get("feature_columns") logger.info(f"Model hot-swapped: run_id={request.run_id}, type={row.model_type}") diff --git a/services/ml/app/preprocessing.py b/services/ml/app/preprocessing.py index 2a61189..415aa97 100644 --- a/services/ml/app/preprocessing.py +++ b/services/ml/app/preprocessing.py @@ -6,7 +6,8 @@ between training and inference. """ import logging -from typing import List, Tuple +import re +from typing import List, Tuple, Optional import pandas as pd import numpy as np @@ -34,9 +35,51 @@ TRAINING_FEATURE_ORDER = [ ] +def _parse_training_feature_columns( + feature_columns: List[str] +) -> Tuple[int, List[str]]: + """ + Derive window size and per-candle feature order from flattened training columns. + + Expected column format: "_" (e.g., "open_0", "rsi_14_12"). + """ + if not feature_columns: + raise ValueError("Training feature columns are empty") + + feature_order: List[str] = [] + max_idx = -1 + idx_set = set() + + for col in feature_columns: + match = re.match(r"^(.*)_([0-9]+)$", col) + if not match: + raise ValueError(f"Invalid training feature column format: {col}") + base = match.group(1) + idx = int(match.group(2)) + if idx == 0: + feature_order.append(base) + if idx > max_idx: + max_idx = idx + idx_set.add(idx) + + window_size = max_idx + 1 + if window_size <= 0: + raise ValueError("Could not derive window size from training feature columns") + + missing_idx = set(range(window_size)) - idx_set + if missing_idx: + raise ValueError(f"Missing window indices in training feature columns: {sorted(missing_idx)[:5]}") + + if not feature_order: + raise ValueError("Could not derive per-candle feature order from training feature columns") + + return window_size, feature_order + + def preprocess_candles( candles: List[dict], - pipeline_config: PipelineConfig + pipeline_config: PipelineConfig, + training_feature_columns: Optional[List[str]] = None ) -> Tuple[pd.DataFrame, np.ndarray]: """ Preprocess candle data for inference. @@ -124,16 +167,24 @@ def preprocess_candles( logger.info(f"Filling NaN values in {len(nan_cols)} columns (indicator warmup + missing data)") df = df.fillna(0.0) + # Determine expected feature order and window size + if training_feature_columns: + window_size, feature_order = _parse_training_feature_columns(training_feature_columns) + logger.info(f"Using training feature columns: {len(feature_order)} features, window_size={window_size}") + else: + window_size = TRAINING_WINDOW_SIZE + feature_order = TRAINING_FEATURE_ORDER + # Ensure all expected per-candle features exist - for col in TRAINING_FEATURE_ORDER: + for col in feature_order: if col not in df.columns: logger.warning(f"Missing expected feature column '{col}', filling with 0") df[col] = 0.0 - logger.info(f"Preprocessing complete: {len(df)} candles with {len(TRAINING_FEATURE_ORDER)} features each") + logger.info(f"Preprocessing complete: {len(df)} candles with {len(feature_order)} features each") # Create sliding windows and flatten - X, window_times = create_sliding_windows(df, TRAINING_WINDOW_SIZE, TRAINING_FEATURE_ORDER) + X, window_times = create_sliding_windows(df, window_size, feature_order) return X, window_times diff --git a/services/ml/training/train.py b/services/ml/training/train.py index aaa4d90..4f3ba51 100644 --- a/services/ml/training/train.py +++ b/services/ml/training/train.py @@ -388,7 +388,28 @@ def train( import joblib output_model_path = Path(output_model_path) output_model_path.parent.mkdir(parents=True, exist_ok=True) - joblib.dump(model, output_model_path) + + labels = [] + try: + if hasattr(model, "classes_"): + labels = [str(c) for c in model.classes_] + elif hasattr(model, "model") and hasattr(model.model, "classes_"): + labels = [str(c) for c in model.model.classes_] + except Exception: + labels = [] + + model_data = { + "model": model, + "metadata": { + "model_type": training_config.model_type, + "trained_at": datetime.utcnow().isoformat(), + "run_id": run_id, + "feature_columns": feature_cols, + "feature_engineering_enabled": config.stages.feature_engineering.enabled, + "labels": labels, + }, + } + joblib.dump(model_data, output_model_path) logger.info(f"Saved model to {output_model_path}") # Update training run record in PostgreSQL