code-review-fix task 15.1: replace @app.on_event startup with FastAPI lifespan pattern in main.py

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marko Djordjevic 2026-02-18 20:57:06 +01:00
parent 21c855db89
commit 41287b20c6

View file

@ -11,6 +11,7 @@ import os
import re
import threading
import uuid
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Optional, Dict, Any, List
from datetime import datetime
@ -57,11 +58,97 @@ async def verify_api_key(x_api_key: str = Header(default="")):
raise HTTPException(status_code=401, detail="Unauthorized")
# --- Lifespan ---
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
FastAPI lifespan context manager.
Handles startup and shutdown logic.
"""
# --- Startup ---
logger.info("Starting inference service...")
# Ensure training_runs table exists
init_db()
# Mark any stale "running" records as failed — they belong to a previous
# process and will never complete.
try:
with get_db() as db:
stmt = (
sa_update(TrainingRun)
.where(TrainingRun.status == "running")
.values(
status="failed",
completed_at=datetime.utcnow(),
metrics_summary={"error": "Service restarted while training was in progress"},
)
)
result = db.execute(stmt)
db.commit()
if result.rowcount:
logger.warning(
f"Marked {result.rowcount} stale training run(s) as failed on startup"
)
except Exception as exc:
logger.error(f"Failed to clean up stale training runs: {exc}")
# Load pipeline config
config_path = Path("config/pipeline.yaml")
if config_path.exists():
try:
state.pipeline_config = load_config(config_path)
logger.info(f"Loaded pipeline config from {config_path}")
# Load model based on config
inference_config = state.pipeline_config.stages.inference
if not inference_config.enabled:
logger.warning("Inference stage is disabled in config")
else:
# Load model
try:
if inference_config.model_source == "mlflow":
# Configure MLflow tracking URI
mlflow.set_tracking_uri(state.pipeline_config.stages.training.mlflow.tracking_uri)
state.model, state.model_info = load_model_from_mlflow(
model_name=inference_config.mlflow_model_name,
stage=inference_config.mlflow_model_stage
)
elif inference_config.model_source == "local":
state.model, state.model_info = load_model_from_local(
model_path=inference_config.local_model_path
)
else:
logger.error(f"Unknown model_source: {inference_config.model_source}")
logger.info("Model loaded successfully")
logger.info(f"Model info: {state.model_info['model_name']} "
f"v{state.model_info['model_version']} "
f"({state.model_info['model_type']})")
except Exception as e:
logger.error(f"Failed to load model: {e}")
logger.warning("Service will start without a model. Use /health to check status.")
except Exception as e:
logger.error(f"Failed to load pipeline config: {e}")
else:
logger.warning(f"Config file not found: {config_path}. Using defaults.")
yield
# --- Shutdown (nothing to do currently) ---
# FastAPI app
app = FastAPI(
title="Candle Pattern Inference API",
description="ML inference service for candlestick pattern recognition",
version="1.0.0"
version="1.0.0",
lifespan=lifespan
)
# Parse CORS origins from environment variable or use default
@ -391,88 +478,6 @@ def load_model_from_local(model_path: str) -> tuple[Any, Dict[str, Any]]:
raise
# --- Startup Event ---
@app.on_event("startup")
async def startup_event():
"""
Load model and pipeline config on startup.
"""
logger.info("Starting inference service...")
# Ensure training_runs table exists
init_db()
# Mark any stale "running" records as failed — they belong to a previous
# process and will never complete.
try:
with get_db() as db:
stmt = (
sa_update(TrainingRun)
.where(TrainingRun.status == "running")
.values(
status="failed",
completed_at=datetime.utcnow(),
metrics_summary={"error": "Service restarted while training was in progress"},
)
)
result = db.execute(stmt)
db.commit()
if result.rowcount:
logger.warning(
f"Marked {result.rowcount} stale training run(s) as failed on startup"
)
except Exception as exc:
logger.error(f"Failed to clean up stale training runs: {exc}")
# Load pipeline config
config_path = Path("config/pipeline.yaml")
if not config_path.exists():
logger.warning(f"Config file not found: {config_path}. Using defaults.")
return
try:
state.pipeline_config = load_config(config_path)
logger.info(f"Loaded pipeline config from {config_path}")
except Exception as e:
logger.error(f"Failed to load pipeline config: {e}")
return
# Load model based on config
inference_config = state.pipeline_config.stages.inference
if not inference_config.enabled:
logger.warning("Inference stage is disabled in config")
return
# Load model
try:
if inference_config.model_source == "mlflow":
# Configure MLflow tracking URI
mlflow.set_tracking_uri(state.pipeline_config.stages.training.mlflow.tracking_uri)
state.model, state.model_info = load_model_from_mlflow(
model_name=inference_config.mlflow_model_name,
stage=inference_config.mlflow_model_stage
)
elif inference_config.model_source == "local":
state.model, state.model_info = load_model_from_local(
model_path=inference_config.local_model_path
)
else:
logger.error(f"Unknown model_source: {inference_config.model_source}")
return
logger.info("Model loaded successfully")
logger.info(f"Model info: {state.model_info['model_name']} "
f"v{state.model_info['model_version']} "
f"({state.model_info['model_type']})")
except Exception as e:
logger.error(f"Failed to load model: {e}")
logger.warning("Service will start without a model. Use /health to check status.")
# --- Health Check ---
@app.get("/health", response_model=HealthResponse)