feat: add FastAPI pattern detection endpoints (Section 1)

- Extract CDL pattern detection logic into services/ml/app/patterns.py
  with TALIB_PATTERNS dict, get_available_patterns(), validate_pattern_names(),
  and detect_patterns(candles, pattern_names) functions
- Add GET /patterns/available endpoint returning all 54 supported CDL pattern
  names with display names
- Add POST /patterns/detect endpoint accepting {candles, patterns}, running
  selected CDL functions, returning span annotations with source "talib"
- Add input validation: reject invalid pattern names with HTTP 400,
  treat empty patterns list as "run all"

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
Marko Djordjevic 2026-02-17 18:34:14 +01:00
parent 38df874255
commit b8e649e333
3 changed files with 320 additions and 0 deletions

View file

@ -22,6 +22,12 @@ import mlflow.xgboost
from app.config import load_config, PipelineConfig
from app.preprocessing import preprocess_candles, extract_feature_columns
from app.patterns import (
TALIB_PATTERNS,
get_available_patterns,
validate_pattern_names,
detect_patterns,
)
# Configure logging
logging.basicConfig(
@ -741,6 +747,85 @@ async def predict_batch(request: BatchPredictRequest):
)
# ---------------------------------------------------------------------------
# Pattern Detection Endpoints
# ---------------------------------------------------------------------------
class PatternInfo(BaseModel):
"""A single supported CDL pattern."""
function_name: str
display_name: str
class DetectPatternsRequest(BaseModel):
"""Request model for POST /patterns/detect."""
candles: List[CandleData] = Field(..., min_length=1, description="Array of candle data")
patterns: List[str] = Field(
default=[],
description="CDL function names to run. Empty list means run all.",
)
class SpanAnnotation(BaseModel):
"""A span annotation returned by pattern detection."""
start_time: int
end_time: int
label: str
confidence: float = Field(..., ge=0.0, le=1.0)
source: str
notes: str
class DetectPatternsResponse(BaseModel):
"""Response model for POST /patterns/detect."""
annotations: List[SpanAnnotation]
metadata: Dict[str, Any]
@app.get("/patterns/available", response_model=List[PatternInfo])
async def patterns_available():
"""
Return all supported CDL pattern names with display names.
"""
return get_available_patterns()
@app.post("/patterns/detect", response_model=DetectPatternsResponse)
async def patterns_detect(request: DetectPatternsRequest):
"""
Detect TA-Lib CDL patterns on provided candle data.
- Empty ``patterns`` list runs all available CDL functions.
- Invalid pattern names return HTTP 400.
"""
# 1.4 Validate pattern names
if request.patterns:
invalid = validate_pattern_names(request.patterns)
if invalid:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid pattern name(s): {', '.join(invalid)}. "
f"Use GET /patterns/available to see supported patterns.",
)
candles_data = [c.model_dump() for c in request.candles]
try:
raw_annotations = detect_patterns(candles_data, request.patterns or None)
except RuntimeError as exc:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=str(exc),
)
annotations = [SpanAnnotation(**ann) for ann in raw_annotations]
return DetectPatternsResponse(
annotations=annotations,
metadata={"source": "talib", "count": len(annotations)},
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)

173
services/ml/app/patterns.py Normal file
View file

@ -0,0 +1,173 @@
"""
TA-Lib CDL pattern detection module.
Extracted from generate_talib_annotations.py for use in the FastAPI service.
Provides detect_patterns() for use in pattern detection endpoints.
"""
from typing import List, Dict, Any, Optional
import logging
import numpy as np
try:
import talib
except ImportError:
talib = None
logger = logging.getLogger(__name__)
TALIB_PATTERNS: Dict[str, str] = {
'CDLENGULFING': 'Engulfing',
'CDLHAMMER': 'Hammer',
'CDLINVERTEDHAMMER': 'Inverted Hammer',
'CDLSHOOTINGSTAR': 'Shooting Star',
'CDLDOJI': 'Doji',
'CDLDOJISTAR': 'Doji Star',
'CDLMORNINGSTAR': 'Morning Star',
'CDLEVENINGSTAR': 'Evening Star',
'CDLHARAMI': 'Harami',
'CDLHARAMICROSS': 'Harami Cross',
'CDLPIERCING': 'Piercing',
'CDLDARKCLOUDCOVER': 'Dark Cloud Cover',
'CDLMARUBOZU': 'Marubozu',
'CDLSPINNINGTOP': 'Spinning Top',
'CDL3WHITESOLDIERS': 'Three White Soldiers',
'CDL3BLACKCROWS': 'Three Black Crows',
'CDLABANDONEDBABY': 'Abandoned Baby',
'CDLADVANCEBLOCK': 'Advance Block',
'CDLBELTHOLD': 'Belt Hold',
'CDLBREAKAWAY': 'Breakaway',
'CDLCLOSINGMARUBOZU': 'Closing Marubozu',
'CDLCONCEALBABYSWALL': 'Concealing Baby Swallow',
'CDLCOUNTERATTACK': 'Counterattack',
'CDLDRAGONFLYDOJI': 'Dragonfly Doji',
'CDLGAPSIDESIDEWHITE': 'Up/Down Gap Side-by-Side White Lines',
'CDLGRAVESTONEDOJI': 'Gravestone Doji',
'CDLHANGINGMAN': 'Hanging Man',
'CDLHIGHWAVE': 'High Wave',
'CDLHIKKAKE': 'Hikkake',
'CDLHIKKAKEMOD': 'Modified Hikkake',
'CDLHOMINGPIGEON': 'Homing Pigeon',
'CDLIDENTICAL3CROWS': 'Identical Three Crows',
'CDLINNECK': 'In-Neck',
'CDLKICKING': 'Kicking',
'CDLKICKINGBYLENGTH': 'Kicking by Length',
'CDLLADDERBOTTOM': 'Ladder Bottom',
'CDLLONGLEGGEDDOJI': 'Long-Legged Doji',
'CDLLONGLINE': 'Long Line',
'CDLMATCHINGLOW': 'Matching Low',
'CDLMATHOLD': 'Mat Hold',
'CDLMORNINGDOJISTAR': 'Morning Doji Star',
'CDLONNECK': 'On-Neck',
'CDLRISEFALL3METHODS': 'Rising/Falling Three Methods',
'CDLSEPARATINGLINES': 'Separating Lines',
'CDLSHORTLINE': 'Short Line',
'CDLSTALLEDPATTERN': 'Stalled Pattern',
'CDLSTICKSANDWICH': 'Stick Sandwich',
'CDLTAKURI': 'Takuri',
'CDLTASUKIGAP': 'Tasuki Gap',
'CDLTHRUSTING': 'Thrusting',
'CDLTRISTAR': 'Tristar',
'CDLUNIQUE3RIVER': 'Unique Three River',
'CDLUPSIDEGAP2CROWS': 'Upside Gap Two Crows',
'CDLXSIDEGAP3METHODS': 'Upside/Downside Gap Three Methods',
}
def get_available_patterns() -> List[Dict[str, str]]:
"""
Return all supported CDL pattern names with display names.
Returns:
List of dicts with function_name and display_name keys.
"""
return [
{"function_name": fn, "display_name": display}
for fn, display in TALIB_PATTERNS.items()
]
def validate_pattern_names(pattern_names: List[str]) -> List[str]:
"""
Validate a list of pattern names against the supported set.
Args:
pattern_names: List of CDL function names to validate.
Returns:
List of invalid pattern names (empty if all valid).
"""
return [name for name in pattern_names if name not in TALIB_PATTERNS]
def detect_patterns(
candles: List[Dict[str, Any]],
pattern_names: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
"""
Run TA-Lib CDL pattern functions on candle data and return span annotations.
Args:
candles: List of candle dicts with keys time, open, high, low, close.
pattern_names: CDL function names to run. None or empty list means run all.
Returns:
List of span annotation dicts with keys:
start_time, end_time, label, confidence, source, notes.
"""
if talib is None:
raise RuntimeError("TA-Lib is not installed")
if not candles:
return []
if not pattern_names:
pattern_names = list(TALIB_PATTERNS.keys())
times = np.array([c["time"] for c in candles], dtype=np.float64)
open_prices = np.array([c["open"] for c in candles], dtype=np.float64)
high_prices = np.array([c["high"] for c in candles], dtype=np.float64)
low_prices = np.array([c["low"] for c in candles], dtype=np.float64)
close_prices = np.array([c["close"] for c in candles], dtype=np.float64)
n = len(candles)
annotations = []
for pattern_func in pattern_names:
if not hasattr(talib, pattern_func):
logger.warning("Unknown TA-Lib function: %s", pattern_func)
continue
try:
func = getattr(talib, pattern_func)
values = func(open_prices, high_prices, low_prices, close_prices)
except Exception as exc:
logger.error("Error running %s: %s", pattern_func, exc)
continue
friendly_name = TALIB_PATTERNS.get(pattern_func, pattern_func)
for idx, value in enumerate(values):
if value == 0:
continue
if value > 0:
label = f"Bullish {friendly_name}"
else:
label = f"Bearish {friendly_name}"
start_idx = max(0, idx - 1)
end_idx = min(n - 1, idx + 1)
annotations.append({
"start_time": int(times[start_idx]),
"end_time": int(times[end_idx]),
"label": label,
"confidence": abs(int(value)) / 100.0,
"source": "talib",
"notes": f"TA-Lib {pattern_func} detection",
})
logger.info("Detected %d pattern annotations", len(annotations))
return annotations