code-review-fix task 15.1: replace @app.on_event startup with FastAPI lifespan pattern in main.py

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marko Djordjevic 2026-02-18 20:57:06 +01:00
parent 21c855db89
commit 41287b20c6

View file

@ -11,6 +11,7 @@ import os
import re import re
import threading import threading
import uuid import uuid
from contextlib import asynccontextmanager
from pathlib import Path from pathlib import Path
from typing import Optional, Dict, Any, List from typing import Optional, Dict, Any, List
from datetime import datetime from datetime import datetime
@ -57,11 +58,97 @@ async def verify_api_key(x_api_key: str = Header(default="")):
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
# --- Lifespan ---
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
FastAPI lifespan context manager.
Handles startup and shutdown logic.
"""
# --- Startup ---
logger.info("Starting inference service...")
# Ensure training_runs table exists
init_db()
# Mark any stale "running" records as failed — they belong to a previous
# process and will never complete.
try:
with get_db() as db:
stmt = (
sa_update(TrainingRun)
.where(TrainingRun.status == "running")
.values(
status="failed",
completed_at=datetime.utcnow(),
metrics_summary={"error": "Service restarted while training was in progress"},
)
)
result = db.execute(stmt)
db.commit()
if result.rowcount:
logger.warning(
f"Marked {result.rowcount} stale training run(s) as failed on startup"
)
except Exception as exc:
logger.error(f"Failed to clean up stale training runs: {exc}")
# Load pipeline config
config_path = Path("config/pipeline.yaml")
if config_path.exists():
try:
state.pipeline_config = load_config(config_path)
logger.info(f"Loaded pipeline config from {config_path}")
# Load model based on config
inference_config = state.pipeline_config.stages.inference
if not inference_config.enabled:
logger.warning("Inference stage is disabled in config")
else:
# Load model
try:
if inference_config.model_source == "mlflow":
# Configure MLflow tracking URI
mlflow.set_tracking_uri(state.pipeline_config.stages.training.mlflow.tracking_uri)
state.model, state.model_info = load_model_from_mlflow(
model_name=inference_config.mlflow_model_name,
stage=inference_config.mlflow_model_stage
)
elif inference_config.model_source == "local":
state.model, state.model_info = load_model_from_local(
model_path=inference_config.local_model_path
)
else:
logger.error(f"Unknown model_source: {inference_config.model_source}")
logger.info("Model loaded successfully")
logger.info(f"Model info: {state.model_info['model_name']} "
f"v{state.model_info['model_version']} "
f"({state.model_info['model_type']})")
except Exception as e:
logger.error(f"Failed to load model: {e}")
logger.warning("Service will start without a model. Use /health to check status.")
except Exception as e:
logger.error(f"Failed to load pipeline config: {e}")
else:
logger.warning(f"Config file not found: {config_path}. Using defaults.")
yield
# --- Shutdown (nothing to do currently) ---
# FastAPI app # FastAPI app
app = FastAPI( app = FastAPI(
title="Candle Pattern Inference API", title="Candle Pattern Inference API",
description="ML inference service for candlestick pattern recognition", description="ML inference service for candlestick pattern recognition",
version="1.0.0" version="1.0.0",
lifespan=lifespan
) )
# Parse CORS origins from environment variable or use default # Parse CORS origins from environment variable or use default
@ -391,88 +478,6 @@ def load_model_from_local(model_path: str) -> tuple[Any, Dict[str, Any]]:
raise raise
# --- Startup Event ---
@app.on_event("startup")
async def startup_event():
"""
Load model and pipeline config on startup.
"""
logger.info("Starting inference service...")
# Ensure training_runs table exists
init_db()
# Mark any stale "running" records as failed — they belong to a previous
# process and will never complete.
try:
with get_db() as db:
stmt = (
sa_update(TrainingRun)
.where(TrainingRun.status == "running")
.values(
status="failed",
completed_at=datetime.utcnow(),
metrics_summary={"error": "Service restarted while training was in progress"},
)
)
result = db.execute(stmt)
db.commit()
if result.rowcount:
logger.warning(
f"Marked {result.rowcount} stale training run(s) as failed on startup"
)
except Exception as exc:
logger.error(f"Failed to clean up stale training runs: {exc}")
# Load pipeline config
config_path = Path("config/pipeline.yaml")
if not config_path.exists():
logger.warning(f"Config file not found: {config_path}. Using defaults.")
return
try:
state.pipeline_config = load_config(config_path)
logger.info(f"Loaded pipeline config from {config_path}")
except Exception as e:
logger.error(f"Failed to load pipeline config: {e}")
return
# Load model based on config
inference_config = state.pipeline_config.stages.inference
if not inference_config.enabled:
logger.warning("Inference stage is disabled in config")
return
# Load model
try:
if inference_config.model_source == "mlflow":
# Configure MLflow tracking URI
mlflow.set_tracking_uri(state.pipeline_config.stages.training.mlflow.tracking_uri)
state.model, state.model_info = load_model_from_mlflow(
model_name=inference_config.mlflow_model_name,
stage=inference_config.mlflow_model_stage
)
elif inference_config.model_source == "local":
state.model, state.model_info = load_model_from_local(
model_path=inference_config.local_model_path
)
else:
logger.error(f"Unknown model_source: {inference_config.model_source}")
return
logger.info("Model loaded successfully")
logger.info(f"Model info: {state.model_info['model_name']} "
f"v{state.model_info['model_version']} "
f"({state.model_info['model_type']})")
except Exception as e:
logger.error(f"Failed to load model: {e}")
logger.warning("Service will start without a model. Use /health to check status.")
# --- Health Check --- # --- Health Check ---
@app.get("/health", response_model=HealthResponse) @app.get("/health", response_model=HealthResponse)