code-review-fix task 15.1: replace @app.on_event startup with FastAPI lifespan pattern in main.py
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
21c855db89
commit
41287b20c6
1 changed files with 88 additions and 83 deletions
|
|
@ -11,6 +11,7 @@ import os
|
||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Dict, Any, List
|
from typing import Optional, Dict, Any, List
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
@ -57,11 +58,97 @@ async def verify_api_key(x_api_key: str = Header(default="")):
|
||||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Lifespan ---
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""
|
||||||
|
FastAPI lifespan context manager.
|
||||||
|
Handles startup and shutdown logic.
|
||||||
|
"""
|
||||||
|
# --- Startup ---
|
||||||
|
logger.info("Starting inference service...")
|
||||||
|
|
||||||
|
# Ensure training_runs table exists
|
||||||
|
init_db()
|
||||||
|
|
||||||
|
# Mark any stale "running" records as failed — they belong to a previous
|
||||||
|
# process and will never complete.
|
||||||
|
try:
|
||||||
|
with get_db() as db:
|
||||||
|
stmt = (
|
||||||
|
sa_update(TrainingRun)
|
||||||
|
.where(TrainingRun.status == "running")
|
||||||
|
.values(
|
||||||
|
status="failed",
|
||||||
|
completed_at=datetime.utcnow(),
|
||||||
|
metrics_summary={"error": "Service restarted while training was in progress"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = db.execute(stmt)
|
||||||
|
db.commit()
|
||||||
|
if result.rowcount:
|
||||||
|
logger.warning(
|
||||||
|
f"Marked {result.rowcount} stale training run(s) as failed on startup"
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Failed to clean up stale training runs: {exc}")
|
||||||
|
|
||||||
|
# Load pipeline config
|
||||||
|
config_path = Path("config/pipeline.yaml")
|
||||||
|
if config_path.exists():
|
||||||
|
try:
|
||||||
|
state.pipeline_config = load_config(config_path)
|
||||||
|
logger.info(f"Loaded pipeline config from {config_path}")
|
||||||
|
|
||||||
|
# Load model based on config
|
||||||
|
inference_config = state.pipeline_config.stages.inference
|
||||||
|
|
||||||
|
if not inference_config.enabled:
|
||||||
|
logger.warning("Inference stage is disabled in config")
|
||||||
|
else:
|
||||||
|
# Load model
|
||||||
|
try:
|
||||||
|
if inference_config.model_source == "mlflow":
|
||||||
|
# Configure MLflow tracking URI
|
||||||
|
mlflow.set_tracking_uri(state.pipeline_config.stages.training.mlflow.tracking_uri)
|
||||||
|
|
||||||
|
state.model, state.model_info = load_model_from_mlflow(
|
||||||
|
model_name=inference_config.mlflow_model_name,
|
||||||
|
stage=inference_config.mlflow_model_stage
|
||||||
|
)
|
||||||
|
elif inference_config.model_source == "local":
|
||||||
|
state.model, state.model_info = load_model_from_local(
|
||||||
|
model_path=inference_config.local_model_path
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(f"Unknown model_source: {inference_config.model_source}")
|
||||||
|
|
||||||
|
logger.info("Model loaded successfully")
|
||||||
|
logger.info(f"Model info: {state.model_info['model_name']} "
|
||||||
|
f"v{state.model_info['model_version']} "
|
||||||
|
f"({state.model_info['model_type']})")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load model: {e}")
|
||||||
|
logger.warning("Service will start without a model. Use /health to check status.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load pipeline config: {e}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Config file not found: {config_path}. Using defaults.")
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# --- Shutdown (nothing to do currently) ---
|
||||||
|
|
||||||
|
|
||||||
# FastAPI app
|
# FastAPI app
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Candle Pattern Inference API",
|
title="Candle Pattern Inference API",
|
||||||
description="ML inference service for candlestick pattern recognition",
|
description="ML inference service for candlestick pattern recognition",
|
||||||
version="1.0.0"
|
version="1.0.0",
|
||||||
|
lifespan=lifespan
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse CORS origins from environment variable or use default
|
# Parse CORS origins from environment variable or use default
|
||||||
|
|
@ -391,88 +478,6 @@ def load_model_from_local(model_path: str) -> tuple[Any, Dict[str, Any]]:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
# --- Startup Event ---
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
|
||||||
async def startup_event():
|
|
||||||
"""
|
|
||||||
Load model and pipeline config on startup.
|
|
||||||
"""
|
|
||||||
logger.info("Starting inference service...")
|
|
||||||
|
|
||||||
# Ensure training_runs table exists
|
|
||||||
init_db()
|
|
||||||
|
|
||||||
# Mark any stale "running" records as failed — they belong to a previous
|
|
||||||
# process and will never complete.
|
|
||||||
try:
|
|
||||||
with get_db() as db:
|
|
||||||
stmt = (
|
|
||||||
sa_update(TrainingRun)
|
|
||||||
.where(TrainingRun.status == "running")
|
|
||||||
.values(
|
|
||||||
status="failed",
|
|
||||||
completed_at=datetime.utcnow(),
|
|
||||||
metrics_summary={"error": "Service restarted while training was in progress"},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
result = db.execute(stmt)
|
|
||||||
db.commit()
|
|
||||||
if result.rowcount:
|
|
||||||
logger.warning(
|
|
||||||
f"Marked {result.rowcount} stale training run(s) as failed on startup"
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error(f"Failed to clean up stale training runs: {exc}")
|
|
||||||
|
|
||||||
# Load pipeline config
|
|
||||||
config_path = Path("config/pipeline.yaml")
|
|
||||||
if not config_path.exists():
|
|
||||||
logger.warning(f"Config file not found: {config_path}. Using defaults.")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
state.pipeline_config = load_config(config_path)
|
|
||||||
logger.info(f"Loaded pipeline config from {config_path}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to load pipeline config: {e}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Load model based on config
|
|
||||||
inference_config = state.pipeline_config.stages.inference
|
|
||||||
|
|
||||||
if not inference_config.enabled:
|
|
||||||
logger.warning("Inference stage is disabled in config")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Load model
|
|
||||||
try:
|
|
||||||
if inference_config.model_source == "mlflow":
|
|
||||||
# Configure MLflow tracking URI
|
|
||||||
mlflow.set_tracking_uri(state.pipeline_config.stages.training.mlflow.tracking_uri)
|
|
||||||
|
|
||||||
state.model, state.model_info = load_model_from_mlflow(
|
|
||||||
model_name=inference_config.mlflow_model_name,
|
|
||||||
stage=inference_config.mlflow_model_stage
|
|
||||||
)
|
|
||||||
elif inference_config.model_source == "local":
|
|
||||||
state.model, state.model_info = load_model_from_local(
|
|
||||||
model_path=inference_config.local_model_path
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.error(f"Unknown model_source: {inference_config.model_source}")
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info("Model loaded successfully")
|
|
||||||
logger.info(f"Model info: {state.model_info['model_name']} "
|
|
||||||
f"v{state.model_info['model_version']} "
|
|
||||||
f"({state.model_info['model_type']})")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to load model: {e}")
|
|
||||||
logger.warning("Service will start without a model. Use /health to check status.")
|
|
||||||
|
|
||||||
|
|
||||||
# --- Health Check ---
|
# --- Health Check ---
|
||||||
|
|
||||||
@app.get("/health", response_model=HealthResponse)
|
@app.get("/health", response_model=HealthResponse)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue