candle-annotator/services/ml/generate_talib_annotations.py
Marko Djordjevic 847ff67986 feat(ml): add TA-Lib annotation generation and import workflow
Add complete workflow for using TA-Lib to bootstrap training data:

- generate_talib_annotations.py: Python script to run TA-Lib CDL* functions
  and output span annotations in UI-compatible format
- import_talib_annotations.ts: TypeScript script to import generated
  annotations into the UI database with auto-label-type creation
- npm script 'import-annotations' for easy execution
- TALIB_WORKFLOW.md: Comprehensive guide covering the full cycle:
  * Generate patterns with TA-Lib
  * Import into UI
  * Review and edit in browser
  * Export and train model
  * Compare predictions with TA-Lib detections
  * Iterate for improvement

This enables the intended workflow: use TA-Lib for initial annotations,
manually refine them, then train a model that learns from corrections.
2026-02-15 19:18:28 +01:00

267 lines
8.9 KiB
Python

#!/usr/bin/env python3
"""
Generate span annotations from TA-Lib candlestick pattern functions.
This script runs TA-Lib CDL* functions on OHLCV data and outputs
annotations in a format that can be imported into the Candle Annotator UI.
Usage:
python generate_talib_annotations.py --input data/raw/OHLCV.csv --output talib_annotations.json
"""
import argparse
import json
import logging
from pathlib import Path
from typing import List, Dict, Any
import pandas as pd
import numpy as np
try:
import talib
except ImportError:
print("ERROR: TA-Lib not installed. Install with: pip install TA-Lib")
print("Note: You may need to install the C library first. See DEPLOYMENT.md")
exit(1)
logging.basicConfig(level=logging.INFO, format='[%(levelname)s] %(message)s')
logger = logging.getLogger(__name__)
# TA-Lib candlestick pattern functions with friendly names
TALIB_PATTERNS = {
'CDLENGULFING': 'Engulfing',
'CDLHAMMER': 'Hammer',
'CDLINVERTEDHAMMER': 'Inverted Hammer',
'CDLSHOOTINGSTAR': 'Shooting Star',
'CDLDOJI': 'Doji',
'CDLDOJISTAR': 'Doji Star',
'CDLMORNINGSTAR': 'Morning Star',
'CDLEVENINGSTAR': 'Evening Star',
'CDLHARAMI': 'Harami',
'CDLHARAMICROSS': 'Harami Cross',
'CDLPIERCING': 'Piercing',
'CDLDARKCLOUDCOVER': 'Dark Cloud Cover',
'CDLTHREEWHITESOLDIERS': 'Three White Soldiers',
'CDLTHREEBLACKCROWS': 'Three Black Crows',
'CDLMARUBOZU': 'Marubozu',
'CDLSPINNINGTOP': 'Spinning Top',
'CDL3BLACKCROWS': 'Three Black Crows',
'CDL3WHITESOLDIERS': 'Three White Soldiers',
'CDLABANDONEDBABY': 'Abandoned Baby',
'CDLADVANCEBLOCK': 'Advance Block',
'CDLBELTHOLD': 'Belt Hold',
'CDLBREAKAWAY': 'Breakaway',
'CDLCLOSINGMARUBOZU': 'Closing Marubozu',
'CDLCONCEALBABYSWALL': 'Concealing Baby Swallow',
'CDLCOUNTERATTACK': 'Counterattack',
'CDLDRAGONFLYDOJI': 'Dragonfly Doji',
'CDLGAPSIDESIDEWHITE': 'Up/Down Gap Side-by-Side White Lines',
'CDLGRAVESTONEDOJI': 'Gravestone Doji',
'CDLHANGINGMAN': 'Hanging Man',
'CDLHIGHWAVE': 'High Wave',
'CDLHIKKAKE': 'Hikkake',
'CDLHIKKAKEMOD': 'Modified Hikkake',
'CDLHOMINGPIGEON': 'Homing Pigeon',
'CDLIDENTICAL3CROWS': 'Identical Three Crows',
'CDLINNECK': 'In-Neck',
'CDLKICKING': 'Kicking',
'CDLKICKINGBYLENGTH': 'Kicking by Length',
'CDLLADDERBOTTOM': 'Ladder Bottom',
'CDLLONGLEGGEDDOJI': 'Long-Legged Doji',
'CDLLONGLINE': 'Long Line',
'CDLMATCHINGLOW': 'Matching Low',
'CDLMATHOLD': 'Mat Hold',
'CDLMORNINGDOJISTAR': 'Morning Doji Star',
'CDLONNECK': 'On-Neck',
'CDLRISEFALL3METHODS': 'Rising/Falling Three Methods',
'CDLSEPARATINGLINES': 'Separating Lines',
'CDLSHORTLINE': 'Short Line',
'CDLSTALLEDPATTERN': 'Stalled Pattern',
'CDLSTICKSANDWICH': 'Stick Sandwich',
'CDLTAKURI': 'Takuri',
'CDLTASUKIGAP': 'Tasuki Gap',
'CDLTHRUSTING': 'Thrusting',
'CDLTRISTAR': 'Tristar',
'CDLUNIQUE3RIVER': 'Unique Three River',
'CDLUPSIDEGAP2CROWS': 'Upside Gap Two Crows',
'CDLXSIDEGAP3METHODS': 'Upside/Downside Gap Three Methods',
}
def load_ohlcv(input_path: str) -> pd.DataFrame:
"""
Load OHLCV data from CSV file.
Expected columns: time, open, high, low, close[, volume]
"""
logger.info(f"Loading OHLCV data from {input_path}")
df = pd.read_csv(input_path)
required_cols = ['time', 'open', 'high', 'low', 'close']
missing = [col for col in required_cols if col not in df.columns]
if missing:
raise ValueError(f"Missing required columns: {missing}")
logger.info(f"Loaded {len(df)} candles")
return df
def detect_talib_patterns(df: pd.DataFrame, patterns: List[str] = None) -> pd.DataFrame:
"""
Run TA-Lib CDL* functions on OHLCV data.
Args:
df: DataFrame with open, high, low, close columns
patterns: List of pattern names to detect (default: all)
Returns:
DataFrame with pattern detection results
"""
if patterns is None:
patterns = list(TALIB_PATTERNS.keys())
results = df[['time', 'open', 'high', 'low', 'close']].copy()
open_prices = df['open'].values
high_prices = df['high'].values
low_prices = df['low'].values
close_prices = df['close'].values
logger.info(f"Running {len(patterns)} TA-Lib pattern detection functions...")
for pattern_func in patterns:
if not hasattr(talib, pattern_func):
logger.warning(f"Unknown TA-Lib function: {pattern_func}")
continue
try:
func = getattr(talib, pattern_func)
pattern_values = func(open_prices, high_prices, low_prices, close_prices)
results[pattern_func] = pattern_values
except Exception as e:
logger.error(f"Error running {pattern_func}: {e}")
results[pattern_func] = 0
return results
def create_span_annotations(df: pd.DataFrame, min_confidence: int = 100) -> List[Dict[str, Any]]:
"""
Convert TA-Lib pattern detection results to span annotations.
TA-Lib returns:
- 0: No pattern
- 100: Bullish pattern
- -100: Bearish pattern
Args:
df: DataFrame with pattern detection columns
min_confidence: Minimum absolute confidence (100 = only perfect matches)
Returns:
List of annotation dicts
"""
annotations = []
pattern_cols = [col for col in df.columns if col.startswith('CDL')]
for idx, row in df.iterrows():
for pattern_col in pattern_cols:
pattern_value = row[pattern_col]
# Skip if no pattern detected
if pattern_value == 0 or abs(pattern_value) < min_confidence:
continue
# Determine label
friendly_name = TALIB_PATTERNS.get(pattern_col, pattern_col)
if pattern_value > 0:
label = f"Bullish {friendly_name}"
else:
label = f"Bearish {friendly_name}"
# Create annotation
# TA-Lib patterns are single-candle or small multi-candle patterns
# We'll use a 3-candle span centered on the detection point
start_idx = max(0, idx - 1)
end_idx = min(len(df) - 1, idx + 1)
start_time = int(df.iloc[start_idx]['time'])
end_time = int(df.iloc[end_idx]['time'])
annotation = {
'start_time': start_time,
'end_time': end_time,
'label': label,
'confidence': abs(pattern_value) / 100.0, # Normalize to 0-1
'source': 'programmatic',
'notes': f'TA-Lib {pattern_col} detection',
}
annotations.append(annotation)
return annotations
def save_annotations(annotations: List[Dict[str, Any]], output_path: str):
"""
Save annotations to JSON file in format compatible with the UI.
"""
output_data = {
'annotations': annotations,
'metadata': {
'source': 'talib',
'count': len(annotations),
}
}
output_file = Path(output_path)
output_file.parent.mkdir(parents=True, exist_ok=True)
with open(output_file, 'w') as f:
json.dump(output_data, f, indent=2)
logger.info(f"Saved {len(annotations)} annotations to {output_path}")
def main():
parser = argparse.ArgumentParser(description='Generate span annotations from TA-Lib patterns')
parser.add_argument('--input', '-i', required=True, help='Input OHLCV CSV file')
parser.add_argument('--output', '-o', required=True, help='Output JSON file')
parser.add_argument('--min-confidence', type=int, default=100,
help='Minimum confidence (0-100, default: 100 = perfect matches only)')
parser.add_argument('--patterns', nargs='+',
help='Specific patterns to detect (default: all)')
args = parser.parse_args()
# Load data
df = load_ohlcv(args.input)
# Detect patterns
patterns = args.patterns if args.patterns else list(TALIB_PATTERNS.keys())
results_df = detect_talib_patterns(df, patterns)
# Create annotations
annotations = create_span_annotations(results_df, min_confidence=args.min_confidence)
logger.info(f"Found {len(annotations)} pattern annotations")
# Show label distribution
from collections import Counter
label_counts = Counter(ann['label'] for ann in annotations)
logger.info("Pattern distribution:")
for label, count in label_counts.most_common():
logger.info(f" {label}: {count}")
# Save
save_annotations(annotations, args.output)
logger.info("\nNext steps:")
logger.info(f"1. Review/edit annotations: Use the import script (coming soon)")
logger.info(f"2. Or use directly for training: Copy to services/ml/data/annotations/export.json")
if __name__ == '__main__':
main()