""" FastAPI inference service for candlestick pattern prediction. Provides REST API endpoints for model serving, health checks, and prediction. """ import logging from pathlib import Path from typing import Optional, Dict, Any, List from datetime import datetime import json from fastapi import FastAPI, HTTPException, status from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field import pandas as pd import joblib import mlflow import mlflow.sklearn import mlflow.xgboost from app.config import load_config, PipelineConfig from app.preprocessing import preprocess_candles, extract_feature_columns # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # FastAPI app app = FastAPI( title="Candle Pattern Inference API", description="ML inference service for candlestick pattern recognition", version="1.0.0" ) # CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # In production, specify actual origins allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global state class AppState: """Application state container.""" model: Optional[Any] = None model_info: Optional[Dict[str, Any]] = None pipeline_config: Optional[PipelineConfig] = None feature_columns: Optional[List[str]] = None label_encoder: Optional[Dict[int, str]] = None state = AppState() # --- Pydantic Models --- class CandleData(BaseModel): """Single candle data point.""" time: int = Field(..., description="Unix timestamp in seconds") open: float high: float low: float close: float volume: float class PredictRequest(BaseModel): """Request model for /predict endpoint.""" pair: str = Field(..., description="Trading pair (e.g., EURUSD)") timeframe: str = Field(..., description="Timeframe (e.g., 1H, 4H, 1D)") candles: List[CandleData] = Field(..., min_length=1, description="Array of candle data") class PredictionResult(BaseModel): """Single candle prediction.""" time: int label: str confidence: float = Field(..., ge=0.0, le=1.0) class PredictionSpan(BaseModel): """Grouped prediction span.""" start_time: int end_time: int label: str avg_confidence: float = Field(..., ge=0.0, le=1.0) candle_count: int class ModelInfo(BaseModel): """Model metadata.""" model_name: str model_version: Optional[str] = None model_type: str trained_at: Optional[str] = None dataset_version: Optional[str] = None feature_engineering_enabled: bool class PredictResponse(BaseModel): """Response model for /predict endpoint.""" predictions: List[PredictionResult] spans: List[PredictionSpan] model_info: ModelInfo class BatchPredictRequest(BaseModel): """Request model for /predict/batch endpoint.""" pair: str timeframe: str start_date: str = Field(..., description="ISO format date (YYYY-MM-DD)") end_date: str = Field(..., description="ISO format date (YYYY-MM-DD)") class LabelInfo(BaseModel): """Pattern label with display color.""" name: str color: str class PerClassMetrics(BaseModel): """Per-class performance metrics.""" label: str precision: float recall: float f1_score: float support: int class ModelInfoResponse(BaseModel): """Extended model information with metrics.""" model_name: str model_version: Optional[str] = None model_type: str trained_at: Optional[str] = None dataset_version: Optional[str] = None feature_engineering_enabled: bool labels: List[str] per_class_metrics: List[PerClassMetrics] class HealthResponse(BaseModel): """Health check response.""" status: str = Field(..., description="healthy, degraded, or unhealthy") model_loaded: bool mlflow: str = Field(..., description="connected or disconnected") database: str = Field(..., description="connected or disconnected") # --- Model Loading Functions --- def load_model_from_mlflow(model_name: str, stage: str) -> tuple[Any, Dict[str, Any]]: """ Load model from MLflow model registry. Args: model_name: Name of the registered model stage: Model stage (Production, Staging, None) Returns: Tuple of (model, model_info) Raises: Exception: If model not found or MLflow unavailable """ logger.info(f"Loading model from MLflow: {model_name} stage={stage}") try: # Load model from registry model_uri = f"models:/{model_name}/{stage}" model = mlflow.pyfunc.load_model(model_uri) # Get model version details client = mlflow.tracking.MlflowClient() model_versions = client.get_latest_versions(model_name, stages=[stage]) if not model_versions: raise ValueError(f"No model found for {model_name} at stage {stage}") model_version = model_versions[0] run_id = model_version.run_id # Get run metadata run = client.get_run(run_id) # Extract model info model_info = { "model_name": model_name, "model_version": model_version.version, "model_type": run.data.params.get("model_type", "unknown"), "trained_at": datetime.fromtimestamp(run.info.start_time / 1000).isoformat(), "dataset_version": run.data.params.get("dataset_version", None), "feature_engineering_enabled": run.data.params.get("feature_engineering", "true") == "true", "labels": json.loads(run.data.params.get("labels", "[]")), "per_class_metrics": [] } # Extract per-class metrics for key, value in run.data.metrics.items(): if key.startswith("precision_"): label = key.replace("precision_", "") recall = run.data.metrics.get(f"recall_{label}", 0.0) f1 = run.data.metrics.get(f"f1_{label}", 0.0) support = int(run.data.params.get(f"support_{label}", 0)) model_info["per_class_metrics"].append({ "label": label, "precision": value, "recall": recall, "f1_score": f1, "support": support }) logger.info(f"Successfully loaded model: {model_name} v{model_version.version}") return model, model_info except Exception as e: logger.error(f"Failed to load model from MLflow: {e}") raise def load_model_from_local(model_path: str) -> tuple[Any, Dict[str, Any]]: """ Load model from local file using joblib. Args: model_path: Path to .pkl model file Returns: Tuple of (model, model_info) Raises: FileNotFoundError: If model file doesn't exist Exception: If model loading fails """ model_path = Path(model_path) logger.info(f"Loading model from local file: {model_path}") if not model_path.exists(): raise FileNotFoundError(f"Model file not found: {model_path}") try: # Load model with joblib model_data = joblib.load(model_path) # Extract model and metadata if isinstance(model_data, dict): model = model_data.get("model") metadata = model_data.get("metadata", {}) else: model = model_data metadata = {} # Build model info model_info = { "model_name": model_path.stem, "model_version": metadata.get("version", "local"), "model_type": metadata.get("model_type", "unknown"), "trained_at": metadata.get("trained_at", None), "dataset_version": metadata.get("dataset_version", None), "feature_engineering_enabled": metadata.get("feature_engineering_enabled", True), "labels": metadata.get("labels", []), "per_class_metrics": metadata.get("per_class_metrics", []) } logger.info(f"Successfully loaded local model: {model_path.name}") return model, model_info except Exception as e: logger.error(f"Failed to load model from local file: {e}") raise # --- Startup Event --- @app.on_event("startup") async def startup_event(): """ Load model and pipeline config on startup. """ logger.info("Starting inference service...") # Load pipeline config config_path = Path("config/pipeline.yaml") if not config_path.exists(): logger.warning(f"Config file not found: {config_path}. Using defaults.") return try: state.pipeline_config = load_config(config_path) logger.info(f"Loaded pipeline config from {config_path}") except Exception as e: logger.error(f"Failed to load pipeline config: {e}") return # Load model based on config inference_config = state.pipeline_config.stages.inference if not inference_config.enabled: logger.warning("Inference stage is disabled in config") return # Load model try: if inference_config.model_source == "mlflow": # Configure MLflow tracking URI mlflow.set_tracking_uri(state.pipeline_config.stages.training.mlflow.tracking_uri) state.model, state.model_info = load_model_from_mlflow( model_name=inference_config.mlflow_model_name, stage=inference_config.mlflow_model_stage ) elif inference_config.model_source == "local": state.model, state.model_info = load_model_from_local( model_path=inference_config.local_model_path ) else: logger.error(f"Unknown model_source: {inference_config.model_source}") return logger.info("Model loaded successfully") logger.info(f"Model info: {state.model_info['model_name']} " f"v{state.model_info['model_version']} " f"({state.model_info['model_type']})") except Exception as e: logger.error(f"Failed to load model: {e}") logger.warning("Service will start without a model. Use /health to check status.") # --- Health Check --- @app.get("/health", response_model=HealthResponse) async def health_check(): """ Health check endpoint. Returns service status, model loaded status, and dependency health. """ model_loaded = state.model is not None # Check MLflow connection mlflow_status = "disconnected" try: # TODO: Actually check MLflow connection mlflow_status = "connected" except Exception: pass # Check database connection db_status = "disconnected" try: # TODO: Actually check database connection db_status = "connected" except Exception: pass # Determine overall status if model_loaded and mlflow_status == "connected" and db_status == "connected": overall_status = "healthy" elif model_loaded: overall_status = "degraded" else: overall_status = "unhealthy" return HealthResponse( status=overall_status, model_loaded=model_loaded, mlflow=mlflow_status, database=db_status ) # --- Model Info Endpoints (Stubs) --- @app.get("/model/info", response_model=ModelInfoResponse) async def get_model_info(): """ Get detailed model information and per-class metrics. Returns HTTP 503 if no model is loaded. """ if state.model is None or state.model_info is None: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="No model available" ) # Build response from loaded model info return ModelInfoResponse( model_name=state.model_info["model_name"], model_version=state.model_info.get("model_version"), model_type=state.model_info["model_type"], trained_at=state.model_info.get("trained_at"), dataset_version=state.model_info.get("dataset_version"), feature_engineering_enabled=state.model_info["feature_engineering_enabled"], labels=state.model_info.get("labels", []), per_class_metrics=[ PerClassMetrics(**metric) for metric in state.model_info.get("per_class_metrics", []) ] ) @app.get("/model/labels", response_model=List[LabelInfo]) async def get_model_labels(): """ Get all pattern labels the model can predict with display colors. """ if state.model is None: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="No model available" ) # TODO: Load label colors from config or database # For now, return placeholder labels = state.model_info.get("labels", []) if state.model_info else [] return [ LabelInfo(name=label, color="#888888") for label in labels ] # --- Prediction Helper Functions --- def group_prediction_spans( predictions: List[PredictionResult] ) -> List[PredictionSpan]: """ Group consecutive predictions with the same label into spans. Only groups non-"O" (non-background) predictions. "O" predictions are treated as background and not grouped into spans. Args: predictions: List of per-candle predictions Returns: List of prediction spans """ if not predictions: return [] spans = [] current_span = None for pred in predictions: # Skip "O" (background) predictions if pred.label == "O": if current_span is not None: # End current span spans.append(current_span) current_span = None continue # Start new span or continue current span if current_span is None or current_span.label != pred.label: # End previous span if exists if current_span is not None: spans.append(current_span) # Start new span current_span = PredictionSpan( start_time=pred.time, end_time=pred.time, label=pred.label, avg_confidence=pred.confidence, candle_count=1 ) else: # Continue current span current_span.end_time = pred.time current_span.candle_count += 1 # Update running average confidence total_confidence = current_span.avg_confidence * (current_span.candle_count - 1) current_span.avg_confidence = (total_confidence + pred.confidence) / current_span.candle_count # Don't forget last span if current_span is not None: spans.append(current_span) logger.info(f"Grouped {len(predictions)} predictions into {len(spans)} spans") return spans # --- Prediction Endpoints --- @app.post("/predict", response_model=PredictResponse) async def predict(request: PredictRequest): """ Predict candlestick patterns for provided candles. Accepts OHLCV candles, runs preprocessing, and returns predictions. """ if state.model is None: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="No model available" ) if not request.candles: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Candles array cannot be empty" ) if state.pipeline_config is None: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Pipeline configuration not loaded" ) logger.info(f"Predict request: {request.pair} {request.timeframe}, {len(request.candles)} candles") try: # Convert candles to list of dicts candles_data = [candle.model_dump() for candle in request.candles] # Preprocess candles (feature engineering) df_preprocessed = preprocess_candles(candles_data, state.pipeline_config) # Keep times for results mapping times = df_preprocessed['time'].values # Extract feature columns (exclude 'time') X = extract_feature_columns(df_preprocessed) # Get predictions and probabilities if hasattr(state.model, 'predict_proba'): y_pred = state.model.predict(X) y_proba = state.model.predict_proba(X) # Get confidence as max probability confidences = np.max(y_proba, axis=1) else: # Fallback for models without predict_proba y_pred = state.model.predict(X) confidences = np.ones(len(y_pred)) # Default confidence of 1.0 # Get label names (handle both string and int predictions) if state.label_encoder is not None: # Model predicts integers, map to labels labels = [state.label_encoder.get(int(pred), f"unknown_{pred}") for pred in y_pred] else: # Model predicts strings directly labels = [str(pred) for pred in y_pred] # Build per-candle predictions predictions = [ PredictionResult( time=int(time), label=label, confidence=float(conf) ) for time, label, conf in zip(times, labels, confidences) ] # Group into spans spans = group_prediction_spans(predictions) # Build model info for response model_info = ModelInfo( model_name=state.model_info["model_name"], model_version=state.model_info.get("model_version"), model_type=state.model_info["model_type"], trained_at=state.model_info.get("trained_at"), dataset_version=state.model_info.get("dataset_version"), feature_engineering_enabled=state.model_info["feature_engineering_enabled"] ) logger.info( f"Prediction complete: {len(predictions)} candles, " f"{len(spans)} spans, {len([p for p in predictions if p.label != 'O'])} patterns" ) return PredictResponse( predictions=predictions, spans=spans, model_info=model_info ) except ValueError as e: logger.error(f"Prediction validation error: {e}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=str(e) ) except Exception as e: logger.error(f"Prediction failed: {e}", exc_info=True) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Prediction failed: {str(e)}" ) @app.post("/predict/batch", response_model=PredictResponse) async def predict_batch(request: BatchPredictRequest): """ Batch prediction for a date range. Loads data from storage and returns predictions for the full range. """ if state.model is None: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="No model available" ) if state.pipeline_config is None: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Pipeline configuration not loaded" ) logger.info( f"Batch predict: {request.pair} {request.timeframe} " f"from {request.start_date} to {request.end_date}" ) try: # Load OHLCV data from raw data path raw_path = Path(state.pipeline_config.data.raw_path) if not raw_path.exists(): raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"No data found for {request.pair} {request.timeframe}" ) # Load data df_raw = pd.read_csv(raw_path) # Filter by date range if time column exists if 'time' in df_raw.columns: # Parse dates to timestamps start_ts = int(pd.Timestamp(request.start_date).timestamp()) end_ts = int(pd.Timestamp(request.end_date).timestamp()) # Filter data df_filtered = df_raw[ (df_raw['time'] >= start_ts) & (df_raw['time'] <= end_ts) ].copy() if len(df_filtered) == 0: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"No data found in date range {request.start_date} to {request.end_date}" ) logger.info(f"Loaded {len(df_filtered)} candles from {request.start_date} to {request.end_date}") else: df_filtered = df_raw logger.warning("No 'time' column found, using all data") # Get batch size from config batch_size = state.pipeline_config.stages.inference.batch_size # Process in batches all_predictions = [] all_spans = [] for i in range(0, len(df_filtered), batch_size): batch_df = df_filtered.iloc[i:i + batch_size] logger.info(f"Processing batch {i // batch_size + 1}: rows {i} to {i + len(batch_df)}") # Convert batch to candles format batch_candles = batch_df.to_dict('records') # Preprocess df_preprocessed = preprocess_candles(batch_candles, state.pipeline_config) # Keep times times = df_preprocessed['time'].values # Extract features X = extract_feature_columns(df_preprocessed) # Predict if hasattr(state.model, 'predict_proba'): y_pred = state.model.predict(X) y_proba = state.model.predict_proba(X) confidences = np.max(y_proba, axis=1) else: y_pred = state.model.predict(X) confidences = np.ones(len(y_pred)) # Get labels if state.label_encoder is not None: labels = [state.label_encoder.get(int(pred), f"unknown_{pred}") for pred in y_pred] else: labels = [str(pred) for pred in y_pred] # Build predictions for this batch batch_predictions = [ PredictionResult( time=int(time), label=label, confidence=float(conf) ) for time, label, conf in zip(times, labels, confidences) ] all_predictions.extend(batch_predictions) # Group all predictions into spans all_spans = group_prediction_spans(all_predictions) # Build model info model_info = ModelInfo( model_name=state.model_info["model_name"], model_version=state.model_info.get("model_version"), model_type=state.model_info["model_type"], trained_at=state.model_info.get("trained_at"), dataset_version=state.model_info.get("dataset_version"), feature_engineering_enabled=state.model_info["feature_engineering_enabled"] ) logger.info( f"Batch prediction complete: {len(all_predictions)} candles, " f"{len(all_spans)} spans" ) return PredictResponse( predictions=all_predictions, spans=all_spans, model_info=model_info ) except HTTPException: raise except Exception as e: logger.error(f"Batch prediction failed: {e}", exc_info=True) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Batch prediction failed: {str(e)}" ) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8001)