""" TA-Lib CDL pattern detection module. Extracted from generate_talib_annotations.py for use in the FastAPI service. Provides detect_patterns() for use in pattern detection endpoints. """ from typing import List, Dict, Any, Optional import logging import numpy as np try: import talib except ImportError: talib = None logger = logging.getLogger(__name__) TALIB_PATTERNS: Dict[str, str] = { 'CDLENGULFING': 'Engulfing', 'CDLHAMMER': 'Hammer', 'CDLINVERTEDHAMMER': 'Inverted Hammer', 'CDLSHOOTINGSTAR': 'Shooting Star', 'CDLDOJI': 'Doji', 'CDLDOJISTAR': 'Doji Star', 'CDLMORNINGSTAR': 'Morning Star', 'CDLEVENINGSTAR': 'Evening Star', 'CDLHARAMI': 'Harami', 'CDLHARAMICROSS': 'Harami Cross', 'CDLPIERCING': 'Piercing', 'CDLDARKCLOUDCOVER': 'Dark Cloud Cover', 'CDLMARUBOZU': 'Marubozu', 'CDLSPINNINGTOP': 'Spinning Top', 'CDL3WHITESOLDIERS': 'Three White Soldiers', 'CDL3BLACKCROWS': 'Three Black Crows', 'CDLABANDONEDBABY': 'Abandoned Baby', 'CDLADVANCEBLOCK': 'Advance Block', 'CDLBELTHOLD': 'Belt Hold', 'CDLBREAKAWAY': 'Breakaway', 'CDLCLOSINGMARUBOZU': 'Closing Marubozu', 'CDLCONCEALBABYSWALL': 'Concealing Baby Swallow', 'CDLCOUNTERATTACK': 'Counterattack', 'CDLDRAGONFLYDOJI': 'Dragonfly Doji', 'CDLGAPSIDESIDEWHITE': 'Up/Down Gap Side-by-Side White Lines', 'CDLGRAVESTONEDOJI': 'Gravestone Doji', 'CDLHANGINGMAN': 'Hanging Man', 'CDLHIGHWAVE': 'High Wave', 'CDLHIKKAKE': 'Hikkake', 'CDLHIKKAKEMOD': 'Modified Hikkake', 'CDLHOMINGPIGEON': 'Homing Pigeon', 'CDLIDENTICAL3CROWS': 'Identical Three Crows', 'CDLINNECK': 'In-Neck', 'CDLKICKING': 'Kicking', 'CDLKICKINGBYLENGTH': 'Kicking by Length', 'CDLLADDERBOTTOM': 'Ladder Bottom', 'CDLLONGLEGGEDDOJI': 'Long-Legged Doji', 'CDLLONGLINE': 'Long Line', 'CDLMATCHINGLOW': 'Matching Low', 'CDLMATHOLD': 'Mat Hold', 'CDLMORNINGDOJISTAR': 'Morning Doji Star', 'CDLONNECK': 'On-Neck', 'CDLRISEFALL3METHODS': 'Rising/Falling Three Methods', 'CDLSEPARATINGLINES': 'Separating Lines', 'CDLSHORTLINE': 'Short Line', 'CDLSTALLEDPATTERN': 'Stalled Pattern', 'CDLSTICKSANDWICH': 'Stick Sandwich', 'CDLTAKURI': 'Takuri', 'CDLTASUKIGAP': 'Tasuki Gap', 'CDLTHRUSTING': 'Thrusting', 'CDLTRISTAR': 'Tristar', 'CDLUNIQUE3RIVER': 'Unique Three River', 'CDLUPSIDEGAP2CROWS': 'Upside Gap Two Crows', 'CDLXSIDEGAP3METHODS': 'Upside/Downside Gap Three Methods', } def get_available_patterns() -> List[Dict[str, str]]: """ Return all supported CDL pattern names with display names. Returns: List of dicts with function_name and display_name keys. """ return [ {"function_name": fn, "display_name": display} for fn, display in TALIB_PATTERNS.items() ] def validate_pattern_names(pattern_names: List[str]) -> List[str]: """ Validate a list of pattern names against the supported set. Args: pattern_names: List of CDL function names to validate. Returns: List of invalid pattern names (empty if all valid). """ return [name for name in pattern_names if name not in TALIB_PATTERNS] def detect_patterns( candles: List[Dict[str, Any]], pattern_names: Optional[List[str]] = None, ) -> List[Dict[str, Any]]: """ Run TA-Lib CDL pattern functions on candle data and return span annotations. Args: candles: List of candle dicts with keys time, open, high, low, close. pattern_names: CDL function names to run. None or empty list means run all. Returns: List of span annotation dicts with keys: start_time, end_time, label, confidence, source, notes. """ if talib is None: raise RuntimeError("TA-Lib is not installed") if not candles: return [] if not pattern_names: pattern_names = list(TALIB_PATTERNS.keys()) times = np.array([c["time"] for c in candles], dtype=np.float64) open_prices = np.array([c["open"] for c in candles], dtype=np.float64) high_prices = np.array([c["high"] for c in candles], dtype=np.float64) low_prices = np.array([c["low"] for c in candles], dtype=np.float64) close_prices = np.array([c["close"] for c in candles], dtype=np.float64) n = len(candles) annotations = [] for pattern_func in pattern_names: if not hasattr(talib, pattern_func): logger.warning("Unknown TA-Lib function: %s", pattern_func) continue try: func = getattr(talib, pattern_func) values = func(open_prices, high_prices, low_prices, close_prices) except Exception as exc: logger.error("Error running %s: %s", pattern_func, exc) continue friendly_name = TALIB_PATTERNS.get(pattern_func, pattern_func) for idx, value in enumerate(values): if value == 0: continue if value > 0: label = f"Bullish {friendly_name}" else: label = f"Bearish {friendly_name}" start_idx = max(0, idx - 1) end_idx = min(n - 1, idx + 1) annotations.append({ "start_time": int(times[start_idx]), "end_time": int(times[end_idx]), "label": label, "confidence": abs(int(value)) / 100.0, "source": "talib", "notes": f"TA-Lib {pattern_func} detection", }) logger.info("Detected %d pattern annotations", len(annotations)) return annotations