feat: add FastAPI pattern detection endpoints (Section 1)

- Extract CDL pattern detection logic into services/ml/app/patterns.py
  with TALIB_PATTERNS dict, get_available_patterns(), validate_pattern_names(),
  and detect_patterns(candles, pattern_names) functions
- Add GET /patterns/available endpoint returning all 54 supported CDL pattern
  names with display names
- Add POST /patterns/detect endpoint accepting {candles, patterns}, running
  selected CDL functions, returning span annotations with source "talib"
- Add input validation: reject invalid pattern names with HTTP 400,
  treat empty patterns list as "run all"

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
Marko Djordjevic 2026-02-17 18:34:14 +01:00
parent 38df874255
commit b8e649e333
3 changed files with 320 additions and 0 deletions

View file

@ -22,6 +22,12 @@ import mlflow.xgboost
from app.config import load_config, PipelineConfig
from app.preprocessing import preprocess_candles, extract_feature_columns
from app.patterns import (
TALIB_PATTERNS,
get_available_patterns,
validate_pattern_names,
detect_patterns,
)
# Configure logging
logging.basicConfig(
@ -741,6 +747,85 @@ async def predict_batch(request: BatchPredictRequest):
)
# ---------------------------------------------------------------------------
# Pattern Detection Endpoints
# ---------------------------------------------------------------------------
class PatternInfo(BaseModel):
"""A single supported CDL pattern."""
function_name: str
display_name: str
class DetectPatternsRequest(BaseModel):
"""Request model for POST /patterns/detect."""
candles: List[CandleData] = Field(..., min_length=1, description="Array of candle data")
patterns: List[str] = Field(
default=[],
description="CDL function names to run. Empty list means run all.",
)
class SpanAnnotation(BaseModel):
"""A span annotation returned by pattern detection."""
start_time: int
end_time: int
label: str
confidence: float = Field(..., ge=0.0, le=1.0)
source: str
notes: str
class DetectPatternsResponse(BaseModel):
"""Response model for POST /patterns/detect."""
annotations: List[SpanAnnotation]
metadata: Dict[str, Any]
@app.get("/patterns/available", response_model=List[PatternInfo])
async def patterns_available():
"""
Return all supported CDL pattern names with display names.
"""
return get_available_patterns()
@app.post("/patterns/detect", response_model=DetectPatternsResponse)
async def patterns_detect(request: DetectPatternsRequest):
"""
Detect TA-Lib CDL patterns on provided candle data.
- Empty ``patterns`` list runs all available CDL functions.
- Invalid pattern names return HTTP 400.
"""
# 1.4 Validate pattern names
if request.patterns:
invalid = validate_pattern_names(request.patterns)
if invalid:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid pattern name(s): {', '.join(invalid)}. "
f"Use GET /patterns/available to see supported patterns.",
)
candles_data = [c.model_dump() for c in request.candles]
try:
raw_annotations = detect_patterns(candles_data, request.patterns or None)
except RuntimeError as exc:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=str(exc),
)
annotations = [SpanAnnotation(**ann) for ann in raw_annotations]
return DetectPatternsResponse(
annotations=annotations,
metadata={"source": "talib", "count": len(annotations)},
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)