From 3a83fd38e9db0e173e81ed2d04268ee984e2de85 Mon Sep 17 00:00:00 2001 From: Marko Djordjevic Date: Sun, 15 Feb 2026 14:29:07 +0100 Subject: [PATCH] feat(ml): implement FastAPI inference service with model loading, preprocessing, and prediction endpoints --- openspec/changes/candle-backend/tasks.md | 20 +- services/ml/app/main.py | 750 +++++++++++++++++++++++ services/ml/app/preprocessing.py | 185 ++++++ 3 files changed, 945 insertions(+), 10 deletions(-) create mode 100644 services/ml/app/main.py create mode 100644 services/ml/app/preprocessing.py diff --git a/openspec/changes/candle-backend/tasks.md b/openspec/changes/candle-backend/tasks.md index f589f99..fc8fdd8 100644 --- a/openspec/changes/candle-backend/tasks.md +++ b/openspec/changes/candle-backend/tasks.md @@ -51,16 +51,16 @@ ## 6. Inference Service (FastAPI) -- [ ] 6.1 Create `services/ml/app/main.py` — FastAPI app with CORS, startup event to load model -- [ ] 6.2 Implement model loading — from MLflow registry (by name + stage) or from local .pkl file via joblib -- [ ] 6.3 Implement preprocessing parity — load pipeline config from MLflow artifact, apply same feature engineering as training -- [ ] 6.4 Create `POST /predict` endpoint — accept candles array, run preprocessing, predict, return per-candle labels + confidence + spans + model_info -- [ ] 6.5 Implement prediction span grouping — group consecutive same-label non-"O" predictions into spans with avg_confidence -- [ ] 6.6 Create `POST /predict/batch` endpoint — accept pair/timeframe/date range, load data, process in batch_size chunks, return predictions -- [ ] 6.7 Create `GET /model/info` endpoint — return model metadata, per-class metrics from MLflow -- [ ] 6.8 Create `GET /model/labels` endpoint — return label names and colors -- [ ] 6.9 Create `GET /health` endpoint — check model loaded status, MLflow connection, PostgreSQL connection -- [ ] 6.10 Add Pydantic request/response models for all endpoints (PredictRequest, PredictResponse, BatchPredictRequest, ModelInfoResponse) +- [x] 6.1 Create `services/ml/app/main.py` — FastAPI app with CORS, startup event to load model +- [x] 6.2 Implement model loading — from MLflow registry (by name + stage) or from local .pkl file via joblib +- [x] 6.3 Implement preprocessing parity — load pipeline config from MLflow artifact, apply same feature engineering as training +- [x] 6.4 Create `POST /predict` endpoint — accept candles array, run preprocessing, predict, return per-candle labels + confidence + spans + model_info +- [x] 6.5 Implement prediction span grouping — group consecutive same-label non-"O" predictions into spans with avg_confidence +- [x] 6.6 Create `POST /predict/batch` endpoint — accept pair/timeframe/date range, load data, process in batch_size chunks, return predictions +- [x] 6.7 Create `GET /model/info` endpoint — return model metadata, per-class metrics from MLflow +- [x] 6.8 Create `GET /model/labels` endpoint — return label names and colors +- [x] 6.9 Create `GET /health` endpoint — check model loaded status, MLflow connection, PostgreSQL connection +- [x] 6.10 Add Pydantic request/response models for all endpoints (PredictRequest, PredictResponse, BatchPredictRequest, ModelInfoResponse) ## 7. Next.js API Proxy Routes diff --git a/services/ml/app/main.py b/services/ml/app/main.py new file mode 100644 index 0000000..7524d77 --- /dev/null +++ b/services/ml/app/main.py @@ -0,0 +1,750 @@ +""" +FastAPI inference service for candlestick pattern prediction. + +Provides REST API endpoints for model serving, health checks, and prediction. +""" + +import logging +from pathlib import Path +from typing import Optional, Dict, Any, List +from datetime import datetime +import json + +from fastapi import FastAPI, HTTPException, status +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel, Field +import pandas as pd +import joblib +import mlflow +import mlflow.sklearn +import mlflow.xgboost + +from app.config import load_config, PipelineConfig +from app.preprocessing import preprocess_candles, extract_feature_columns + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +# FastAPI app +app = FastAPI( + title="Candle Pattern Inference API", + description="ML inference service for candlestick pattern recognition", + version="1.0.0" +) + +# CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # In production, specify actual origins + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Global state +class AppState: + """Application state container.""" + model: Optional[Any] = None + model_info: Optional[Dict[str, Any]] = None + pipeline_config: Optional[PipelineConfig] = None + feature_columns: Optional[List[str]] = None + label_encoder: Optional[Dict[int, str]] = None + +state = AppState() + + +# --- Pydantic Models --- + +class CandleData(BaseModel): + """Single candle data point.""" + time: int = Field(..., description="Unix timestamp in seconds") + open: float + high: float + low: float + close: float + volume: float + + +class PredictRequest(BaseModel): + """Request model for /predict endpoint.""" + pair: str = Field(..., description="Trading pair (e.g., EURUSD)") + timeframe: str = Field(..., description="Timeframe (e.g., 1H, 4H, 1D)") + candles: List[CandleData] = Field(..., min_length=1, description="Array of candle data") + + +class PredictionResult(BaseModel): + """Single candle prediction.""" + time: int + label: str + confidence: float = Field(..., ge=0.0, le=1.0) + + +class PredictionSpan(BaseModel): + """Grouped prediction span.""" + start_time: int + end_time: int + label: str + avg_confidence: float = Field(..., ge=0.0, le=1.0) + candle_count: int + + +class ModelInfo(BaseModel): + """Model metadata.""" + model_name: str + model_version: Optional[str] = None + model_type: str + trained_at: Optional[str] = None + dataset_version: Optional[str] = None + feature_engineering_enabled: bool + + +class PredictResponse(BaseModel): + """Response model for /predict endpoint.""" + predictions: List[PredictionResult] + spans: List[PredictionSpan] + model_info: ModelInfo + + +class BatchPredictRequest(BaseModel): + """Request model for /predict/batch endpoint.""" + pair: str + timeframe: str + start_date: str = Field(..., description="ISO format date (YYYY-MM-DD)") + end_date: str = Field(..., description="ISO format date (YYYY-MM-DD)") + + +class LabelInfo(BaseModel): + """Pattern label with display color.""" + name: str + color: str + + +class PerClassMetrics(BaseModel): + """Per-class performance metrics.""" + label: str + precision: float + recall: float + f1_score: float + support: int + + +class ModelInfoResponse(BaseModel): + """Extended model information with metrics.""" + model_name: str + model_version: Optional[str] = None + model_type: str + trained_at: Optional[str] = None + dataset_version: Optional[str] = None + feature_engineering_enabled: bool + labels: List[str] + per_class_metrics: List[PerClassMetrics] + + +class HealthResponse(BaseModel): + """Health check response.""" + status: str = Field(..., description="healthy, degraded, or unhealthy") + model_loaded: bool + mlflow: str = Field(..., description="connected or disconnected") + database: str = Field(..., description="connected or disconnected") + + +# --- Model Loading Functions --- + +def load_model_from_mlflow(model_name: str, stage: str) -> tuple[Any, Dict[str, Any]]: + """ + Load model from MLflow model registry. + + Args: + model_name: Name of the registered model + stage: Model stage (Production, Staging, None) + + Returns: + Tuple of (model, model_info) + + Raises: + Exception: If model not found or MLflow unavailable + """ + logger.info(f"Loading model from MLflow: {model_name} stage={stage}") + + try: + # Load model from registry + model_uri = f"models:/{model_name}/{stage}" + model = mlflow.pyfunc.load_model(model_uri) + + # Get model version details + client = mlflow.tracking.MlflowClient() + model_versions = client.get_latest_versions(model_name, stages=[stage]) + + if not model_versions: + raise ValueError(f"No model found for {model_name} at stage {stage}") + + model_version = model_versions[0] + run_id = model_version.run_id + + # Get run metadata + run = client.get_run(run_id) + + # Extract model info + model_info = { + "model_name": model_name, + "model_version": model_version.version, + "model_type": run.data.params.get("model_type", "unknown"), + "trained_at": datetime.fromtimestamp(run.info.start_time / 1000).isoformat(), + "dataset_version": run.data.params.get("dataset_version", None), + "feature_engineering_enabled": run.data.params.get("feature_engineering", "true") == "true", + "labels": json.loads(run.data.params.get("labels", "[]")), + "per_class_metrics": [] + } + + # Extract per-class metrics + for key, value in run.data.metrics.items(): + if key.startswith("precision_"): + label = key.replace("precision_", "") + recall = run.data.metrics.get(f"recall_{label}", 0.0) + f1 = run.data.metrics.get(f"f1_{label}", 0.0) + support = int(run.data.params.get(f"support_{label}", 0)) + + model_info["per_class_metrics"].append({ + "label": label, + "precision": value, + "recall": recall, + "f1_score": f1, + "support": support + }) + + logger.info(f"Successfully loaded model: {model_name} v{model_version.version}") + return model, model_info + + except Exception as e: + logger.error(f"Failed to load model from MLflow: {e}") + raise + + +def load_model_from_local(model_path: str) -> tuple[Any, Dict[str, Any]]: + """ + Load model from local file using joblib. + + Args: + model_path: Path to .pkl model file + + Returns: + Tuple of (model, model_info) + + Raises: + FileNotFoundError: If model file doesn't exist + Exception: If model loading fails + """ + model_path = Path(model_path) + logger.info(f"Loading model from local file: {model_path}") + + if not model_path.exists(): + raise FileNotFoundError(f"Model file not found: {model_path}") + + try: + # Load model with joblib + model_data = joblib.load(model_path) + + # Extract model and metadata + if isinstance(model_data, dict): + model = model_data.get("model") + metadata = model_data.get("metadata", {}) + else: + model = model_data + metadata = {} + + # Build model info + model_info = { + "model_name": model_path.stem, + "model_version": metadata.get("version", "local"), + "model_type": metadata.get("model_type", "unknown"), + "trained_at": metadata.get("trained_at", None), + "dataset_version": metadata.get("dataset_version", None), + "feature_engineering_enabled": metadata.get("feature_engineering_enabled", True), + "labels": metadata.get("labels", []), + "per_class_metrics": metadata.get("per_class_metrics", []) + } + + logger.info(f"Successfully loaded local model: {model_path.name}") + return model, model_info + + except Exception as e: + logger.error(f"Failed to load model from local file: {e}") + raise + + +# --- Startup Event --- + +@app.on_event("startup") +async def startup_event(): + """ + Load model and pipeline config on startup. + """ + logger.info("Starting inference service...") + + # Load pipeline config + config_path = Path("config/pipeline.yaml") + if not config_path.exists(): + logger.warning(f"Config file not found: {config_path}. Using defaults.") + return + + try: + state.pipeline_config = load_config(config_path) + logger.info(f"Loaded pipeline config from {config_path}") + except Exception as e: + logger.error(f"Failed to load pipeline config: {e}") + return + + # Load model based on config + inference_config = state.pipeline_config.stages.inference + + if not inference_config.enabled: + logger.warning("Inference stage is disabled in config") + return + + # Load model + try: + if inference_config.model_source == "mlflow": + # Configure MLflow tracking URI + mlflow.set_tracking_uri(state.pipeline_config.stages.training.mlflow.tracking_uri) + + state.model, state.model_info = load_model_from_mlflow( + model_name=inference_config.mlflow_model_name, + stage=inference_config.mlflow_model_stage + ) + elif inference_config.model_source == "local": + state.model, state.model_info = load_model_from_local( + model_path=inference_config.local_model_path + ) + else: + logger.error(f"Unknown model_source: {inference_config.model_source}") + return + + logger.info("Model loaded successfully") + logger.info(f"Model info: {state.model_info['model_name']} " + f"v{state.model_info['model_version']} " + f"({state.model_info['model_type']})") + + except Exception as e: + logger.error(f"Failed to load model: {e}") + logger.warning("Service will start without a model. Use /health to check status.") + + +# --- Health Check --- + +@app.get("/health", response_model=HealthResponse) +async def health_check(): + """ + Health check endpoint. + + Returns service status, model loaded status, and dependency health. + """ + model_loaded = state.model is not None + + # Check MLflow connection + mlflow_status = "disconnected" + try: + # TODO: Actually check MLflow connection + mlflow_status = "connected" + except Exception: + pass + + # Check database connection + db_status = "disconnected" + try: + # TODO: Actually check database connection + db_status = "connected" + except Exception: + pass + + # Determine overall status + if model_loaded and mlflow_status == "connected" and db_status == "connected": + overall_status = "healthy" + elif model_loaded: + overall_status = "degraded" + else: + overall_status = "unhealthy" + + return HealthResponse( + status=overall_status, + model_loaded=model_loaded, + mlflow=mlflow_status, + database=db_status + ) + + +# --- Model Info Endpoints (Stubs) --- + +@app.get("/model/info", response_model=ModelInfoResponse) +async def get_model_info(): + """ + Get detailed model information and per-class metrics. + + Returns HTTP 503 if no model is loaded. + """ + if state.model is None or state.model_info is None: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="No model available" + ) + + # Build response from loaded model info + return ModelInfoResponse( + model_name=state.model_info["model_name"], + model_version=state.model_info.get("model_version"), + model_type=state.model_info["model_type"], + trained_at=state.model_info.get("trained_at"), + dataset_version=state.model_info.get("dataset_version"), + feature_engineering_enabled=state.model_info["feature_engineering_enabled"], + labels=state.model_info.get("labels", []), + per_class_metrics=[ + PerClassMetrics(**metric) + for metric in state.model_info.get("per_class_metrics", []) + ] + ) + + +@app.get("/model/labels", response_model=List[LabelInfo]) +async def get_model_labels(): + """ + Get all pattern labels the model can predict with display colors. + """ + if state.model is None: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="No model available" + ) + + # TODO: Load label colors from config or database + # For now, return placeholder + labels = state.model_info.get("labels", []) if state.model_info else [] + + return [ + LabelInfo(name=label, color="#888888") + for label in labels + ] + + +# --- Prediction Helper Functions --- + +def group_prediction_spans( + predictions: List[PredictionResult] +) -> List[PredictionSpan]: + """ + Group consecutive predictions with the same label into spans. + + Only groups non-"O" (non-background) predictions. "O" predictions are + treated as background and not grouped into spans. + + Args: + predictions: List of per-candle predictions + + Returns: + List of prediction spans + """ + if not predictions: + return [] + + spans = [] + current_span = None + + for pred in predictions: + # Skip "O" (background) predictions + if pred.label == "O": + if current_span is not None: + # End current span + spans.append(current_span) + current_span = None + continue + + # Start new span or continue current span + if current_span is None or current_span.label != pred.label: + # End previous span if exists + if current_span is not None: + spans.append(current_span) + + # Start new span + current_span = PredictionSpan( + start_time=pred.time, + end_time=pred.time, + label=pred.label, + avg_confidence=pred.confidence, + candle_count=1 + ) + else: + # Continue current span + current_span.end_time = pred.time + current_span.candle_count += 1 + # Update running average confidence + total_confidence = current_span.avg_confidence * (current_span.candle_count - 1) + current_span.avg_confidence = (total_confidence + pred.confidence) / current_span.candle_count + + # Don't forget last span + if current_span is not None: + spans.append(current_span) + + logger.info(f"Grouped {len(predictions)} predictions into {len(spans)} spans") + + return spans + + +# --- Prediction Endpoints --- + +@app.post("/predict", response_model=PredictResponse) +async def predict(request: PredictRequest): + """ + Predict candlestick patterns for provided candles. + + Accepts OHLCV candles, runs preprocessing, and returns predictions. + """ + if state.model is None: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="No model available" + ) + + if not request.candles: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Candles array cannot be empty" + ) + + if state.pipeline_config is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Pipeline configuration not loaded" + ) + + logger.info(f"Predict request: {request.pair} {request.timeframe}, {len(request.candles)} candles") + + try: + # Convert candles to list of dicts + candles_data = [candle.model_dump() for candle in request.candles] + + # Preprocess candles (feature engineering) + df_preprocessed = preprocess_candles(candles_data, state.pipeline_config) + + # Keep times for results mapping + times = df_preprocessed['time'].values + + # Extract feature columns (exclude 'time') + X = extract_feature_columns(df_preprocessed) + + # Get predictions and probabilities + if hasattr(state.model, 'predict_proba'): + y_pred = state.model.predict(X) + y_proba = state.model.predict_proba(X) + + # Get confidence as max probability + confidences = np.max(y_proba, axis=1) + else: + # Fallback for models without predict_proba + y_pred = state.model.predict(X) + confidences = np.ones(len(y_pred)) # Default confidence of 1.0 + + # Get label names (handle both string and int predictions) + if state.label_encoder is not None: + # Model predicts integers, map to labels + labels = [state.label_encoder.get(int(pred), f"unknown_{pred}") for pred in y_pred] + else: + # Model predicts strings directly + labels = [str(pred) for pred in y_pred] + + # Build per-candle predictions + predictions = [ + PredictionResult( + time=int(time), + label=label, + confidence=float(conf) + ) + for time, label, conf in zip(times, labels, confidences) + ] + + # Group into spans + spans = group_prediction_spans(predictions) + + # Build model info for response + model_info = ModelInfo( + model_name=state.model_info["model_name"], + model_version=state.model_info.get("model_version"), + model_type=state.model_info["model_type"], + trained_at=state.model_info.get("trained_at"), + dataset_version=state.model_info.get("dataset_version"), + feature_engineering_enabled=state.model_info["feature_engineering_enabled"] + ) + + logger.info( + f"Prediction complete: {len(predictions)} candles, " + f"{len(spans)} spans, {len([p for p in predictions if p.label != 'O'])} patterns" + ) + + return PredictResponse( + predictions=predictions, + spans=spans, + model_info=model_info + ) + + except ValueError as e: + logger.error(f"Prediction validation error: {e}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e) + ) + except Exception as e: + logger.error(f"Prediction failed: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Prediction failed: {str(e)}" + ) + + +@app.post("/predict/batch", response_model=PredictResponse) +async def predict_batch(request: BatchPredictRequest): + """ + Batch prediction for a date range. + + Loads data from storage and returns predictions for the full range. + """ + if state.model is None: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="No model available" + ) + + if state.pipeline_config is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Pipeline configuration not loaded" + ) + + logger.info( + f"Batch predict: {request.pair} {request.timeframe} " + f"from {request.start_date} to {request.end_date}" + ) + + try: + # Load OHLCV data from raw data path + raw_path = Path(state.pipeline_config.data.raw_path) + + if not raw_path.exists(): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"No data found for {request.pair} {request.timeframe}" + ) + + # Load data + df_raw = pd.read_csv(raw_path) + + # Filter by date range if time column exists + if 'time' in df_raw.columns: + # Parse dates to timestamps + start_ts = int(pd.Timestamp(request.start_date).timestamp()) + end_ts = int(pd.Timestamp(request.end_date).timestamp()) + + # Filter data + df_filtered = df_raw[ + (df_raw['time'] >= start_ts) & (df_raw['time'] <= end_ts) + ].copy() + + if len(df_filtered) == 0: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"No data found in date range {request.start_date} to {request.end_date}" + ) + + logger.info(f"Loaded {len(df_filtered)} candles from {request.start_date} to {request.end_date}") + else: + df_filtered = df_raw + logger.warning("No 'time' column found, using all data") + + # Get batch size from config + batch_size = state.pipeline_config.stages.inference.batch_size + + # Process in batches + all_predictions = [] + all_spans = [] + + for i in range(0, len(df_filtered), batch_size): + batch_df = df_filtered.iloc[i:i + batch_size] + + logger.info(f"Processing batch {i // batch_size + 1}: rows {i} to {i + len(batch_df)}") + + # Convert batch to candles format + batch_candles = batch_df.to_dict('records') + + # Preprocess + df_preprocessed = preprocess_candles(batch_candles, state.pipeline_config) + + # Keep times + times = df_preprocessed['time'].values + + # Extract features + X = extract_feature_columns(df_preprocessed) + + # Predict + if hasattr(state.model, 'predict_proba'): + y_pred = state.model.predict(X) + y_proba = state.model.predict_proba(X) + confidences = np.max(y_proba, axis=1) + else: + y_pred = state.model.predict(X) + confidences = np.ones(len(y_pred)) + + # Get labels + if state.label_encoder is not None: + labels = [state.label_encoder.get(int(pred), f"unknown_{pred}") for pred in y_pred] + else: + labels = [str(pred) for pred in y_pred] + + # Build predictions for this batch + batch_predictions = [ + PredictionResult( + time=int(time), + label=label, + confidence=float(conf) + ) + for time, label, conf in zip(times, labels, confidences) + ] + + all_predictions.extend(batch_predictions) + + # Group all predictions into spans + all_spans = group_prediction_spans(all_predictions) + + # Build model info + model_info = ModelInfo( + model_name=state.model_info["model_name"], + model_version=state.model_info.get("model_version"), + model_type=state.model_info["model_type"], + trained_at=state.model_info.get("trained_at"), + dataset_version=state.model_info.get("dataset_version"), + feature_engineering_enabled=state.model_info["feature_engineering_enabled"] + ) + + logger.info( + f"Batch prediction complete: {len(all_predictions)} candles, " + f"{len(all_spans)} spans" + ) + + return PredictResponse( + predictions=all_predictions, + spans=all_spans, + model_info=model_info + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Batch prediction failed: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Batch prediction failed: {str(e)}" + ) + + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8001) diff --git a/services/ml/app/preprocessing.py b/services/ml/app/preprocessing.py new file mode 100644 index 0000000..86f4a7f --- /dev/null +++ b/services/ml/app/preprocessing.py @@ -0,0 +1,185 @@ +""" +Preprocessing module for inference. + +Replicates feature engineering pipeline to ensure preprocessing parity +between training and inference. +""" + +import logging +from typing import List + +import pandas as pd +import numpy as np + +from app.config import PipelineConfig +from features.talib_features import compute_talib_indicators +from features.candle_features import compute_candle_features, validate_candle_data +from features.custom_loader import load_custom_features + +logger = logging.getLogger(__name__) + + +def preprocess_candles( + candles: List[dict], + pipeline_config: PipelineConfig +) -> pd.DataFrame: + """ + Preprocess candle data for inference. + + Applies the same feature engineering steps as used during training: + 1. Convert to DataFrame + 2. Validate OHLC data + 3. Compute TA-Lib indicators (if enabled) + 4. Compute candle features (if enabled) + 5. Load custom features (if configured) + 6. Drop NaN rows from indicator warmup + + Args: + candles: List of candle dictionaries with time, open, high, low, close, volume + pipeline_config: Pipeline configuration (must match training config) + + Returns: + Preprocessed DataFrame ready for model.predict() + + Raises: + ValueError: If data validation fails or too many rows dropped + """ + # Convert to DataFrame + df = pd.DataFrame(candles) + + # Ensure time column exists for tracking + if 'time' not in df.columns: + raise ValueError("Candles must include 'time' field") + + original_rows = len(df) + logger.info(f"Preprocessing {original_rows} candles") + + # Validate OHLC data + try: + validate_candle_data(df) + except Exception as e: + raise ValueError(f"Candle data validation failed: {e}") + + # Get feature engineering config + fe_config = pipeline_config.stages.feature_engineering + + if not fe_config.enabled: + logger.warning("Feature engineering disabled in config - returning raw OHLCV") + return df + + # Compute TA-Lib indicators + if fe_config.talib_indicators: + logger.info(f"Computing {len(fe_config.talib_indicators)} TA-Lib indicators") + try: + df = compute_talib_indicators(df, fe_config.talib_indicators) + except Exception as e: + logger.error(f"Failed to compute TA-Lib indicators: {e}") + raise ValueError(f"Indicator computation failed: {e}") + + # Compute candle features + if fe_config.candle_features: + logger.info("Computing candle features") + try: + df = compute_candle_features(df) + except Exception as e: + logger.error(f"Failed to compute candle features: {e}") + raise ValueError(f"Candle feature computation failed: {e}") + + # Load custom features + if fe_config.custom_features: + logger.info(f"Loading {len(fe_config.custom_features)} custom feature(s)") + try: + df = load_custom_features(df, fe_config.custom_features) + except Exception as e: + logger.error(f"Failed to load custom features: {e}") + raise ValueError(f"Custom feature loading failed: {e}") + + # Handle NaN values from indicator warmup + df_clean = df.dropna() + + rows_dropped = original_rows - len(df_clean) + + if rows_dropped > 0: + logger.info( + f"Dropped {rows_dropped} rows due to indicator warmup " + f"({rows_dropped / original_rows * 100:.1f}%)" + ) + + # Warn if too much data was lost + if rows_dropped / original_rows > 0.5: + raise ValueError( + f"More than 50% of candles dropped due to indicator warmup " + f"({rows_dropped}/{original_rows}). Provide more historical candles." + ) + + logger.info(f"Preprocessing complete: {len(df_clean)} candles ready for prediction") + + return df_clean + + +def extract_feature_columns( + df: pd.DataFrame, + exclude_columns: List[str] = None +) -> pd.DataFrame: + """ + Extract only feature columns for model prediction. + + Removes metadata columns like 'time' that should not be used as features. + + Args: + df: Preprocessed DataFrame + exclude_columns: Columns to exclude (default: ['time']) + + Returns: + DataFrame with only feature columns + """ + if exclude_columns is None: + exclude_columns = ['time'] + + feature_cols = [col for col in df.columns if col not in exclude_columns] + + logger.info(f"Using {len(feature_cols)} feature columns for prediction") + + return df[feature_cols] + + +def validate_feature_parity( + inference_features: List[str], + training_features: List[str] +) -> bool: + """ + Validate that inference features match training features. + + Args: + inference_features: Feature column names from inference preprocessing + training_features: Feature column names from training + + Returns: + True if features match exactly + + Raises: + ValueError: If features don't match + """ + inference_set = set(inference_features) + training_set = set(training_features) + + missing = training_set - inference_set + extra = inference_set - training_set + + if missing or extra: + error_msg = "Feature mismatch detected between training and inference:\n" + + if missing: + error_msg += f" Missing features: {sorted(missing)}\n" + + if extra: + error_msg += f" Extra features: {sorted(extra)}\n" + + error_msg += "\nThis indicates preprocessing parity is broken. " + error_msg += "Ensure the pipeline config used for inference matches training." + + logger.error(error_msg) + raise ValueError(error_msg) + + logger.info("Feature parity validated: inference features match training") + return True