fix: implement real health checks in ML service /health endpoint

- Execute SELECT 1 against PostgreSQL via SQLAlchemy session to verify DB connectivity
- HTTP GET to ${MLFLOW_TRACKING_URI}/health (default: http://localhost:5000/health) to verify MLflow
- Return HTTP 503 if any component is unhealthy, HTTP 200 only when both are healthy
- Values: "healthy" / "unhealthy" for database and mlflow fields
- Add `import requests as http_requests` and `sa_text` imports

Closes task 5.6 in openspec/changes/code-review-fix/tasks.md

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marko Djordjevic 2026-02-18 11:30:12 +01:00
parent d75b05b585
commit f94d16c6ab
2 changed files with 40 additions and 25 deletions

View file

@ -46,7 +46,7 @@
- [x] 5.3 `[sonnet]` Add `_model_swap_lock` to prediction reads (not just writes) in `services/ml/app/main.py` for thread-safe model access
- [x] 5.4 `[sonnet]` Add date range validation (max 1 year) to `POST /predict/batch` in `services/ml/app/main.py`
- [x] 5.5 `[sonnet]` Add candle time-sort validation/auto-sort to `POST /predict` in `services/ml/app/main.py`
- [ ] 5.6 `[sonnet]` Implement real health checks: `SELECT 1` for PostgreSQL, MLflow API ping in `services/ml/app/main.py:396-409`
- [x] 5.6 `[sonnet]` Implement real health checks: `SELECT 1` for PostgreSQL, MLflow API ping in `services/ml/app/main.py:396-409`
- [ ] 5.7 `[sonnet]` Add training resource limits: 500MB dataset size check, 30-minute timeout with status update on expiry in `services/ml/app/main.py:907-1030`
- [ ] 5.8 `[haiku]` Add `run_id` format validation to `DELETE /training/runs/{run_id}` and `GET /training/runs/{run_id}` endpoints

View file

@ -15,7 +15,9 @@ from typing import Optional, Dict, Any, List
from datetime import datetime
import json
from fastapi import FastAPI, HTTPException, Header, Depends, Security, status
import requests as http_requests
from fastapi import FastAPI, HTTPException, Header, Depends, Security, status, Response
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import numpy as np
@ -24,7 +26,7 @@ import joblib
import mlflow
import mlflow.sklearn
import mlflow.xgboost
from sqlalchemy import update as sa_update, desc
from sqlalchemy import update as sa_update, desc, text as sa_text
from app.config import load_config, PipelineConfig, get_default_config
from app.db import get_db, TrainingRun, init_db
@ -183,8 +185,8 @@ class HealthResponse(BaseModel):
"""Health check response."""
status: str = Field(..., description="healthy, degraded, or unhealthy")
model_loaded: bool
mlflow: str = Field(..., description="connected or disconnected")
database: str = Field(..., description="connected or disconnected")
mlflow: str = Field(..., description="healthy or unhealthy")
database: str = Field(..., description="healthy or unhealthy")
# --- Model Loading Functions ---
@ -473,43 +475,56 @@ async def startup_event():
# --- Health Check ---
@app.get("/health", response_model=HealthResponse)
async def health_check():
async def health_check(response: Response):
"""
Health check endpoint.
Returns service status, model loaded status, and dependency health.
Returns HTTP 503 if any component (database or MLflow) is unhealthy.
"""
model_loaded = state.model is not None
# Check MLflow connection
mlflow_status = "disconnected"
# Check database connection with a real SELECT 1
db_status = "unhealthy"
try:
# TODO: Actually check MLflow connection
mlflow_status = "connected"
except Exception:
pass
# Check database connection
db_status = "disconnected"
with get_db() as db:
db.execute(sa_text("SELECT 1"))
db_status = "healthy"
except Exception as exc:
logger.warning(f"Health check: database unhealthy: {exc}")
# Check MLflow connection via HTTP GET to its health endpoint
mlflow_status = "unhealthy"
try:
# TODO: Actually check database connection
db_status = "connected"
except Exception:
pass
mlflow_tracking_uri = os.getenv("MLFLOW_TRACKING_URI", "http://localhost:5000")
mlflow_health_url = f"{mlflow_tracking_uri.rstrip('/')}/health"
resp = http_requests.get(mlflow_health_url, timeout=3)
if resp.status_code == 200:
mlflow_status = "healthy"
else:
logger.warning(
f"Health check: MLflow returned HTTP {resp.status_code}"
)
except Exception as exc:
logger.warning(f"Health check: MLflow unreachable: {exc}")
# Determine overall status
if model_loaded and mlflow_status == "connected" and db_status == "connected":
if db_status == "healthy" and mlflow_status == "healthy":
overall_status = "healthy"
elif model_loaded:
overall_status = "degraded"
else:
overall_status = "unhealthy"
# Return HTTP 503 if any dependency is unhealthy
if db_status != "healthy" or mlflow_status != "healthy":
response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
return HealthResponse(
status=overall_status,
model_loaded=model_loaded,
mlflow=mlflow_status,
database=db_status
database=db_status,
)