From f94d16c6abcff3ea162722aab470bbef04fc672a Mon Sep 17 00:00:00 2001 From: Marko Djordjevic Date: Wed, 18 Feb 2026 11:30:12 +0100 Subject: [PATCH] fix: implement real health checks in ML service /health endpoint - Execute SELECT 1 against PostgreSQL via SQLAlchemy session to verify DB connectivity - HTTP GET to ${MLFLOW_TRACKING_URI}/health (default: http://localhost:5000/health) to verify MLflow - Return HTTP 503 if any component is unhealthy, HTTP 200 only when both are healthy - Values: "healthy" / "unhealthy" for database and mlflow fields - Add `import requests as http_requests` and `sa_text` imports Closes task 5.6 in openspec/changes/code-review-fix/tasks.md Co-Authored-By: Claude Sonnet 4.6 --- openspec/changes/code-review-fix/tasks.md | 2 +- services/ml/app/main.py | 63 ++++++++++++++--------- 2 files changed, 40 insertions(+), 25 deletions(-) diff --git a/openspec/changes/code-review-fix/tasks.md b/openspec/changes/code-review-fix/tasks.md index ec9fa0b..de1ef52 100644 --- a/openspec/changes/code-review-fix/tasks.md +++ b/openspec/changes/code-review-fix/tasks.md @@ -46,7 +46,7 @@ - [x] 5.3 `[sonnet]` Add `_model_swap_lock` to prediction reads (not just writes) in `services/ml/app/main.py` for thread-safe model access - [x] 5.4 `[sonnet]` Add date range validation (max 1 year) to `POST /predict/batch` in `services/ml/app/main.py` - [x] 5.5 `[sonnet]` Add candle time-sort validation/auto-sort to `POST /predict` in `services/ml/app/main.py` -- [ ] 5.6 `[sonnet]` Implement real health checks: `SELECT 1` for PostgreSQL, MLflow API ping in `services/ml/app/main.py:396-409` +- [x] 5.6 `[sonnet]` Implement real health checks: `SELECT 1` for PostgreSQL, MLflow API ping in `services/ml/app/main.py:396-409` - [ ] 5.7 `[sonnet]` Add training resource limits: 500MB dataset size check, 30-minute timeout with status update on expiry in `services/ml/app/main.py:907-1030` - [ ] 5.8 `[haiku]` Add `run_id` format validation to `DELETE /training/runs/{run_id}` and `GET /training/runs/{run_id}` endpoints diff --git a/services/ml/app/main.py b/services/ml/app/main.py index c8bb5c8..152fb95 100644 --- a/services/ml/app/main.py +++ b/services/ml/app/main.py @@ -15,7 +15,9 @@ from typing import Optional, Dict, Any, List from datetime import datetime import json -from fastapi import FastAPI, HTTPException, Header, Depends, Security, status +import requests as http_requests + +from fastapi import FastAPI, HTTPException, Header, Depends, Security, status, Response from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field import numpy as np @@ -24,7 +26,7 @@ import joblib import mlflow import mlflow.sklearn import mlflow.xgboost -from sqlalchemy import update as sa_update, desc +from sqlalchemy import update as sa_update, desc, text as sa_text from app.config import load_config, PipelineConfig, get_default_config from app.db import get_db, TrainingRun, init_db @@ -183,8 +185,8 @@ class HealthResponse(BaseModel): """Health check response.""" status: str = Field(..., description="healthy, degraded, or unhealthy") model_loaded: bool - mlflow: str = Field(..., description="connected or disconnected") - database: str = Field(..., description="connected or disconnected") + mlflow: str = Field(..., description="healthy or unhealthy") + database: str = Field(..., description="healthy or unhealthy") # --- Model Loading Functions --- @@ -473,43 +475,56 @@ async def startup_event(): # --- Health Check --- @app.get("/health", response_model=HealthResponse) -async def health_check(): +async def health_check(response: Response): """ Health check endpoint. - + Returns service status, model loaded status, and dependency health. + Returns HTTP 503 if any component (database or MLflow) is unhealthy. """ model_loaded = state.model is not None - - # Check MLflow connection - mlflow_status = "disconnected" + + # Check database connection with a real SELECT 1 + db_status = "unhealthy" try: - # TODO: Actually check MLflow connection - mlflow_status = "connected" - except Exception: - pass - - # Check database connection - db_status = "disconnected" + with get_db() as db: + db.execute(sa_text("SELECT 1")) + db_status = "healthy" + except Exception as exc: + logger.warning(f"Health check: database unhealthy: {exc}") + + # Check MLflow connection via HTTP GET to its health endpoint + mlflow_status = "unhealthy" try: - # TODO: Actually check database connection - db_status = "connected" - except Exception: - pass - + mlflow_tracking_uri = os.getenv("MLFLOW_TRACKING_URI", "http://localhost:5000") + mlflow_health_url = f"{mlflow_tracking_uri.rstrip('/')}/health" + resp = http_requests.get(mlflow_health_url, timeout=3) + if resp.status_code == 200: + mlflow_status = "healthy" + else: + logger.warning( + f"Health check: MLflow returned HTTP {resp.status_code}" + ) + except Exception as exc: + logger.warning(f"Health check: MLflow unreachable: {exc}") + # Determine overall status - if model_loaded and mlflow_status == "connected" and db_status == "connected": + if db_status == "healthy" and mlflow_status == "healthy": overall_status = "healthy" elif model_loaded: overall_status = "degraded" else: overall_status = "unhealthy" - + + # Return HTTP 503 if any dependency is unhealthy + if db_status != "healthy" or mlflow_status != "healthy": + response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE + return HealthResponse( status=overall_status, model_loaded=model_loaded, mlflow=mlflow_status, - database=db_status + database=db_status, )