fix: implement real health checks in ML service /health endpoint

- Execute SELECT 1 against PostgreSQL via SQLAlchemy session to verify DB connectivity
- HTTP GET to ${MLFLOW_TRACKING_URI}/health (default: http://localhost:5000/health) to verify MLflow
- Return HTTP 503 if any component is unhealthy, HTTP 200 only when both are healthy
- Values: "healthy" / "unhealthy" for database and mlflow fields
- Add `import requests as http_requests` and `sa_text` imports

Closes task 5.6 in openspec/changes/code-review-fix/tasks.md

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marko Djordjevic 2026-02-18 11:30:12 +01:00
parent d75b05b585
commit f94d16c6ab
2 changed files with 40 additions and 25 deletions

View file

@ -46,7 +46,7 @@
- [x] 5.3 `[sonnet]` Add `_model_swap_lock` to prediction reads (not just writes) in `services/ml/app/main.py` for thread-safe model access - [x] 5.3 `[sonnet]` Add `_model_swap_lock` to prediction reads (not just writes) in `services/ml/app/main.py` for thread-safe model access
- [x] 5.4 `[sonnet]` Add date range validation (max 1 year) to `POST /predict/batch` in `services/ml/app/main.py` - [x] 5.4 `[sonnet]` Add date range validation (max 1 year) to `POST /predict/batch` in `services/ml/app/main.py`
- [x] 5.5 `[sonnet]` Add candle time-sort validation/auto-sort to `POST /predict` in `services/ml/app/main.py` - [x] 5.5 `[sonnet]` Add candle time-sort validation/auto-sort to `POST /predict` in `services/ml/app/main.py`
- [ ] 5.6 `[sonnet]` Implement real health checks: `SELECT 1` for PostgreSQL, MLflow API ping in `services/ml/app/main.py:396-409` - [x] 5.6 `[sonnet]` Implement real health checks: `SELECT 1` for PostgreSQL, MLflow API ping in `services/ml/app/main.py:396-409`
- [ ] 5.7 `[sonnet]` Add training resource limits: 500MB dataset size check, 30-minute timeout with status update on expiry in `services/ml/app/main.py:907-1030` - [ ] 5.7 `[sonnet]` Add training resource limits: 500MB dataset size check, 30-minute timeout with status update on expiry in `services/ml/app/main.py:907-1030`
- [ ] 5.8 `[haiku]` Add `run_id` format validation to `DELETE /training/runs/{run_id}` and `GET /training/runs/{run_id}` endpoints - [ ] 5.8 `[haiku]` Add `run_id` format validation to `DELETE /training/runs/{run_id}` and `GET /training/runs/{run_id}` endpoints

View file

@ -15,7 +15,9 @@ from typing import Optional, Dict, Any, List
from datetime import datetime from datetime import datetime
import json import json
from fastapi import FastAPI, HTTPException, Header, Depends, Security, status import requests as http_requests
from fastapi import FastAPI, HTTPException, Header, Depends, Security, status, Response
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import numpy as np import numpy as np
@ -24,7 +26,7 @@ import joblib
import mlflow import mlflow
import mlflow.sklearn import mlflow.sklearn
import mlflow.xgboost import mlflow.xgboost
from sqlalchemy import update as sa_update, desc from sqlalchemy import update as sa_update, desc, text as sa_text
from app.config import load_config, PipelineConfig, get_default_config from app.config import load_config, PipelineConfig, get_default_config
from app.db import get_db, TrainingRun, init_db from app.db import get_db, TrainingRun, init_db
@ -183,8 +185,8 @@ class HealthResponse(BaseModel):
"""Health check response.""" """Health check response."""
status: str = Field(..., description="healthy, degraded, or unhealthy") status: str = Field(..., description="healthy, degraded, or unhealthy")
model_loaded: bool model_loaded: bool
mlflow: str = Field(..., description="connected or disconnected") mlflow: str = Field(..., description="healthy or unhealthy")
database: str = Field(..., description="connected or disconnected") database: str = Field(..., description="healthy or unhealthy")
# --- Model Loading Functions --- # --- Model Loading Functions ---
@ -473,43 +475,56 @@ async def startup_event():
# --- Health Check --- # --- Health Check ---
@app.get("/health", response_model=HealthResponse) @app.get("/health", response_model=HealthResponse)
async def health_check(): async def health_check(response: Response):
""" """
Health check endpoint. Health check endpoint.
Returns service status, model loaded status, and dependency health. Returns service status, model loaded status, and dependency health.
Returns HTTP 503 if any component (database or MLflow) is unhealthy.
""" """
model_loaded = state.model is not None model_loaded = state.model is not None
# Check MLflow connection # Check database connection with a real SELECT 1
mlflow_status = "disconnected" db_status = "unhealthy"
try: try:
# TODO: Actually check MLflow connection with get_db() as db:
mlflow_status = "connected" db.execute(sa_text("SELECT 1"))
except Exception: db_status = "healthy"
pass except Exception as exc:
logger.warning(f"Health check: database unhealthy: {exc}")
# Check database connection
db_status = "disconnected" # Check MLflow connection via HTTP GET to its health endpoint
mlflow_status = "unhealthy"
try: try:
# TODO: Actually check database connection mlflow_tracking_uri = os.getenv("MLFLOW_TRACKING_URI", "http://localhost:5000")
db_status = "connected" mlflow_health_url = f"{mlflow_tracking_uri.rstrip('/')}/health"
except Exception: resp = http_requests.get(mlflow_health_url, timeout=3)
pass if resp.status_code == 200:
mlflow_status = "healthy"
else:
logger.warning(
f"Health check: MLflow returned HTTP {resp.status_code}"
)
except Exception as exc:
logger.warning(f"Health check: MLflow unreachable: {exc}")
# Determine overall status # Determine overall status
if model_loaded and mlflow_status == "connected" and db_status == "connected": if db_status == "healthy" and mlflow_status == "healthy":
overall_status = "healthy" overall_status = "healthy"
elif model_loaded: elif model_loaded:
overall_status = "degraded" overall_status = "degraded"
else: else:
overall_status = "unhealthy" overall_status = "unhealthy"
# Return HTTP 503 if any dependency is unhealthy
if db_status != "healthy" or mlflow_status != "healthy":
response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
return HealthResponse( return HealthResponse(
status=overall_status, status=overall_status,
model_loaded=model_loaded, model_loaded=model_loaded,
mlflow=mlflow_status, mlflow=mlflow_status,
database=db_status database=db_status,
) )