222 lines
7.1 KiB
Python
222 lines
7.1 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Generate span annotations from TA-Lib candlestick pattern functions.
|
|
|
|
This script runs TA-Lib CDL* functions on OHLCV data and outputs
|
|
annotations in a format that can be imported into the Candle Annotator UI.
|
|
|
|
Usage:
|
|
python generate_talib_annotations.py --input data/raw/OHLCV.csv --output talib_annotations.json
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import logging
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import List, Dict, Any
|
|
|
|
import pandas as pd
|
|
import numpy as np
|
|
|
|
try:
|
|
import talib
|
|
except ImportError:
|
|
print("ERROR: TA-Lib not installed. Install with: pip install TA-Lib")
|
|
print("Note: You may need to install the C library first. See DEPLOYMENT.md")
|
|
exit(1)
|
|
|
|
# Add app directory to path to import from app.patterns
|
|
sys.path.insert(0, str(Path(__file__).parent / 'app'))
|
|
|
|
from patterns import TALIB_PATTERNS
|
|
|
|
logging.basicConfig(level=logging.INFO, format='[%(levelname)s] %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def load_ohlcv(input_path: str) -> pd.DataFrame:
|
|
"""
|
|
Load OHLCV data from CSV file.
|
|
|
|
Expected columns: time, open, high, low, close[, volume]
|
|
Time can be Unix timestamp or date string.
|
|
"""
|
|
logger.info(f"Loading OHLCV data from {input_path}")
|
|
df = pd.read_csv(input_path)
|
|
|
|
required_cols = ['time', 'open', 'high', 'low', 'close']
|
|
missing = [col for col in required_cols if col not in df.columns]
|
|
if missing:
|
|
raise ValueError(f"Missing required columns: {missing}")
|
|
|
|
# Convert time to Unix timestamp if it's a date string
|
|
if df['time'].dtype == 'object':
|
|
logger.info("Converting date strings to Unix timestamps...")
|
|
df['time'] = pd.to_datetime(df['time']).astype(int) // 10**9
|
|
|
|
# Ensure time is integer
|
|
df['time'] = df['time'].astype(int)
|
|
|
|
logger.info(f"Loaded {len(df)} candles")
|
|
return df
|
|
|
|
|
|
def detect_talib_patterns(df: pd.DataFrame, patterns: List[str] = None) -> pd.DataFrame:
|
|
"""
|
|
Run TA-Lib CDL* functions on OHLCV data.
|
|
|
|
Args:
|
|
df: DataFrame with open, high, low, close columns
|
|
patterns: List of pattern names to detect (default: all)
|
|
|
|
Returns:
|
|
DataFrame with pattern detection results
|
|
"""
|
|
if patterns is None:
|
|
patterns = list(TALIB_PATTERNS.keys())
|
|
|
|
results = df[['time', 'open', 'high', 'low', 'close']].copy()
|
|
|
|
open_prices = df['open'].values
|
|
high_prices = df['high'].values
|
|
low_prices = df['low'].values
|
|
close_prices = df['close'].values
|
|
|
|
logger.info(f"Running {len(patterns)} TA-Lib pattern detection functions...")
|
|
|
|
for pattern_func in patterns:
|
|
if not hasattr(talib, pattern_func):
|
|
logger.warning(f"Unknown TA-Lib function: {pattern_func}")
|
|
continue
|
|
|
|
try:
|
|
func = getattr(talib, pattern_func)
|
|
pattern_values = func(open_prices, high_prices, low_prices, close_prices)
|
|
results[pattern_func] = pattern_values
|
|
except Exception as e:
|
|
logger.error(f"Error running {pattern_func}: {e}")
|
|
results[pattern_func] = 0
|
|
|
|
return results
|
|
|
|
|
|
def create_span_annotations(df: pd.DataFrame, min_confidence: int = 100) -> List[Dict[str, Any]]:
|
|
"""
|
|
Convert TA-Lib pattern detection results to span annotations.
|
|
|
|
TA-Lib returns:
|
|
- 0: No pattern
|
|
- 100: Bullish pattern
|
|
- -100: Bearish pattern
|
|
|
|
Args:
|
|
df: DataFrame with pattern detection columns
|
|
min_confidence: Minimum absolute confidence (100 = only perfect matches)
|
|
|
|
Returns:
|
|
List of annotation dicts
|
|
"""
|
|
annotations = []
|
|
|
|
pattern_cols = [col for col in df.columns if col.startswith('CDL')]
|
|
|
|
for idx, row in df.iterrows():
|
|
for pattern_col in pattern_cols:
|
|
pattern_value = row[pattern_col]
|
|
|
|
# Skip if no pattern detected
|
|
if pattern_value == 0 or abs(pattern_value) < min_confidence:
|
|
continue
|
|
|
|
# Determine label
|
|
friendly_name = TALIB_PATTERNS.get(pattern_col, pattern_col)
|
|
if pattern_value > 0:
|
|
label = f"Bullish {friendly_name}"
|
|
else:
|
|
label = f"Bearish {friendly_name}"
|
|
|
|
# Create annotation
|
|
# TA-Lib patterns are single-candle or small multi-candle patterns
|
|
# We'll use a 3-candle span centered on the detection point
|
|
start_idx = max(0, idx - 1)
|
|
end_idx = min(len(df) - 1, idx + 1)
|
|
|
|
start_time = int(df.iloc[start_idx]['time'])
|
|
end_time = int(df.iloc[end_idx]['time'])
|
|
|
|
annotation = {
|
|
'start_time': start_time,
|
|
'end_time': end_time,
|
|
'label': label,
|
|
'confidence': abs(pattern_value) / 100.0, # Normalize to 0-1
|
|
'source': 'programmatic',
|
|
'notes': f'TA-Lib {pattern_col} detection',
|
|
}
|
|
|
|
annotations.append(annotation)
|
|
|
|
return annotations
|
|
|
|
|
|
def save_annotations(annotations: List[Dict[str, Any]], output_path: str):
|
|
"""
|
|
Save annotations to JSON file in format compatible with the UI.
|
|
"""
|
|
output_data = {
|
|
'annotations': annotations,
|
|
'metadata': {
|
|
'source': 'talib',
|
|
'count': len(annotations),
|
|
}
|
|
}
|
|
|
|
output_file = Path(output_path)
|
|
output_file.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
with open(output_file, 'w') as f:
|
|
json.dump(output_data, f, indent=2)
|
|
|
|
logger.info(f"Saved {len(annotations)} annotations to {output_path}")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Generate span annotations from TA-Lib patterns')
|
|
parser.add_argument('--input', '-i', required=True, help='Input OHLCV CSV file')
|
|
parser.add_argument('--output', '-o', required=True, help='Output JSON file')
|
|
parser.add_argument('--min-confidence', type=int, default=100,
|
|
help='Minimum confidence (0-100, default: 100 = perfect matches only)')
|
|
parser.add_argument('--patterns', nargs='+',
|
|
help='Specific patterns to detect (default: all)')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Load data
|
|
df = load_ohlcv(args.input)
|
|
|
|
# Detect patterns
|
|
patterns = args.patterns if args.patterns else list(TALIB_PATTERNS.keys())
|
|
results_df = detect_talib_patterns(df, patterns)
|
|
|
|
# Create annotations
|
|
annotations = create_span_annotations(results_df, min_confidence=args.min_confidence)
|
|
|
|
logger.info(f"Found {len(annotations)} pattern annotations")
|
|
|
|
# Show label distribution
|
|
from collections import Counter
|
|
label_counts = Counter(ann['label'] for ann in annotations)
|
|
logger.info("Pattern distribution:")
|
|
for label, count in label_counts.most_common():
|
|
logger.info(f" {label}: {count}")
|
|
|
|
# Save
|
|
save_annotations(annotations, args.output)
|
|
|
|
logger.info("\nNext steps:")
|
|
logger.info(f"1. Review/edit annotations: Use the import script (coming soon)")
|
|
logger.info(f"2. Or use directly for training: Copy to services/ml/data/annotations/export.json")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|