From 5f569d9134c78b4ef5a48e72b7d895d753e1ec13 Mon Sep 17 00:00:00 2001 From: Marko Djordjevic Date: Wed, 18 Feb 2026 11:04:41 +0100 Subject: [PATCH] feat(ml): add API key authentication via FastAPI Depends() on all endpoints except /health - Import Header, Depends, Security from fastapi - Add verify_api_key dependency: reads API_KEY env var, checks X-API-Key header, raises HTTP 401 if key mismatch; fail-open if env var not set - Apply Depends(verify_api_key) to all 14 non-health endpoints - /health endpoint remains unauthenticated for liveness probes - Mark task 3.2 as complete in tasks.md Co-Authored-By: Claude Sonnet 4.6 --- openspec/changes/code-review-fix/tasks.md | 2 +- services/ml/app/main.py | 39 +++++++++++++++-------- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/openspec/changes/code-review-fix/tasks.md b/openspec/changes/code-review-fix/tasks.md index cbe64bf..0be4f62 100644 --- a/openspec/changes/code-review-fix/tasks.md +++ b/openspec/changes/code-review-fix/tasks.md @@ -20,7 +20,7 @@ ## 3. Authentication - [x] 3.1 `[sonnet]` Create `src/middleware.ts` with API key auth middleware: read `API_KEY` env var, check `X-API-Key` header on all `/api/*` routes except `/api/health`, return 401 if invalid -- [ ] 3.2 `[sonnet]` Add FastAPI `Depends()` API key dependency in `services/ml/app/main.py`: read `API_KEY` env var, check `X-API-Key` header, exempt `/health` endpoint +- [x] 3.2 `[sonnet]` Add FastAPI `Depends()` API key dependency in `services/ml/app/main.py`: read `API_KEY` env var, check `X-API-Key` header, exempt `/health` endpoint - [ ] 3.3 `[sonnet]` Update all Next.js proxy routes to forward `X-API-Key` header to ML service - [ ] 3.4 `[haiku]` Add `API_KEY` to `.env.example` with placeholder value and instructions diff --git a/services/ml/app/main.py b/services/ml/app/main.py index c3ffea8..700cc98 100644 --- a/services/ml/app/main.py +++ b/services/ml/app/main.py @@ -14,7 +14,7 @@ from typing import Optional, Dict, Any, List from datetime import datetime import json -from fastapi import FastAPI, HTTPException, status +from fastapi import FastAPI, HTTPException, Header, Depends, Security, status from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field import numpy as np @@ -42,6 +42,17 @@ logging.basicConfig( ) logger = logging.getLogger(__name__) +# --- API Key Dependency --- + +async def verify_api_key(x_api_key: str = Header(default="")): + """Verify X-API-Key header against API_KEY env var. Fail-open if not configured.""" + api_key = os.getenv("API_KEY") + if not api_key: + return # fail-open if not configured + if x_api_key != api_key: + raise HTTPException(status_code=401, detail="Unauthorized") + + # FastAPI app app = FastAPI( title="Candle Pattern Inference API", @@ -435,7 +446,7 @@ async def health_check(): # --- Model Info Endpoints (Stubs) --- -@app.get("/model/info", response_model=ModelInfoResponse) +@app.get("/model/info", response_model=ModelInfoResponse, dependencies=[Depends(verify_api_key)]) async def get_model_info(): """ Get detailed model information and per-class metrics. @@ -464,7 +475,7 @@ async def get_model_info(): ) -@app.get("/model/labels", response_model=List[LabelInfo]) +@app.get("/model/labels", response_model=List[LabelInfo], dependencies=[Depends(verify_api_key)]) async def get_model_labels(): """ Get all pattern labels the model can predict with display colors. @@ -550,7 +561,7 @@ def group_prediction_spans( # --- Prediction Endpoints --- -@app.post("/predict", response_model=PredictResponse) +@app.post("/predict", response_model=PredictResponse, dependencies=[Depends(verify_api_key)]) async def predict(request: PredictRequest): """ Predict candlestick patterns for provided candles. @@ -650,7 +661,7 @@ async def predict(request: PredictRequest): ) -@app.post("/predict/batch", response_model=PredictResponse) +@app.post("/predict/batch", response_model=PredictResponse, dependencies=[Depends(verify_api_key)]) async def predict_batch(request: BatchPredictRequest): """ Batch prediction for a date range. @@ -823,7 +834,7 @@ class DetectPatternsResponse(BaseModel): metadata: Dict[str, Any] -@app.get("/patterns/available", response_model=List[PatternInfo]) +@app.get("/patterns/available", response_model=List[PatternInfo], dependencies=[Depends(verify_api_key)]) async def patterns_available(): """ Return all supported CDL pattern names with display names. @@ -831,7 +842,7 @@ async def patterns_available(): return get_available_patterns() -@app.post("/patterns/detect", response_model=DetectPatternsResponse) +@app.post("/patterns/detect", response_model=DetectPatternsResponse, dependencies=[Depends(verify_api_key)]) async def patterns_detect(request: DetectPatternsRequest): """ Detect TA-Lib CDL patterns on provided candle data. @@ -1116,7 +1127,7 @@ def _run_training_background(run_id: str, model_type: str, config: PipelineConfi logger.info(f"Training thread exiting: run_id={run_id}") -@app.post("/training/start", response_model=TrainingStartResponse) +@app.post("/training/start", response_model=TrainingStartResponse, dependencies=[Depends(verify_api_key)]) async def training_start(request: TrainingStartRequest): """ Start a training run in a background thread. @@ -1190,7 +1201,7 @@ async def training_start(request: TrainingStartRequest): return TrainingStartResponse(run_id=run_id, status="running") -@app.get("/training/runs", response_model=TrainingRunsResponse) +@app.get("/training/runs", response_model=TrainingRunsResponse, dependencies=[Depends(verify_api_key)]) async def training_runs(): """ Return training run history from the database, sorted by date descending. @@ -1229,7 +1240,7 @@ class ActiveTrainingResponse(BaseModel): run_id: Optional[str] = None -@app.get("/training/active", response_model=ActiveTrainingResponse) +@app.get("/training/active", response_model=ActiveTrainingResponse, dependencies=[Depends(verify_api_key)]) async def training_active(): """ Return whether a training run is currently active and its run_id. @@ -1245,7 +1256,7 @@ class DeleteRunResponse(BaseModel): deleted: bool -@app.delete("/training/runs/{run_id}", response_model=DeleteRunResponse) +@app.delete("/training/runs/{run_id}", response_model=DeleteRunResponse, dependencies=[Depends(verify_api_key)]) async def delete_training_run(run_id: str): """ Delete a training run record and its model artifact. @@ -1312,7 +1323,7 @@ async def delete_training_run(run_id: str): return DeleteRunResponse(run_id=run_id, deleted=True) -@app.get("/training/dataset-info", response_model=DatasetInfoResponse) +@app.get("/training/dataset-info", response_model=DatasetInfoResponse, dependencies=[Depends(verify_api_key)]) async def training_dataset_info(): """ Return information about the labeled training dataset. @@ -1362,7 +1373,7 @@ class BuildDatasetResponse(BaseModel): labeled_path: str -@app.post("/training/build-dataset", response_model=BuildDatasetResponse) +@app.post("/training/build-dataset", response_model=BuildDatasetResponse, dependencies=[Depends(verify_api_key)]) async def training_build_dataset(): """ Build the labeled training dataset from database annotations. @@ -1405,7 +1416,7 @@ class ModelLoadResponse(BaseModel): _model_swap_lock = threading.Lock() -@app.post("/model/load", response_model=ModelLoadResponse) +@app.post("/model/load", response_model=ModelLoadResponse, dependencies=[Depends(verify_api_key)]) async def model_load(request: ModelLoadRequest): """ Load a trained model from a completed training run.