feat(ml): add API key authentication via FastAPI Depends() on all endpoints except /health
- Import Header, Depends, Security from fastapi - Add verify_api_key dependency: reads API_KEY env var, checks X-API-Key header, raises HTTP 401 if key mismatch; fail-open if env var not set - Apply Depends(verify_api_key) to all 14 non-health endpoints - /health endpoint remains unauthenticated for liveness probes - Mark task 3.2 as complete in tasks.md Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
577bb2e56e
commit
5f569d9134
2 changed files with 26 additions and 15 deletions
|
|
@ -14,7 +14,7 @@ from typing import Optional, Dict, Any, List
|
|||
from datetime import datetime
|
||||
import json
|
||||
|
||||
from fastapi import FastAPI, HTTPException, status
|
||||
from fastapi import FastAPI, HTTPException, Header, Depends, Security, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, Field
|
||||
import numpy as np
|
||||
|
|
@ -42,6 +42,17 @@ logging.basicConfig(
|
|||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- API Key Dependency ---
|
||||
|
||||
async def verify_api_key(x_api_key: str = Header(default="")):
|
||||
"""Verify X-API-Key header against API_KEY env var. Fail-open if not configured."""
|
||||
api_key = os.getenv("API_KEY")
|
||||
if not api_key:
|
||||
return # fail-open if not configured
|
||||
if x_api_key != api_key:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
|
||||
# FastAPI app
|
||||
app = FastAPI(
|
||||
title="Candle Pattern Inference API",
|
||||
|
|
@ -435,7 +446,7 @@ async def health_check():
|
|||
|
||||
# --- Model Info Endpoints (Stubs) ---
|
||||
|
||||
@app.get("/model/info", response_model=ModelInfoResponse)
|
||||
@app.get("/model/info", response_model=ModelInfoResponse, dependencies=[Depends(verify_api_key)])
|
||||
async def get_model_info():
|
||||
"""
|
||||
Get detailed model information and per-class metrics.
|
||||
|
|
@ -464,7 +475,7 @@ async def get_model_info():
|
|||
)
|
||||
|
||||
|
||||
@app.get("/model/labels", response_model=List[LabelInfo])
|
||||
@app.get("/model/labels", response_model=List[LabelInfo], dependencies=[Depends(verify_api_key)])
|
||||
async def get_model_labels():
|
||||
"""
|
||||
Get all pattern labels the model can predict with display colors.
|
||||
|
|
@ -550,7 +561,7 @@ def group_prediction_spans(
|
|||
|
||||
# --- Prediction Endpoints ---
|
||||
|
||||
@app.post("/predict", response_model=PredictResponse)
|
||||
@app.post("/predict", response_model=PredictResponse, dependencies=[Depends(verify_api_key)])
|
||||
async def predict(request: PredictRequest):
|
||||
"""
|
||||
Predict candlestick patterns for provided candles.
|
||||
|
|
@ -650,7 +661,7 @@ async def predict(request: PredictRequest):
|
|||
)
|
||||
|
||||
|
||||
@app.post("/predict/batch", response_model=PredictResponse)
|
||||
@app.post("/predict/batch", response_model=PredictResponse, dependencies=[Depends(verify_api_key)])
|
||||
async def predict_batch(request: BatchPredictRequest):
|
||||
"""
|
||||
Batch prediction for a date range.
|
||||
|
|
@ -823,7 +834,7 @@ class DetectPatternsResponse(BaseModel):
|
|||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
@app.get("/patterns/available", response_model=List[PatternInfo])
|
||||
@app.get("/patterns/available", response_model=List[PatternInfo], dependencies=[Depends(verify_api_key)])
|
||||
async def patterns_available():
|
||||
"""
|
||||
Return all supported CDL pattern names with display names.
|
||||
|
|
@ -831,7 +842,7 @@ async def patterns_available():
|
|||
return get_available_patterns()
|
||||
|
||||
|
||||
@app.post("/patterns/detect", response_model=DetectPatternsResponse)
|
||||
@app.post("/patterns/detect", response_model=DetectPatternsResponse, dependencies=[Depends(verify_api_key)])
|
||||
async def patterns_detect(request: DetectPatternsRequest):
|
||||
"""
|
||||
Detect TA-Lib CDL patterns on provided candle data.
|
||||
|
|
@ -1116,7 +1127,7 @@ def _run_training_background(run_id: str, model_type: str, config: PipelineConfi
|
|||
logger.info(f"Training thread exiting: run_id={run_id}")
|
||||
|
||||
|
||||
@app.post("/training/start", response_model=TrainingStartResponse)
|
||||
@app.post("/training/start", response_model=TrainingStartResponse, dependencies=[Depends(verify_api_key)])
|
||||
async def training_start(request: TrainingStartRequest):
|
||||
"""
|
||||
Start a training run in a background thread.
|
||||
|
|
@ -1190,7 +1201,7 @@ async def training_start(request: TrainingStartRequest):
|
|||
return TrainingStartResponse(run_id=run_id, status="running")
|
||||
|
||||
|
||||
@app.get("/training/runs", response_model=TrainingRunsResponse)
|
||||
@app.get("/training/runs", response_model=TrainingRunsResponse, dependencies=[Depends(verify_api_key)])
|
||||
async def training_runs():
|
||||
"""
|
||||
Return training run history from the database, sorted by date descending.
|
||||
|
|
@ -1229,7 +1240,7 @@ class ActiveTrainingResponse(BaseModel):
|
|||
run_id: Optional[str] = None
|
||||
|
||||
|
||||
@app.get("/training/active", response_model=ActiveTrainingResponse)
|
||||
@app.get("/training/active", response_model=ActiveTrainingResponse, dependencies=[Depends(verify_api_key)])
|
||||
async def training_active():
|
||||
"""
|
||||
Return whether a training run is currently active and its run_id.
|
||||
|
|
@ -1245,7 +1256,7 @@ class DeleteRunResponse(BaseModel):
|
|||
deleted: bool
|
||||
|
||||
|
||||
@app.delete("/training/runs/{run_id}", response_model=DeleteRunResponse)
|
||||
@app.delete("/training/runs/{run_id}", response_model=DeleteRunResponse, dependencies=[Depends(verify_api_key)])
|
||||
async def delete_training_run(run_id: str):
|
||||
"""
|
||||
Delete a training run record and its model artifact.
|
||||
|
|
@ -1312,7 +1323,7 @@ async def delete_training_run(run_id: str):
|
|||
return DeleteRunResponse(run_id=run_id, deleted=True)
|
||||
|
||||
|
||||
@app.get("/training/dataset-info", response_model=DatasetInfoResponse)
|
||||
@app.get("/training/dataset-info", response_model=DatasetInfoResponse, dependencies=[Depends(verify_api_key)])
|
||||
async def training_dataset_info():
|
||||
"""
|
||||
Return information about the labeled training dataset.
|
||||
|
|
@ -1362,7 +1373,7 @@ class BuildDatasetResponse(BaseModel):
|
|||
labeled_path: str
|
||||
|
||||
|
||||
@app.post("/training/build-dataset", response_model=BuildDatasetResponse)
|
||||
@app.post("/training/build-dataset", response_model=BuildDatasetResponse, dependencies=[Depends(verify_api_key)])
|
||||
async def training_build_dataset():
|
||||
"""
|
||||
Build the labeled training dataset from database annotations.
|
||||
|
|
@ -1405,7 +1416,7 @@ class ModelLoadResponse(BaseModel):
|
|||
_model_swap_lock = threading.Lock()
|
||||
|
||||
|
||||
@app.post("/model/load", response_model=ModelLoadResponse)
|
||||
@app.post("/model/load", response_model=ModelLoadResponse, dependencies=[Depends(verify_api_key)])
|
||||
async def model_load(request: ModelLoadRequest):
|
||||
"""
|
||||
Load a trained model from a completed training run.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue