fix: implement real health checks in ML service /health endpoint
- Execute SELECT 1 against PostgreSQL via SQLAlchemy session to verify DB connectivity
- HTTP GET to ${MLFLOW_TRACKING_URI}/health (default: http://localhost:5000/health) to verify MLflow
- Return HTTP 503 if any component is unhealthy, HTTP 200 only when both are healthy
- Values: "healthy" / "unhealthy" for database and mlflow fields
- Add `import requests as http_requests` and `sa_text` imports
Closes task 5.6 in openspec/changes/code-review-fix/tasks.md
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
d75b05b585
commit
f94d16c6ab
2 changed files with 40 additions and 25 deletions
|
|
@ -46,7 +46,7 @@
|
||||||
- [x] 5.3 `[sonnet]` Add `_model_swap_lock` to prediction reads (not just writes) in `services/ml/app/main.py` for thread-safe model access
|
- [x] 5.3 `[sonnet]` Add `_model_swap_lock` to prediction reads (not just writes) in `services/ml/app/main.py` for thread-safe model access
|
||||||
- [x] 5.4 `[sonnet]` Add date range validation (max 1 year) to `POST /predict/batch` in `services/ml/app/main.py`
|
- [x] 5.4 `[sonnet]` Add date range validation (max 1 year) to `POST /predict/batch` in `services/ml/app/main.py`
|
||||||
- [x] 5.5 `[sonnet]` Add candle time-sort validation/auto-sort to `POST /predict` in `services/ml/app/main.py`
|
- [x] 5.5 `[sonnet]` Add candle time-sort validation/auto-sort to `POST /predict` in `services/ml/app/main.py`
|
||||||
- [ ] 5.6 `[sonnet]` Implement real health checks: `SELECT 1` for PostgreSQL, MLflow API ping in `services/ml/app/main.py:396-409`
|
- [x] 5.6 `[sonnet]` Implement real health checks: `SELECT 1` for PostgreSQL, MLflow API ping in `services/ml/app/main.py:396-409`
|
||||||
- [ ] 5.7 `[sonnet]` Add training resource limits: 500MB dataset size check, 30-minute timeout with status update on expiry in `services/ml/app/main.py:907-1030`
|
- [ ] 5.7 `[sonnet]` Add training resource limits: 500MB dataset size check, 30-minute timeout with status update on expiry in `services/ml/app/main.py:907-1030`
|
||||||
- [ ] 5.8 `[haiku]` Add `run_id` format validation to `DELETE /training/runs/{run_id}` and `GET /training/runs/{run_id}` endpoints
|
- [ ] 5.8 `[haiku]` Add `run_id` format validation to `DELETE /training/runs/{run_id}` and `GET /training/runs/{run_id}` endpoints
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,9 @@ from typing import Optional, Dict, Any, List
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException, Header, Depends, Security, status
|
import requests as http_requests
|
||||||
|
|
||||||
|
from fastapi import FastAPI, HTTPException, Header, Depends, Security, status, Response
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -24,7 +26,7 @@ import joblib
|
||||||
import mlflow
|
import mlflow
|
||||||
import mlflow.sklearn
|
import mlflow.sklearn
|
||||||
import mlflow.xgboost
|
import mlflow.xgboost
|
||||||
from sqlalchemy import update as sa_update, desc
|
from sqlalchemy import update as sa_update, desc, text as sa_text
|
||||||
|
|
||||||
from app.config import load_config, PipelineConfig, get_default_config
|
from app.config import load_config, PipelineConfig, get_default_config
|
||||||
from app.db import get_db, TrainingRun, init_db
|
from app.db import get_db, TrainingRun, init_db
|
||||||
|
|
@ -183,8 +185,8 @@ class HealthResponse(BaseModel):
|
||||||
"""Health check response."""
|
"""Health check response."""
|
||||||
status: str = Field(..., description="healthy, degraded, or unhealthy")
|
status: str = Field(..., description="healthy, degraded, or unhealthy")
|
||||||
model_loaded: bool
|
model_loaded: bool
|
||||||
mlflow: str = Field(..., description="connected or disconnected")
|
mlflow: str = Field(..., description="healthy or unhealthy")
|
||||||
database: str = Field(..., description="connected or disconnected")
|
database: str = Field(..., description="healthy or unhealthy")
|
||||||
|
|
||||||
|
|
||||||
# --- Model Loading Functions ---
|
# --- Model Loading Functions ---
|
||||||
|
|
@ -473,43 +475,56 @@ async def startup_event():
|
||||||
# --- Health Check ---
|
# --- Health Check ---
|
||||||
|
|
||||||
@app.get("/health", response_model=HealthResponse)
|
@app.get("/health", response_model=HealthResponse)
|
||||||
async def health_check():
|
async def health_check(response: Response):
|
||||||
"""
|
"""
|
||||||
Health check endpoint.
|
Health check endpoint.
|
||||||
|
|
||||||
Returns service status, model loaded status, and dependency health.
|
Returns service status, model loaded status, and dependency health.
|
||||||
|
Returns HTTP 503 if any component (database or MLflow) is unhealthy.
|
||||||
"""
|
"""
|
||||||
model_loaded = state.model is not None
|
model_loaded = state.model is not None
|
||||||
|
|
||||||
# Check MLflow connection
|
# Check database connection with a real SELECT 1
|
||||||
mlflow_status = "disconnected"
|
db_status = "unhealthy"
|
||||||
try:
|
try:
|
||||||
# TODO: Actually check MLflow connection
|
with get_db() as db:
|
||||||
mlflow_status = "connected"
|
db.execute(sa_text("SELECT 1"))
|
||||||
except Exception:
|
db_status = "healthy"
|
||||||
pass
|
except Exception as exc:
|
||||||
|
logger.warning(f"Health check: database unhealthy: {exc}")
|
||||||
|
|
||||||
# Check database connection
|
# Check MLflow connection via HTTP GET to its health endpoint
|
||||||
db_status = "disconnected"
|
mlflow_status = "unhealthy"
|
||||||
try:
|
try:
|
||||||
# TODO: Actually check database connection
|
mlflow_tracking_uri = os.getenv("MLFLOW_TRACKING_URI", "http://localhost:5000")
|
||||||
db_status = "connected"
|
mlflow_health_url = f"{mlflow_tracking_uri.rstrip('/')}/health"
|
||||||
except Exception:
|
resp = http_requests.get(mlflow_health_url, timeout=3)
|
||||||
pass
|
if resp.status_code == 200:
|
||||||
|
mlflow_status = "healthy"
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Health check: MLflow returned HTTP {resp.status_code}"
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"Health check: MLflow unreachable: {exc}")
|
||||||
|
|
||||||
# Determine overall status
|
# Determine overall status
|
||||||
if model_loaded and mlflow_status == "connected" and db_status == "connected":
|
if db_status == "healthy" and mlflow_status == "healthy":
|
||||||
overall_status = "healthy"
|
overall_status = "healthy"
|
||||||
elif model_loaded:
|
elif model_loaded:
|
||||||
overall_status = "degraded"
|
overall_status = "degraded"
|
||||||
else:
|
else:
|
||||||
overall_status = "unhealthy"
|
overall_status = "unhealthy"
|
||||||
|
|
||||||
|
# Return HTTP 503 if any dependency is unhealthy
|
||||||
|
if db_status != "healthy" or mlflow_status != "healthy":
|
||||||
|
response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
|
||||||
|
|
||||||
return HealthResponse(
|
return HealthResponse(
|
||||||
status=overall_status,
|
status=overall_status,
|
||||||
model_loaded=model_loaded,
|
model_loaded=model_loaded,
|
||||||
mlflow=mlflow_status,
|
mlflow=mlflow_status,
|
||||||
database=db_status
|
database=db_status,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue