From 41287b20c6aa52dae112dba4f3854733bfb3f04f Mon Sep 17 00:00:00 2001 From: Marko Djordjevic Date: Wed, 18 Feb 2026 20:57:06 +0100 Subject: [PATCH] code-review-fix task 15.1: replace @app.on_event startup with FastAPI lifespan pattern in main.py Co-Authored-By: Claude Sonnet 4.6 --- services/ml/app/main.py | 171 +++++++++++++++++++++------------------- 1 file changed, 88 insertions(+), 83 deletions(-) diff --git a/services/ml/app/main.py b/services/ml/app/main.py index 588fbab..83fc328 100644 --- a/services/ml/app/main.py +++ b/services/ml/app/main.py @@ -11,6 +11,7 @@ import os import re import threading import uuid +from contextlib import asynccontextmanager from pathlib import Path from typing import Optional, Dict, Any, List from datetime import datetime @@ -57,11 +58,97 @@ async def verify_api_key(x_api_key: str = Header(default="")): raise HTTPException(status_code=401, detail="Unauthorized") +# --- Lifespan --- + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + FastAPI lifespan context manager. + Handles startup and shutdown logic. + """ + # --- Startup --- + logger.info("Starting inference service...") + + # Ensure training_runs table exists + init_db() + + # Mark any stale "running" records as failed — they belong to a previous + # process and will never complete. + try: + with get_db() as db: + stmt = ( + sa_update(TrainingRun) + .where(TrainingRun.status == "running") + .values( + status="failed", + completed_at=datetime.utcnow(), + metrics_summary={"error": "Service restarted while training was in progress"}, + ) + ) + result = db.execute(stmt) + db.commit() + if result.rowcount: + logger.warning( + f"Marked {result.rowcount} stale training run(s) as failed on startup" + ) + except Exception as exc: + logger.error(f"Failed to clean up stale training runs: {exc}") + + # Load pipeline config + config_path = Path("config/pipeline.yaml") + if config_path.exists(): + try: + state.pipeline_config = load_config(config_path) + logger.info(f"Loaded pipeline config from {config_path}") + + # Load model based on config + inference_config = state.pipeline_config.stages.inference + + if not inference_config.enabled: + logger.warning("Inference stage is disabled in config") + else: + # Load model + try: + if inference_config.model_source == "mlflow": + # Configure MLflow tracking URI + mlflow.set_tracking_uri(state.pipeline_config.stages.training.mlflow.tracking_uri) + + state.model, state.model_info = load_model_from_mlflow( + model_name=inference_config.mlflow_model_name, + stage=inference_config.mlflow_model_stage + ) + elif inference_config.model_source == "local": + state.model, state.model_info = load_model_from_local( + model_path=inference_config.local_model_path + ) + else: + logger.error(f"Unknown model_source: {inference_config.model_source}") + + logger.info("Model loaded successfully") + logger.info(f"Model info: {state.model_info['model_name']} " + f"v{state.model_info['model_version']} " + f"({state.model_info['model_type']})") + + except Exception as e: + logger.error(f"Failed to load model: {e}") + logger.warning("Service will start without a model. Use /health to check status.") + + except Exception as e: + logger.error(f"Failed to load pipeline config: {e}") + else: + logger.warning(f"Config file not found: {config_path}. Using defaults.") + + yield + + # --- Shutdown (nothing to do currently) --- + + # FastAPI app app = FastAPI( title="Candle Pattern Inference API", description="ML inference service for candlestick pattern recognition", - version="1.0.0" + version="1.0.0", + lifespan=lifespan ) # Parse CORS origins from environment variable or use default @@ -391,88 +478,6 @@ def load_model_from_local(model_path: str) -> tuple[Any, Dict[str, Any]]: raise -# --- Startup Event --- - -@app.on_event("startup") -async def startup_event(): - """ - Load model and pipeline config on startup. - """ - logger.info("Starting inference service...") - - # Ensure training_runs table exists - init_db() - - # Mark any stale "running" records as failed — they belong to a previous - # process and will never complete. - try: - with get_db() as db: - stmt = ( - sa_update(TrainingRun) - .where(TrainingRun.status == "running") - .values( - status="failed", - completed_at=datetime.utcnow(), - metrics_summary={"error": "Service restarted while training was in progress"}, - ) - ) - result = db.execute(stmt) - db.commit() - if result.rowcount: - logger.warning( - f"Marked {result.rowcount} stale training run(s) as failed on startup" - ) - except Exception as exc: - logger.error(f"Failed to clean up stale training runs: {exc}") - - # Load pipeline config - config_path = Path("config/pipeline.yaml") - if not config_path.exists(): - logger.warning(f"Config file not found: {config_path}. Using defaults.") - return - - try: - state.pipeline_config = load_config(config_path) - logger.info(f"Loaded pipeline config from {config_path}") - except Exception as e: - logger.error(f"Failed to load pipeline config: {e}") - return - - # Load model based on config - inference_config = state.pipeline_config.stages.inference - - if not inference_config.enabled: - logger.warning("Inference stage is disabled in config") - return - - # Load model - try: - if inference_config.model_source == "mlflow": - # Configure MLflow tracking URI - mlflow.set_tracking_uri(state.pipeline_config.stages.training.mlflow.tracking_uri) - - state.model, state.model_info = load_model_from_mlflow( - model_name=inference_config.mlflow_model_name, - stage=inference_config.mlflow_model_stage - ) - elif inference_config.model_source == "local": - state.model, state.model_info = load_model_from_local( - model_path=inference_config.local_model_path - ) - else: - logger.error(f"Unknown model_source: {inference_config.model_source}") - return - - logger.info("Model loaded successfully") - logger.info(f"Model info: {state.model_info['model_name']} " - f"v{state.model_info['model_version']} " - f"({state.model_info['model_type']})") - - except Exception as e: - logger.error(f"Failed to load model: {e}") - logger.warning("Service will start without a model. Use /health to check status.") - - # --- Health Check --- @app.get("/health", response_model=HealthResponse)