code-review-fix task 15.1: replace @app.on_event startup with FastAPI lifespan pattern in main.py
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
21c855db89
commit
41287b20c6
1 changed files with 88 additions and 83 deletions
|
|
@ -11,6 +11,7 @@ import os
|
|||
import re
|
||||
import threading
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
|
|
@ -57,11 +58,97 @@ async def verify_api_key(x_api_key: str = Header(default="")):
|
|||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
|
||||
# --- Lifespan ---
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""
|
||||
FastAPI lifespan context manager.
|
||||
Handles startup and shutdown logic.
|
||||
"""
|
||||
# --- Startup ---
|
||||
logger.info("Starting inference service...")
|
||||
|
||||
# Ensure training_runs table exists
|
||||
init_db()
|
||||
|
||||
# Mark any stale "running" records as failed — they belong to a previous
|
||||
# process and will never complete.
|
||||
try:
|
||||
with get_db() as db:
|
||||
stmt = (
|
||||
sa_update(TrainingRun)
|
||||
.where(TrainingRun.status == "running")
|
||||
.values(
|
||||
status="failed",
|
||||
completed_at=datetime.utcnow(),
|
||||
metrics_summary={"error": "Service restarted while training was in progress"},
|
||||
)
|
||||
)
|
||||
result = db.execute(stmt)
|
||||
db.commit()
|
||||
if result.rowcount:
|
||||
logger.warning(
|
||||
f"Marked {result.rowcount} stale training run(s) as failed on startup"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to clean up stale training runs: {exc}")
|
||||
|
||||
# Load pipeline config
|
||||
config_path = Path("config/pipeline.yaml")
|
||||
if config_path.exists():
|
||||
try:
|
||||
state.pipeline_config = load_config(config_path)
|
||||
logger.info(f"Loaded pipeline config from {config_path}")
|
||||
|
||||
# Load model based on config
|
||||
inference_config = state.pipeline_config.stages.inference
|
||||
|
||||
if not inference_config.enabled:
|
||||
logger.warning("Inference stage is disabled in config")
|
||||
else:
|
||||
# Load model
|
||||
try:
|
||||
if inference_config.model_source == "mlflow":
|
||||
# Configure MLflow tracking URI
|
||||
mlflow.set_tracking_uri(state.pipeline_config.stages.training.mlflow.tracking_uri)
|
||||
|
||||
state.model, state.model_info = load_model_from_mlflow(
|
||||
model_name=inference_config.mlflow_model_name,
|
||||
stage=inference_config.mlflow_model_stage
|
||||
)
|
||||
elif inference_config.model_source == "local":
|
||||
state.model, state.model_info = load_model_from_local(
|
||||
model_path=inference_config.local_model_path
|
||||
)
|
||||
else:
|
||||
logger.error(f"Unknown model_source: {inference_config.model_source}")
|
||||
|
||||
logger.info("Model loaded successfully")
|
||||
logger.info(f"Model info: {state.model_info['model_name']} "
|
||||
f"v{state.model_info['model_version']} "
|
||||
f"({state.model_info['model_type']})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}")
|
||||
logger.warning("Service will start without a model. Use /health to check status.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load pipeline config: {e}")
|
||||
else:
|
||||
logger.warning(f"Config file not found: {config_path}. Using defaults.")
|
||||
|
||||
yield
|
||||
|
||||
# --- Shutdown (nothing to do currently) ---
|
||||
|
||||
|
||||
# FastAPI app
|
||||
app = FastAPI(
|
||||
title="Candle Pattern Inference API",
|
||||
description="ML inference service for candlestick pattern recognition",
|
||||
version="1.0.0"
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# Parse CORS origins from environment variable or use default
|
||||
|
|
@ -391,88 +478,6 @@ def load_model_from_local(model_path: str) -> tuple[Any, Dict[str, Any]]:
|
|||
raise
|
||||
|
||||
|
||||
# --- Startup Event ---
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""
|
||||
Load model and pipeline config on startup.
|
||||
"""
|
||||
logger.info("Starting inference service...")
|
||||
|
||||
# Ensure training_runs table exists
|
||||
init_db()
|
||||
|
||||
# Mark any stale "running" records as failed — they belong to a previous
|
||||
# process and will never complete.
|
||||
try:
|
||||
with get_db() as db:
|
||||
stmt = (
|
||||
sa_update(TrainingRun)
|
||||
.where(TrainingRun.status == "running")
|
||||
.values(
|
||||
status="failed",
|
||||
completed_at=datetime.utcnow(),
|
||||
metrics_summary={"error": "Service restarted while training was in progress"},
|
||||
)
|
||||
)
|
||||
result = db.execute(stmt)
|
||||
db.commit()
|
||||
if result.rowcount:
|
||||
logger.warning(
|
||||
f"Marked {result.rowcount} stale training run(s) as failed on startup"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to clean up stale training runs: {exc}")
|
||||
|
||||
# Load pipeline config
|
||||
config_path = Path("config/pipeline.yaml")
|
||||
if not config_path.exists():
|
||||
logger.warning(f"Config file not found: {config_path}. Using defaults.")
|
||||
return
|
||||
|
||||
try:
|
||||
state.pipeline_config = load_config(config_path)
|
||||
logger.info(f"Loaded pipeline config from {config_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load pipeline config: {e}")
|
||||
return
|
||||
|
||||
# Load model based on config
|
||||
inference_config = state.pipeline_config.stages.inference
|
||||
|
||||
if not inference_config.enabled:
|
||||
logger.warning("Inference stage is disabled in config")
|
||||
return
|
||||
|
||||
# Load model
|
||||
try:
|
||||
if inference_config.model_source == "mlflow":
|
||||
# Configure MLflow tracking URI
|
||||
mlflow.set_tracking_uri(state.pipeline_config.stages.training.mlflow.tracking_uri)
|
||||
|
||||
state.model, state.model_info = load_model_from_mlflow(
|
||||
model_name=inference_config.mlflow_model_name,
|
||||
stage=inference_config.mlflow_model_stage
|
||||
)
|
||||
elif inference_config.model_source == "local":
|
||||
state.model, state.model_info = load_model_from_local(
|
||||
model_path=inference_config.local_model_path
|
||||
)
|
||||
else:
|
||||
logger.error(f"Unknown model_source: {inference_config.model_source}")
|
||||
return
|
||||
|
||||
logger.info("Model loaded successfully")
|
||||
logger.info(f"Model info: {state.model_info['model_name']} "
|
||||
f"v{state.model_info['model_version']} "
|
||||
f"({state.model_info['model_type']})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}")
|
||||
logger.warning("Service will start without a model. Use /health to check status.")
|
||||
|
||||
|
||||
# --- Health Check ---
|
||||
|
||||
@app.get("/health", response_model=HealthResponse)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue