candle-annotator/services/ml/generate_talib_annotations.py

222 lines
7.1 KiB
Python

#!/usr/bin/env python3
"""
Generate span annotations from TA-Lib candlestick pattern functions.
This script runs TA-Lib CDL* functions on OHLCV data and outputs
annotations in a format that can be imported into the Candle Annotator UI.
Usage:
python generate_talib_annotations.py --input data/raw/OHLCV.csv --output talib_annotations.json
"""
import argparse
import json
import logging
import sys
from pathlib import Path
from typing import List, Dict, Any
import pandas as pd
import numpy as np
try:
import talib
except ImportError:
print("ERROR: TA-Lib not installed. Install with: pip install TA-Lib")
print("Note: You may need to install the C library first. See DEPLOYMENT.md")
exit(1)
# Add app directory to path to import from app.patterns
sys.path.insert(0, str(Path(__file__).parent / 'app'))
from patterns import TALIB_PATTERNS
logging.basicConfig(level=logging.INFO, format='[%(levelname)s] %(message)s')
logger = logging.getLogger(__name__)
def load_ohlcv(input_path: str) -> pd.DataFrame:
"""
Load OHLCV data from CSV file.
Expected columns: time, open, high, low, close[, volume]
Time can be Unix timestamp or date string.
"""
logger.info(f"Loading OHLCV data from {input_path}")
df = pd.read_csv(input_path)
required_cols = ['time', 'open', 'high', 'low', 'close']
missing = [col for col in required_cols if col not in df.columns]
if missing:
raise ValueError(f"Missing required columns: {missing}")
# Convert time to Unix timestamp if it's a date string
if df['time'].dtype == 'object':
logger.info("Converting date strings to Unix timestamps...")
df['time'] = pd.to_datetime(df['time']).astype(int) // 10**9
# Ensure time is integer
df['time'] = df['time'].astype(int)
logger.info(f"Loaded {len(df)} candles")
return df
def detect_talib_patterns(df: pd.DataFrame, patterns: List[str] = None) -> pd.DataFrame:
"""
Run TA-Lib CDL* functions on OHLCV data.
Args:
df: DataFrame with open, high, low, close columns
patterns: List of pattern names to detect (default: all)
Returns:
DataFrame with pattern detection results
"""
if patterns is None:
patterns = list(TALIB_PATTERNS.keys())
results = df[['time', 'open', 'high', 'low', 'close']].copy()
open_prices = df['open'].values
high_prices = df['high'].values
low_prices = df['low'].values
close_prices = df['close'].values
logger.info(f"Running {len(patterns)} TA-Lib pattern detection functions...")
for pattern_func in patterns:
if not hasattr(talib, pattern_func):
logger.warning(f"Unknown TA-Lib function: {pattern_func}")
continue
try:
func = getattr(talib, pattern_func)
pattern_values = func(open_prices, high_prices, low_prices, close_prices)
results[pattern_func] = pattern_values
except Exception as e:
logger.error(f"Error running {pattern_func}: {e}")
results[pattern_func] = 0
return results
def create_span_annotations(df: pd.DataFrame, min_confidence: int = 100) -> List[Dict[str, Any]]:
"""
Convert TA-Lib pattern detection results to span annotations.
TA-Lib returns:
- 0: No pattern
- 100: Bullish pattern
- -100: Bearish pattern
Args:
df: DataFrame with pattern detection columns
min_confidence: Minimum absolute confidence (100 = only perfect matches)
Returns:
List of annotation dicts
"""
annotations = []
pattern_cols = [col for col in df.columns if col.startswith('CDL')]
for idx, row in df.iterrows():
for pattern_col in pattern_cols:
pattern_value = row[pattern_col]
# Skip if no pattern detected
if pattern_value == 0 or abs(pattern_value) < min_confidence:
continue
# Determine label
friendly_name = TALIB_PATTERNS.get(pattern_col, pattern_col)
if pattern_value > 0:
label = f"Bullish {friendly_name}"
else:
label = f"Bearish {friendly_name}"
# Create annotation
# TA-Lib patterns are single-candle or small multi-candle patterns
# We'll use a 3-candle span centered on the detection point
start_idx = max(0, idx - 1)
end_idx = min(len(df) - 1, idx + 1)
start_time = int(df.iloc[start_idx]['time'])
end_time = int(df.iloc[end_idx]['time'])
annotation = {
'start_time': start_time,
'end_time': end_time,
'label': label,
'confidence': abs(pattern_value) / 100.0, # Normalize to 0-1
'source': 'programmatic',
'notes': f'TA-Lib {pattern_col} detection',
}
annotations.append(annotation)
return annotations
def save_annotations(annotations: List[Dict[str, Any]], output_path: str):
"""
Save annotations to JSON file in format compatible with the UI.
"""
output_data = {
'annotations': annotations,
'metadata': {
'source': 'talib',
'count': len(annotations),
}
}
output_file = Path(output_path)
output_file.parent.mkdir(parents=True, exist_ok=True)
with open(output_file, 'w') as f:
json.dump(output_data, f, indent=2)
logger.info(f"Saved {len(annotations)} annotations to {output_path}")
def main():
parser = argparse.ArgumentParser(description='Generate span annotations from TA-Lib patterns')
parser.add_argument('--input', '-i', required=True, help='Input OHLCV CSV file')
parser.add_argument('--output', '-o', required=True, help='Output JSON file')
parser.add_argument('--min-confidence', type=int, default=100,
help='Minimum confidence (0-100, default: 100 = perfect matches only)')
parser.add_argument('--patterns', nargs='+',
help='Specific patterns to detect (default: all)')
args = parser.parse_args()
# Load data
df = load_ohlcv(args.input)
# Detect patterns
patterns = args.patterns if args.patterns else list(TALIB_PATTERNS.keys())
results_df = detect_talib_patterns(df, patterns)
# Create annotations
annotations = create_span_annotations(results_df, min_confidence=args.min_confidence)
logger.info(f"Found {len(annotations)} pattern annotations")
# Show label distribution
from collections import Counter
label_counts = Counter(ann['label'] for ann in annotations)
logger.info("Pattern distribution:")
for label, count in label_counts.most_common():
logger.info(f" {label}: {count}")
# Save
save_annotations(annotations, args.output)
logger.info("\nNext steps:")
logger.info(f"1. Review/edit annotations: Use the import script (coming soon)")
logger.info(f"2. Or use directly for training: Copy to services/ml/data/annotations/export.json")
if __name__ == '__main__':
main()