feat(ml): add API key authentication via FastAPI Depends() on all endpoints except /health

- Import Header, Depends, Security from fastapi
- Add verify_api_key dependency: reads API_KEY env var, checks X-API-Key
  header, raises HTTP 401 if key mismatch; fail-open if env var not set
- Apply Depends(verify_api_key) to all 14 non-health endpoints
- /health endpoint remains unauthenticated for liveness probes
- Mark task 3.2 as complete in tasks.md

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marko Djordjevic 2026-02-18 11:04:41 +01:00
parent 577bb2e56e
commit 5f569d9134
2 changed files with 26 additions and 15 deletions

View file

@ -14,7 +14,7 @@ from typing import Optional, Dict, Any, List
from datetime import datetime
import json
from fastapi import FastAPI, HTTPException, status
from fastapi import FastAPI, HTTPException, Header, Depends, Security, status
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import numpy as np
@ -42,6 +42,17 @@ logging.basicConfig(
)
logger = logging.getLogger(__name__)
# --- API Key Dependency ---
async def verify_api_key(x_api_key: str = Header(default="")):
"""Verify X-API-Key header against API_KEY env var. Fail-open if not configured."""
api_key = os.getenv("API_KEY")
if not api_key:
return # fail-open if not configured
if x_api_key != api_key:
raise HTTPException(status_code=401, detail="Unauthorized")
# FastAPI app
app = FastAPI(
title="Candle Pattern Inference API",
@ -435,7 +446,7 @@ async def health_check():
# --- Model Info Endpoints (Stubs) ---
@app.get("/model/info", response_model=ModelInfoResponse)
@app.get("/model/info", response_model=ModelInfoResponse, dependencies=[Depends(verify_api_key)])
async def get_model_info():
"""
Get detailed model information and per-class metrics.
@ -464,7 +475,7 @@ async def get_model_info():
)
@app.get("/model/labels", response_model=List[LabelInfo])
@app.get("/model/labels", response_model=List[LabelInfo], dependencies=[Depends(verify_api_key)])
async def get_model_labels():
"""
Get all pattern labels the model can predict with display colors.
@ -550,7 +561,7 @@ def group_prediction_spans(
# --- Prediction Endpoints ---
@app.post("/predict", response_model=PredictResponse)
@app.post("/predict", response_model=PredictResponse, dependencies=[Depends(verify_api_key)])
async def predict(request: PredictRequest):
"""
Predict candlestick patterns for provided candles.
@ -650,7 +661,7 @@ async def predict(request: PredictRequest):
)
@app.post("/predict/batch", response_model=PredictResponse)
@app.post("/predict/batch", response_model=PredictResponse, dependencies=[Depends(verify_api_key)])
async def predict_batch(request: BatchPredictRequest):
"""
Batch prediction for a date range.
@ -823,7 +834,7 @@ class DetectPatternsResponse(BaseModel):
metadata: Dict[str, Any]
@app.get("/patterns/available", response_model=List[PatternInfo])
@app.get("/patterns/available", response_model=List[PatternInfo], dependencies=[Depends(verify_api_key)])
async def patterns_available():
"""
Return all supported CDL pattern names with display names.
@ -831,7 +842,7 @@ async def patterns_available():
return get_available_patterns()
@app.post("/patterns/detect", response_model=DetectPatternsResponse)
@app.post("/patterns/detect", response_model=DetectPatternsResponse, dependencies=[Depends(verify_api_key)])
async def patterns_detect(request: DetectPatternsRequest):
"""
Detect TA-Lib CDL patterns on provided candle data.
@ -1116,7 +1127,7 @@ def _run_training_background(run_id: str, model_type: str, config: PipelineConfi
logger.info(f"Training thread exiting: run_id={run_id}")
@app.post("/training/start", response_model=TrainingStartResponse)
@app.post("/training/start", response_model=TrainingStartResponse, dependencies=[Depends(verify_api_key)])
async def training_start(request: TrainingStartRequest):
"""
Start a training run in a background thread.
@ -1190,7 +1201,7 @@ async def training_start(request: TrainingStartRequest):
return TrainingStartResponse(run_id=run_id, status="running")
@app.get("/training/runs", response_model=TrainingRunsResponse)
@app.get("/training/runs", response_model=TrainingRunsResponse, dependencies=[Depends(verify_api_key)])
async def training_runs():
"""
Return training run history from the database, sorted by date descending.
@ -1229,7 +1240,7 @@ class ActiveTrainingResponse(BaseModel):
run_id: Optional[str] = None
@app.get("/training/active", response_model=ActiveTrainingResponse)
@app.get("/training/active", response_model=ActiveTrainingResponse, dependencies=[Depends(verify_api_key)])
async def training_active():
"""
Return whether a training run is currently active and its run_id.
@ -1245,7 +1256,7 @@ class DeleteRunResponse(BaseModel):
deleted: bool
@app.delete("/training/runs/{run_id}", response_model=DeleteRunResponse)
@app.delete("/training/runs/{run_id}", response_model=DeleteRunResponse, dependencies=[Depends(verify_api_key)])
async def delete_training_run(run_id: str):
"""
Delete a training run record and its model artifact.
@ -1312,7 +1323,7 @@ async def delete_training_run(run_id: str):
return DeleteRunResponse(run_id=run_id, deleted=True)
@app.get("/training/dataset-info", response_model=DatasetInfoResponse)
@app.get("/training/dataset-info", response_model=DatasetInfoResponse, dependencies=[Depends(verify_api_key)])
async def training_dataset_info():
"""
Return information about the labeled training dataset.
@ -1362,7 +1373,7 @@ class BuildDatasetResponse(BaseModel):
labeled_path: str
@app.post("/training/build-dataset", response_model=BuildDatasetResponse)
@app.post("/training/build-dataset", response_model=BuildDatasetResponse, dependencies=[Depends(verify_api_key)])
async def training_build_dataset():
"""
Build the labeled training dataset from database annotations.
@ -1405,7 +1416,7 @@ class ModelLoadResponse(BaseModel):
_model_swap_lock = threading.Lock()
@app.post("/model/load", response_model=ModelLoadResponse)
@app.post("/model/load", response_model=ModelLoadResponse, dependencies=[Depends(verify_api_key)])
async def model_load(request: ModelLoadRequest):
"""
Load a trained model from a completed training run.