fix: implement real health checks in ML service /health endpoint
- Execute SELECT 1 against PostgreSQL via SQLAlchemy session to verify DB connectivity
- HTTP GET to ${MLFLOW_TRACKING_URI}/health (default: http://localhost:5000/health) to verify MLflow
- Return HTTP 503 if any component is unhealthy, HTTP 200 only when both are healthy
- Values: "healthy" / "unhealthy" for database and mlflow fields
- Add `import requests as http_requests` and `sa_text` imports
Closes task 5.6 in openspec/changes/code-review-fix/tasks.md
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
d75b05b585
commit
f94d16c6ab
2 changed files with 40 additions and 25 deletions
|
|
@ -46,7 +46,7 @@
|
|||
- [x] 5.3 `[sonnet]` Add `_model_swap_lock` to prediction reads (not just writes) in `services/ml/app/main.py` for thread-safe model access
|
||||
- [x] 5.4 `[sonnet]` Add date range validation (max 1 year) to `POST /predict/batch` in `services/ml/app/main.py`
|
||||
- [x] 5.5 `[sonnet]` Add candle time-sort validation/auto-sort to `POST /predict` in `services/ml/app/main.py`
|
||||
- [ ] 5.6 `[sonnet]` Implement real health checks: `SELECT 1` for PostgreSQL, MLflow API ping in `services/ml/app/main.py:396-409`
|
||||
- [x] 5.6 `[sonnet]` Implement real health checks: `SELECT 1` for PostgreSQL, MLflow API ping in `services/ml/app/main.py:396-409`
|
||||
- [ ] 5.7 `[sonnet]` Add training resource limits: 500MB dataset size check, 30-minute timeout with status update on expiry in `services/ml/app/main.py:907-1030`
|
||||
- [ ] 5.8 `[haiku]` Add `run_id` format validation to `DELETE /training/runs/{run_id}` and `GET /training/runs/{run_id}` endpoints
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,9 @@ from typing import Optional, Dict, Any, List
|
|||
from datetime import datetime
|
||||
import json
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Header, Depends, Security, status
|
||||
import requests as http_requests
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Header, Depends, Security, status, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, Field
|
||||
import numpy as np
|
||||
|
|
@ -24,7 +26,7 @@ import joblib
|
|||
import mlflow
|
||||
import mlflow.sklearn
|
||||
import mlflow.xgboost
|
||||
from sqlalchemy import update as sa_update, desc
|
||||
from sqlalchemy import update as sa_update, desc, text as sa_text
|
||||
|
||||
from app.config import load_config, PipelineConfig, get_default_config
|
||||
from app.db import get_db, TrainingRun, init_db
|
||||
|
|
@ -183,8 +185,8 @@ class HealthResponse(BaseModel):
|
|||
"""Health check response."""
|
||||
status: str = Field(..., description="healthy, degraded, or unhealthy")
|
||||
model_loaded: bool
|
||||
mlflow: str = Field(..., description="connected or disconnected")
|
||||
database: str = Field(..., description="connected or disconnected")
|
||||
mlflow: str = Field(..., description="healthy or unhealthy")
|
||||
database: str = Field(..., description="healthy or unhealthy")
|
||||
|
||||
|
||||
# --- Model Loading Functions ---
|
||||
|
|
@ -473,43 +475,56 @@ async def startup_event():
|
|||
# --- Health Check ---
|
||||
|
||||
@app.get("/health", response_model=HealthResponse)
|
||||
async def health_check():
|
||||
async def health_check(response: Response):
|
||||
"""
|
||||
Health check endpoint.
|
||||
|
||||
|
||||
Returns service status, model loaded status, and dependency health.
|
||||
Returns HTTP 503 if any component (database or MLflow) is unhealthy.
|
||||
"""
|
||||
model_loaded = state.model is not None
|
||||
|
||||
# Check MLflow connection
|
||||
mlflow_status = "disconnected"
|
||||
|
||||
# Check database connection with a real SELECT 1
|
||||
db_status = "unhealthy"
|
||||
try:
|
||||
# TODO: Actually check MLflow connection
|
||||
mlflow_status = "connected"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check database connection
|
||||
db_status = "disconnected"
|
||||
with get_db() as db:
|
||||
db.execute(sa_text("SELECT 1"))
|
||||
db_status = "healthy"
|
||||
except Exception as exc:
|
||||
logger.warning(f"Health check: database unhealthy: {exc}")
|
||||
|
||||
# Check MLflow connection via HTTP GET to its health endpoint
|
||||
mlflow_status = "unhealthy"
|
||||
try:
|
||||
# TODO: Actually check database connection
|
||||
db_status = "connected"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
mlflow_tracking_uri = os.getenv("MLFLOW_TRACKING_URI", "http://localhost:5000")
|
||||
mlflow_health_url = f"{mlflow_tracking_uri.rstrip('/')}/health"
|
||||
resp = http_requests.get(mlflow_health_url, timeout=3)
|
||||
if resp.status_code == 200:
|
||||
mlflow_status = "healthy"
|
||||
else:
|
||||
logger.warning(
|
||||
f"Health check: MLflow returned HTTP {resp.status_code}"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Health check: MLflow unreachable: {exc}")
|
||||
|
||||
# Determine overall status
|
||||
if model_loaded and mlflow_status == "connected" and db_status == "connected":
|
||||
if db_status == "healthy" and mlflow_status == "healthy":
|
||||
overall_status = "healthy"
|
||||
elif model_loaded:
|
||||
overall_status = "degraded"
|
||||
else:
|
||||
overall_status = "unhealthy"
|
||||
|
||||
|
||||
# Return HTTP 503 if any dependency is unhealthy
|
||||
if db_status != "healthy" or mlflow_status != "healthy":
|
||||
response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
|
||||
return HealthResponse(
|
||||
status=overall_status,
|
||||
model_loaded=model_loaded,
|
||||
mlflow=mlflow_status,
|
||||
database=db_status
|
||||
database=db_status,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue