270 lines
9.8 KiB
Python
270 lines
9.8 KiB
Python
"""
|
|
TA-Lib technical indicator computation.
|
|
|
|
Computes technical indicators from raw OHLCV data using TA-Lib.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Dict, List
|
|
|
|
import pandas as pd
|
|
import numpy as np
|
|
|
|
from app.config import TALibIndicator
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def compute_talib_indicators(
|
|
df: pd.DataFrame,
|
|
indicators: List[TALibIndicator]
|
|
) -> pd.DataFrame:
|
|
"""
|
|
Compute TA-Lib indicators and append as columns.
|
|
|
|
Args:
|
|
df: DataFrame with OHLCV columns (open, high, low, close, volume)
|
|
indicators: List of indicator configurations from pipeline config
|
|
|
|
Returns:
|
|
DataFrame with original columns + computed indicator columns
|
|
|
|
Raises:
|
|
ImportError: If TA-Lib is not installed
|
|
ValueError: If required OHLCV columns are missing
|
|
AttributeError: If an indicator name is not valid
|
|
"""
|
|
# Check if TA-Lib is installed
|
|
try:
|
|
import talib
|
|
except ImportError:
|
|
raise ImportError(
|
|
"TA-Lib is not installed. Please install the TA-Lib C library first:\n\n"
|
|
" Ubuntu/Debian: sudo apt-get install libta-lib-dev\n"
|
|
" macOS: brew install ta-lib\n"
|
|
" Windows: Download from https://www.ta-lib.org/\n\n"
|
|
"Then install the Python wrapper: pip install TA-Lib\n"
|
|
)
|
|
|
|
# Validate required columns (volume is optional)
|
|
required_cols = ['open', 'high', 'low', 'close']
|
|
missing_cols = [col for col in required_cols if col not in df.columns]
|
|
if missing_cols:
|
|
raise ValueError(f"Missing required OHLC columns: {missing_cols}")
|
|
|
|
has_volume = 'volume' in df.columns
|
|
|
|
# Make a copy to avoid modifying the original
|
|
result_df = df.copy()
|
|
|
|
# Extract OHLCV arrays (TA-Lib expects float64/double numpy arrays)
|
|
open_prices = df['open'].values.astype(np.float64)
|
|
high_prices = df['high'].values.astype(np.float64)
|
|
low_prices = df['low'].values.astype(np.float64)
|
|
close_prices = df['close'].values.astype(np.float64)
|
|
volume = df['volume'].values.astype(np.float64) if has_volume else None
|
|
|
|
logger.info(f"Computing {len(indicators)} TA-Lib indicators")
|
|
|
|
for indicator_config in indicators:
|
|
indicator_name = indicator_config.name.upper()
|
|
params = indicator_config.params
|
|
|
|
# Skip volume-dependent indicators when volume data is absent
|
|
volume_indicators = {'OBV', 'AD', 'ADOSC', 'MFI'}
|
|
if indicator_name in volume_indicators and volume is None:
|
|
logger.warning(f"Skipping {indicator_name}: requires volume data")
|
|
continue
|
|
|
|
# Check if indicator function exists
|
|
if not hasattr(talib, indicator_name):
|
|
raise AttributeError(
|
|
f"TA-Lib indicator '{indicator_name}' not found. "
|
|
f"Check TA-Lib documentation for valid indicator names."
|
|
)
|
|
|
|
indicator_func = getattr(talib, indicator_name)
|
|
|
|
try:
|
|
# Call the TA-Lib function with OHLCV data and parameters
|
|
result = _call_talib_function(
|
|
indicator_func,
|
|
indicator_name,
|
|
open_prices,
|
|
high_prices,
|
|
low_prices,
|
|
close_prices,
|
|
volume,
|
|
params
|
|
)
|
|
|
|
# Add result columns to DataFrame
|
|
result_df = _add_indicator_columns(
|
|
result_df,
|
|
indicator_name,
|
|
result,
|
|
params
|
|
)
|
|
|
|
logger.debug(f"Computed indicator: {indicator_name} with params {params}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to compute {indicator_name}: {e}")
|
|
raise
|
|
|
|
# Count new columns added
|
|
new_cols = len(result_df.columns) - len(df.columns)
|
|
logger.info(f"Added {new_cols} indicator columns")
|
|
|
|
return result_df
|
|
|
|
|
|
def _call_talib_function(
|
|
func,
|
|
name: str,
|
|
open_prices: np.ndarray,
|
|
high_prices: np.ndarray,
|
|
low_prices: np.ndarray,
|
|
close_prices: np.ndarray,
|
|
volume: np.ndarray,
|
|
params: Dict
|
|
):
|
|
"""
|
|
Call a TA-Lib function with appropriate inputs.
|
|
|
|
Different indicators require different inputs (price only, OHLC, OHLCV, etc.)
|
|
This function handles the common cases.
|
|
"""
|
|
# Price indicators (close only)
|
|
price_only = ['RSI', 'EMA', 'SMA', 'WMA', 'TEMA', 'DEMA', 'TRIMA', 'KAMA',
|
|
'MAMA', 'T3', 'CMO', 'MOM', 'ROC', 'ROCP', 'ROCR',
|
|
'TRIX', 'WILLR', 'DX', 'MINUS_DI', 'PLUS_DI', 'MINUS_DM',
|
|
'PLUS_DM', 'TSF', 'HT_DCPERIOD', 'HT_DCPHASE', 'HT_PHASOR',
|
|
'HT_SINE', 'HT_TRENDMODE']
|
|
|
|
# High-Low-Close indicators
|
|
hlc_indicators = ['ULTOSC', 'NATR', 'CCI']
|
|
|
|
# OHLC indicators
|
|
ohlc_indicators = ['CDL2CROWS', 'CDL3BLACKCROWS', 'CDL3INSIDE', 'CDL3LINESTRIKE',
|
|
'CDL3OUTSIDE', 'CDL3STARSINSOUTH', 'CDL3WHITESOLDIERS',
|
|
'CDLABANDONEDBABY', 'CDLADVANCEBLOCK', 'CDLBELTHOLD',
|
|
'CDLBREAKAWAY', 'CDLCLOSINGMARUBOZU', 'CDLCONCEALBABYSWALL',
|
|
'CDLCOUNTERATTACK', 'CDLDARKCLOUDCOVER', 'CDLDOJI',
|
|
'CDLDOJISTAR', 'CDLDRAGONFLYDOJI', 'CDLENGULFING',
|
|
'CDLEVENINGDOJISTAR', 'CDLEVENINGSTAR', 'CDLGAPSIDESIDEWHITE',
|
|
'CDLGRAVESTONEDOJI', 'CDLHAMMER', 'CDLHANGINGMAN',
|
|
'CDLHARAMI', 'CDLHARAMICROSS', 'CDLHIGHWAVE', 'CDLHIKKAKE',
|
|
'CDLHIKKAKEMOD', 'CDLHOMINGPIGEON', 'CDLIDENTICAL3CROWS',
|
|
'CDLINNECK', 'CDLINVERTEDHAMMER', 'CDLKICKING',
|
|
'CDLKICKINGBYLENGTH', 'CDLLADDERBOTTOM', 'CDLLONGLEGGEDDOJI',
|
|
'CDLLONGLINE', 'CDLMARUBOZU', 'CDLMATCHINGLOW',
|
|
'CDLMATHOLD', 'CDLMORNINGDOJISTAR', 'CDLMORNINGSTAR',
|
|
'CDLONNECK', 'CDLPIERCING', 'CDLRICKSHAWMAN',
|
|
'CDLRISEFALL3METHODS', 'CDLSEPARATINGLINES', 'CDLSHOOTINGSTAR',
|
|
'CDLSHORTLINE', 'CDLSPINNINGTOP', 'CDLSTALLEDPATTERN',
|
|
'CDLSTICKSANDWICH', 'CDLTAKURI', 'CDLTASUKIGAP',
|
|
'CDLTHRUSTING', 'CDLTRISTAR', 'CDLUNIQUE3RIVER',
|
|
'CDLUPSIDEGAP2CROWS', 'CDLXSIDEGAP3METHODS']
|
|
|
|
# Volume indicators
|
|
volume_indicators = ['OBV', 'AD', 'ADOSC', 'MFI']
|
|
|
|
# High-Low indicators
|
|
hl_indicators = ['AROON', 'AROONOSC', 'MINUS_DM', 'PLUS_DM']
|
|
|
|
if name in price_only:
|
|
return func(close_prices, **params)
|
|
elif name in hlc_indicators:
|
|
return func(high_prices, low_prices, close_prices, **params)
|
|
elif name in ohlc_indicators:
|
|
return func(open_prices, high_prices, low_prices, close_prices, **params)
|
|
elif name in volume_indicators:
|
|
if name == 'OBV':
|
|
return func(close_prices, volume, **params)
|
|
elif name in ['AD', 'ADOSC']:
|
|
return func(high_prices, low_prices, close_prices, volume, **params)
|
|
elif name == 'MFI':
|
|
return func(high_prices, low_prices, close_prices, volume, **params)
|
|
elif name in hl_indicators:
|
|
return func(high_prices, low_prices, **params)
|
|
else:
|
|
# Default: try with high, low, close (most common)
|
|
try:
|
|
return func(high_prices, low_prices, close_prices, **params)
|
|
except TypeError:
|
|
# If that fails, try with just close
|
|
return func(close_prices, **params)
|
|
|
|
|
|
def _add_indicator_columns(
|
|
df: pd.DataFrame,
|
|
indicator_name: str,
|
|
result,
|
|
params: Dict
|
|
) -> pd.DataFrame:
|
|
"""
|
|
Add indicator result(s) as column(s) to DataFrame.
|
|
|
|
Handles single-output and multi-output indicators.
|
|
Column names follow the pattern: {indicator_lower}_{param} or just {indicator_lower}
|
|
"""
|
|
indicator_lower = indicator_name.lower()
|
|
|
|
# Check if result is a tuple (multi-output indicator like MACD, BBANDS, STOCH)
|
|
if isinstance(result, tuple):
|
|
# Multi-output indicator
|
|
output_names = _get_output_names(indicator_name, len(result))
|
|
|
|
for i, (output_name, values) in enumerate(zip(output_names, result)):
|
|
col_name = f"{indicator_lower}_{output_name}"
|
|
df[col_name] = values
|
|
else:
|
|
# Single-output indicator
|
|
# Add parameter to column name if there's a significant param
|
|
if params:
|
|
# Use the first parameter value in the column name
|
|
# Common params: timeperiod, fastperiod, etc.
|
|
param_key = list(params.keys())[0]
|
|
param_val = params[param_key]
|
|
col_name = f"{indicator_lower}_{param_val}"
|
|
else:
|
|
col_name = indicator_lower
|
|
|
|
df[col_name] = result
|
|
|
|
return df
|
|
|
|
|
|
def _get_output_names(indicator_name: str, num_outputs: int) -> List[str]:
|
|
"""
|
|
Get output names for multi-output indicators.
|
|
|
|
Args:
|
|
indicator_name: Name of the indicator (e.g., "MACD", "BBANDS")
|
|
num_outputs: Number of outputs from the indicator
|
|
|
|
Returns:
|
|
List of output names (e.g., ["macd", "signal", "hist"])
|
|
"""
|
|
# Known multi-output indicators
|
|
output_mappings = {
|
|
'MACD': ['macd', 'signal', 'hist'],
|
|
'MACDEXT': ['macd', 'signal', 'hist'],
|
|
'MACDFIX': ['macd', 'signal', 'hist'],
|
|
'BBANDS': ['upper', 'middle', 'lower'],
|
|
'STOCH': ['slowk', 'slowd'],
|
|
'STOCHF': ['fastk', 'fastd'],
|
|
'STOCHRSI': ['fastk', 'fastd'],
|
|
'AROON': ['aroondown', 'aroonup'],
|
|
'HT_PHASOR': ['inphase', 'quadrature'],
|
|
'HT_SINE': ['sine', 'leadsine'],
|
|
'MAMA': ['mama', 'fama'],
|
|
}
|
|
|
|
if indicator_name in output_mappings:
|
|
return output_mappings[indicator_name]
|
|
|
|
# Default: generic names
|
|
return [f"output{i}" for i in range(num_outputs)]
|