Fix inference feature mismatch with training metadata
This commit is contained in:
parent
328476a581
commit
73c10a4156
3 changed files with 137 additions and 10 deletions
|
|
@ -124,6 +124,7 @@ async def lifespan(app: FastAPI):
|
||||||
else:
|
else:
|
||||||
logger.error(f"Unknown model_source: {inference_config.model_source}")
|
logger.error(f"Unknown model_source: {inference_config.model_source}")
|
||||||
|
|
||||||
|
state.feature_columns = state.model_info.get("feature_columns")
|
||||||
logger.info("Model loaded successfully")
|
logger.info("Model loaded successfully")
|
||||||
logger.info(f"Model info: {state.model_info['model_name']} "
|
logger.info(f"Model info: {state.model_info['model_name']} "
|
||||||
f"v{state.model_info['model_version']} "
|
f"v{state.model_info['model_version']} "
|
||||||
|
|
@ -181,6 +182,22 @@ class AppState:
|
||||||
|
|
||||||
state = AppState()
|
state = AppState()
|
||||||
|
|
||||||
|
def _get_model_n_features(model: Any) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
Extract expected feature count from a model (supports wrappers).
|
||||||
|
"""
|
||||||
|
if hasattr(model, "n_features_in_"):
|
||||||
|
try:
|
||||||
|
return int(model.n_features_in_)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
if hasattr(model, "model") and hasattr(model.model, "n_features_in_"):
|
||||||
|
try:
|
||||||
|
return int(model.model.n_features_in_)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
# --- Pydantic Models ---
|
# --- Pydantic Models ---
|
||||||
|
|
||||||
|
|
@ -467,7 +484,8 @@ def load_model_from_local(model_path: str) -> tuple[Any, Dict[str, Any]]:
|
||||||
"dataset_version": metadata.get("dataset_version", None),
|
"dataset_version": metadata.get("dataset_version", None),
|
||||||
"feature_engineering_enabled": metadata.get("feature_engineering_enabled", True),
|
"feature_engineering_enabled": metadata.get("feature_engineering_enabled", True),
|
||||||
"labels": labels,
|
"labels": labels,
|
||||||
"per_class_metrics": metadata.get("per_class_metrics", [])
|
"per_class_metrics": metadata.get("per_class_metrics", []),
|
||||||
|
"feature_columns": metadata.get("feature_columns", None),
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"Successfully loaded local model: {model_path.name}")
|
logger.info(f"Successfully loaded local model: {model_path.name}")
|
||||||
|
|
@ -707,7 +725,25 @@ async def predict(request: PredictRequest):
|
||||||
candles_data = [candle.model_dump() for candle in candles_sorted]
|
candles_data = [candle.model_dump() for candle in candles_sorted]
|
||||||
|
|
||||||
# Preprocess candles (feature engineering + windowing)
|
# Preprocess candles (feature engineering + windowing)
|
||||||
X, window_times = preprocess_candles(candles_data, state.pipeline_config)
|
training_feature_columns = (
|
||||||
|
current_model_info.get("feature_columns") if current_model_info else None
|
||||||
|
)
|
||||||
|
X, window_times = preprocess_candles(
|
||||||
|
candles_data,
|
||||||
|
state.pipeline_config,
|
||||||
|
training_feature_columns=training_feature_columns
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_n_features = _get_model_n_features(current_model)
|
||||||
|
if expected_n_features is not None and X.shape[1] != expected_n_features:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=(
|
||||||
|
f"Feature mismatch: model expects {expected_n_features} features, "
|
||||||
|
f"but preprocessing produced {X.shape[1]}. "
|
||||||
|
"Ensure the loaded model matches the inference preprocessing config."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
# Get predictions and probabilities (using local reference, outside lock)
|
# Get predictions and probabilities (using local reference, outside lock)
|
||||||
if hasattr(current_model, 'predict_proba'):
|
if hasattr(current_model, 'predict_proba'):
|
||||||
|
|
@ -765,7 +801,7 @@ async def predict(request: PredictRequest):
|
||||||
logger.error(f"Prediction validation error: {e}")
|
logger.error(f"Prediction validation error: {e}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail="Internal server error"
|
detail=str(e)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Prediction failed: {e}", exc_info=True)
|
logger.error(f"Prediction failed: {e}", exc_info=True)
|
||||||
|
|
@ -883,7 +919,25 @@ async def predict_batch(request: BatchPredictRequest):
|
||||||
batch_candles = batch_df.to_dict('records')
|
batch_candles = batch_df.to_dict('records')
|
||||||
|
|
||||||
# Preprocess (feature engineering + windowing)
|
# Preprocess (feature engineering + windowing)
|
||||||
X, window_times = preprocess_candles(batch_candles, state.pipeline_config)
|
training_feature_columns = (
|
||||||
|
current_model_info.get("feature_columns") if current_model_info else None
|
||||||
|
)
|
||||||
|
X, window_times = preprocess_candles(
|
||||||
|
batch_candles,
|
||||||
|
state.pipeline_config,
|
||||||
|
training_feature_columns=training_feature_columns
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_n_features = _get_model_n_features(current_model)
|
||||||
|
if expected_n_features is not None and X.shape[1] != expected_n_features:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=(
|
||||||
|
f"Feature mismatch: model expects {expected_n_features} features, "
|
||||||
|
f"but preprocessing produced {X.shape[1]}. "
|
||||||
|
"Ensure the loaded model matches the inference preprocessing config."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
# Predict (using local reference, outside lock)
|
# Predict (using local reference, outside lock)
|
||||||
if hasattr(current_model, 'predict_proba'):
|
if hasattr(current_model, 'predict_proba'):
|
||||||
|
|
@ -1755,6 +1809,7 @@ async def model_load(request: ModelLoadRequest):
|
||||||
with _model_swap_lock:
|
with _model_swap_lock:
|
||||||
state.model = new_model
|
state.model = new_model
|
||||||
state.model_info = new_model_info
|
state.model_info = new_model_info
|
||||||
|
state.feature_columns = new_model_info.get("feature_columns")
|
||||||
|
|
||||||
logger.info(f"Model hot-swapped: run_id={request.run_id}, type={row.model_type}")
|
logger.info(f"Model hot-swapped: run_id={request.run_id}, type={row.model_type}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,8 @@ between training and inference.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Tuple
|
import re
|
||||||
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -34,9 +35,51 @@ TRAINING_FEATURE_ORDER = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_training_feature_columns(
|
||||||
|
feature_columns: List[str]
|
||||||
|
) -> Tuple[int, List[str]]:
|
||||||
|
"""
|
||||||
|
Derive window size and per-candle feature order from flattened training columns.
|
||||||
|
|
||||||
|
Expected column format: "<feature>_<index>" (e.g., "open_0", "rsi_14_12").
|
||||||
|
"""
|
||||||
|
if not feature_columns:
|
||||||
|
raise ValueError("Training feature columns are empty")
|
||||||
|
|
||||||
|
feature_order: List[str] = []
|
||||||
|
max_idx = -1
|
||||||
|
idx_set = set()
|
||||||
|
|
||||||
|
for col in feature_columns:
|
||||||
|
match = re.match(r"^(.*)_([0-9]+)$", col)
|
||||||
|
if not match:
|
||||||
|
raise ValueError(f"Invalid training feature column format: {col}")
|
||||||
|
base = match.group(1)
|
||||||
|
idx = int(match.group(2))
|
||||||
|
if idx == 0:
|
||||||
|
feature_order.append(base)
|
||||||
|
if idx > max_idx:
|
||||||
|
max_idx = idx
|
||||||
|
idx_set.add(idx)
|
||||||
|
|
||||||
|
window_size = max_idx + 1
|
||||||
|
if window_size <= 0:
|
||||||
|
raise ValueError("Could not derive window size from training feature columns")
|
||||||
|
|
||||||
|
missing_idx = set(range(window_size)) - idx_set
|
||||||
|
if missing_idx:
|
||||||
|
raise ValueError(f"Missing window indices in training feature columns: {sorted(missing_idx)[:5]}")
|
||||||
|
|
||||||
|
if not feature_order:
|
||||||
|
raise ValueError("Could not derive per-candle feature order from training feature columns")
|
||||||
|
|
||||||
|
return window_size, feature_order
|
||||||
|
|
||||||
|
|
||||||
def preprocess_candles(
|
def preprocess_candles(
|
||||||
candles: List[dict],
|
candles: List[dict],
|
||||||
pipeline_config: PipelineConfig
|
pipeline_config: PipelineConfig,
|
||||||
|
training_feature_columns: Optional[List[str]] = None
|
||||||
) -> Tuple[pd.DataFrame, np.ndarray]:
|
) -> Tuple[pd.DataFrame, np.ndarray]:
|
||||||
"""
|
"""
|
||||||
Preprocess candle data for inference.
|
Preprocess candle data for inference.
|
||||||
|
|
@ -124,16 +167,24 @@ def preprocess_candles(
|
||||||
logger.info(f"Filling NaN values in {len(nan_cols)} columns (indicator warmup + missing data)")
|
logger.info(f"Filling NaN values in {len(nan_cols)} columns (indicator warmup + missing data)")
|
||||||
df = df.fillna(0.0)
|
df = df.fillna(0.0)
|
||||||
|
|
||||||
|
# Determine expected feature order and window size
|
||||||
|
if training_feature_columns:
|
||||||
|
window_size, feature_order = _parse_training_feature_columns(training_feature_columns)
|
||||||
|
logger.info(f"Using training feature columns: {len(feature_order)} features, window_size={window_size}")
|
||||||
|
else:
|
||||||
|
window_size = TRAINING_WINDOW_SIZE
|
||||||
|
feature_order = TRAINING_FEATURE_ORDER
|
||||||
|
|
||||||
# Ensure all expected per-candle features exist
|
# Ensure all expected per-candle features exist
|
||||||
for col in TRAINING_FEATURE_ORDER:
|
for col in feature_order:
|
||||||
if col not in df.columns:
|
if col not in df.columns:
|
||||||
logger.warning(f"Missing expected feature column '{col}', filling with 0")
|
logger.warning(f"Missing expected feature column '{col}', filling with 0")
|
||||||
df[col] = 0.0
|
df[col] = 0.0
|
||||||
|
|
||||||
logger.info(f"Preprocessing complete: {len(df)} candles with {len(TRAINING_FEATURE_ORDER)} features each")
|
logger.info(f"Preprocessing complete: {len(df)} candles with {len(feature_order)} features each")
|
||||||
|
|
||||||
# Create sliding windows and flatten
|
# Create sliding windows and flatten
|
||||||
X, window_times = create_sliding_windows(df, TRAINING_WINDOW_SIZE, TRAINING_FEATURE_ORDER)
|
X, window_times = create_sliding_windows(df, window_size, feature_order)
|
||||||
|
|
||||||
return X, window_times
|
return X, window_times
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -388,7 +388,28 @@ def train(
|
||||||
import joblib
|
import joblib
|
||||||
output_model_path = Path(output_model_path)
|
output_model_path = Path(output_model_path)
|
||||||
output_model_path.parent.mkdir(parents=True, exist_ok=True)
|
output_model_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
joblib.dump(model, output_model_path)
|
|
||||||
|
labels = []
|
||||||
|
try:
|
||||||
|
if hasattr(model, "classes_"):
|
||||||
|
labels = [str(c) for c in model.classes_]
|
||||||
|
elif hasattr(model, "model") and hasattr(model.model, "classes_"):
|
||||||
|
labels = [str(c) for c in model.model.classes_]
|
||||||
|
except Exception:
|
||||||
|
labels = []
|
||||||
|
|
||||||
|
model_data = {
|
||||||
|
"model": model,
|
||||||
|
"metadata": {
|
||||||
|
"model_type": training_config.model_type,
|
||||||
|
"trained_at": datetime.utcnow().isoformat(),
|
||||||
|
"run_id": run_id,
|
||||||
|
"feature_columns": feature_cols,
|
||||||
|
"feature_engineering_enabled": config.stages.feature_engineering.enabled,
|
||||||
|
"labels": labels,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
joblib.dump(model_data, output_model_path)
|
||||||
logger.info(f"Saved model to {output_model_path}")
|
logger.info(f"Saved model to {output_model_path}")
|
||||||
|
|
||||||
# Update training run record in PostgreSQL
|
# Update training run record in PostgreSQL
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue