Fix inference feature mismatch with training metadata
This commit is contained in:
parent
328476a581
commit
73c10a4156
3 changed files with 137 additions and 10 deletions
|
|
@ -124,6 +124,7 @@ async def lifespan(app: FastAPI):
|
|||
else:
|
||||
logger.error(f"Unknown model_source: {inference_config.model_source}")
|
||||
|
||||
state.feature_columns = state.model_info.get("feature_columns")
|
||||
logger.info("Model loaded successfully")
|
||||
logger.info(f"Model info: {state.model_info['model_name']} "
|
||||
f"v{state.model_info['model_version']} "
|
||||
|
|
@ -181,6 +182,22 @@ class AppState:
|
|||
|
||||
state = AppState()
|
||||
|
||||
def _get_model_n_features(model: Any) -> Optional[int]:
|
||||
"""
|
||||
Extract expected feature count from a model (supports wrappers).
|
||||
"""
|
||||
if hasattr(model, "n_features_in_"):
|
||||
try:
|
||||
return int(model.n_features_in_)
|
||||
except Exception:
|
||||
return None
|
||||
if hasattr(model, "model") and hasattr(model.model, "n_features_in_"):
|
||||
try:
|
||||
return int(model.model.n_features_in_)
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
# --- Pydantic Models ---
|
||||
|
||||
|
|
@ -467,7 +484,8 @@ def load_model_from_local(model_path: str) -> tuple[Any, Dict[str, Any]]:
|
|||
"dataset_version": metadata.get("dataset_version", None),
|
||||
"feature_engineering_enabled": metadata.get("feature_engineering_enabled", True),
|
||||
"labels": labels,
|
||||
"per_class_metrics": metadata.get("per_class_metrics", [])
|
||||
"per_class_metrics": metadata.get("per_class_metrics", []),
|
||||
"feature_columns": metadata.get("feature_columns", None),
|
||||
}
|
||||
|
||||
logger.info(f"Successfully loaded local model: {model_path.name}")
|
||||
|
|
@ -707,7 +725,25 @@ async def predict(request: PredictRequest):
|
|||
candles_data = [candle.model_dump() for candle in candles_sorted]
|
||||
|
||||
# Preprocess candles (feature engineering + windowing)
|
||||
X, window_times = preprocess_candles(candles_data, state.pipeline_config)
|
||||
training_feature_columns = (
|
||||
current_model_info.get("feature_columns") if current_model_info else None
|
||||
)
|
||||
X, window_times = preprocess_candles(
|
||||
candles_data,
|
||||
state.pipeline_config,
|
||||
training_feature_columns=training_feature_columns
|
||||
)
|
||||
|
||||
expected_n_features = _get_model_n_features(current_model)
|
||||
if expected_n_features is not None and X.shape[1] != expected_n_features:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=(
|
||||
f"Feature mismatch: model expects {expected_n_features} features, "
|
||||
f"but preprocessing produced {X.shape[1]}. "
|
||||
"Ensure the loaded model matches the inference preprocessing config."
|
||||
),
|
||||
)
|
||||
|
||||
# Get predictions and probabilities (using local reference, outside lock)
|
||||
if hasattr(current_model, 'predict_proba'):
|
||||
|
|
@ -765,7 +801,7 @@ async def predict(request: PredictRequest):
|
|||
logger.error(f"Prediction validation error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Internal server error"
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Prediction failed: {e}", exc_info=True)
|
||||
|
|
@ -883,7 +919,25 @@ async def predict_batch(request: BatchPredictRequest):
|
|||
batch_candles = batch_df.to_dict('records')
|
||||
|
||||
# Preprocess (feature engineering + windowing)
|
||||
X, window_times = preprocess_candles(batch_candles, state.pipeline_config)
|
||||
training_feature_columns = (
|
||||
current_model_info.get("feature_columns") if current_model_info else None
|
||||
)
|
||||
X, window_times = preprocess_candles(
|
||||
batch_candles,
|
||||
state.pipeline_config,
|
||||
training_feature_columns=training_feature_columns
|
||||
)
|
||||
|
||||
expected_n_features = _get_model_n_features(current_model)
|
||||
if expected_n_features is not None and X.shape[1] != expected_n_features:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=(
|
||||
f"Feature mismatch: model expects {expected_n_features} features, "
|
||||
f"but preprocessing produced {X.shape[1]}. "
|
||||
"Ensure the loaded model matches the inference preprocessing config."
|
||||
),
|
||||
)
|
||||
|
||||
# Predict (using local reference, outside lock)
|
||||
if hasattr(current_model, 'predict_proba'):
|
||||
|
|
@ -1755,6 +1809,7 @@ async def model_load(request: ModelLoadRequest):
|
|||
with _model_swap_lock:
|
||||
state.model = new_model
|
||||
state.model_info = new_model_info
|
||||
state.feature_columns = new_model_info.get("feature_columns")
|
||||
|
||||
logger.info(f"Model hot-swapped: run_id={request.run_id}, type={row.model_type}")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue