feat(ml): implement feature engineering pipeline
- Create pipeline.py with CLI argument parsing for running stages - Implement TA-Lib indicator computation with multi-output support - Add candle feature extraction (body_size, wicks, ratios, etc.) - Create custom feature loader with dynamic module import - Wire all feature engineering stages with NaN handling - Tasks completed: 2.2, 2.3, 3.1, 3.2, 3.3, 3.4, 3.5
This commit is contained in:
parent
ea339a54a7
commit
fd29ab91e0
6 changed files with 889 additions and 7 deletions
134
services/ml/features/candle_features.py
Normal file
134
services/ml/features/candle_features.py
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
"""
|
||||
Candle-derived feature extraction.
|
||||
|
||||
Computes geometric and structural features from OHLCV candlestick data.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def compute_candle_features(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Compute derived candle features for each row.
|
||||
|
||||
Features computed:
|
||||
- body_size: abs(close - open) — size of the candle body
|
||||
- body_direction: 1 if close >= open (bullish), -1 otherwise (bearish)
|
||||
- upper_wick: high - max(open, close) — upper shadow length
|
||||
- lower_wick: min(open, close) - low — lower shadow length
|
||||
- wick_ratio: upper_wick / lower_wick (0 if lower_wick is 0)
|
||||
- body_to_range: body_size / (high - low) — body as fraction of total range (0 if range is 0)
|
||||
- gap: open - previous close (0 for first candle)
|
||||
- range: high - low — total candle range
|
||||
|
||||
Args:
|
||||
df: DataFrame with OHLCV columns (open, high, low, close)
|
||||
|
||||
Returns:
|
||||
DataFrame with original columns + candle feature columns
|
||||
|
||||
Raises:
|
||||
ValueError: If required OHLCV columns are missing
|
||||
"""
|
||||
# Validate required columns
|
||||
required_cols = ['open', 'high', 'low', 'close']
|
||||
missing_cols = [col for col in required_cols if col not in df.columns]
|
||||
if missing_cols:
|
||||
raise ValueError(f"Missing required OHLC columns: {missing_cols}")
|
||||
|
||||
logger.info("Computing candle features")
|
||||
|
||||
# Make a copy to avoid modifying the original
|
||||
result_df = df.copy()
|
||||
|
||||
# Body size
|
||||
result_df['body_size'] = np.abs(result_df['close'] - result_df['open'])
|
||||
|
||||
# Body direction
|
||||
result_df['body_direction'] = np.where(
|
||||
result_df['close'] >= result_df['open'],
|
||||
1, # Bullish
|
||||
-1 # Bearish
|
||||
)
|
||||
|
||||
# Upper wick
|
||||
result_df['upper_wick'] = result_df['high'] - np.maximum(
|
||||
result_df['open'],
|
||||
result_df['close']
|
||||
)
|
||||
|
||||
# Lower wick
|
||||
result_df['lower_wick'] = np.minimum(
|
||||
result_df['open'],
|
||||
result_df['close']
|
||||
) - result_df['low']
|
||||
|
||||
# Wick ratio (handle division by zero)
|
||||
result_df['wick_ratio'] = np.where(
|
||||
result_df['lower_wick'] != 0,
|
||||
result_df['upper_wick'] / result_df['lower_wick'],
|
||||
0.0
|
||||
)
|
||||
|
||||
# Range
|
||||
result_df['range'] = result_df['high'] - result_df['low']
|
||||
|
||||
# Body to range ratio (handle division by zero)
|
||||
result_df['body_to_range'] = np.where(
|
||||
result_df['range'] != 0,
|
||||
result_df['body_size'] / result_df['range'],
|
||||
0.0
|
||||
)
|
||||
|
||||
# Gap (open - previous close)
|
||||
# For the first candle, gap is 0
|
||||
result_df['gap'] = result_df['open'] - result_df['close'].shift(1)
|
||||
result_df['gap'].fillna(0.0, inplace=True)
|
||||
|
||||
logger.info("Computed 8 candle features: body_size, body_direction, upper_wick, "
|
||||
"lower_wick, wick_ratio, body_to_range, gap, range")
|
||||
|
||||
return result_df
|
||||
|
||||
|
||||
def validate_candle_data(df: pd.DataFrame) -> None:
|
||||
"""
|
||||
Validate OHLC data consistency.
|
||||
|
||||
Checks:
|
||||
- high >= low
|
||||
- high >= open
|
||||
- high >= close
|
||||
- low <= open
|
||||
- low <= close
|
||||
|
||||
Args:
|
||||
df: DataFrame with OHLC columns
|
||||
|
||||
Raises:
|
||||
ValueError: If data validation fails
|
||||
"""
|
||||
# Check high >= low
|
||||
invalid_hl = df[df['high'] < df['low']]
|
||||
if not invalid_hl.empty:
|
||||
logger.warning(f"Found {len(invalid_hl)} rows where high < low")
|
||||
|
||||
# Check high >= open and high >= close
|
||||
invalid_h = df[(df['high'] < df['open']) | (df['high'] < df['close'])]
|
||||
if not invalid_h.empty:
|
||||
logger.warning(f"Found {len(invalid_h)} rows where high < open or high < close")
|
||||
|
||||
# Check low <= open and low <= close
|
||||
invalid_l = df[(df['low'] > df['open']) | (df['low'] > df['close'])]
|
||||
if not invalid_l.empty:
|
||||
logger.warning(f"Found {len(invalid_l)} rows where low > open or low > close")
|
||||
|
||||
# If there are many invalid rows, this could indicate a data quality issue
|
||||
total_invalid = len(invalid_hl) + len(invalid_h) + len(invalid_l)
|
||||
if total_invalid > 0:
|
||||
logger.warning(f"Total invalid candles: {total_invalid} out of {len(df)}")
|
||||
136
services/ml/features/custom_loader.py
Normal file
136
services/ml/features/custom_loader.py
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
"""
|
||||
Custom feature function loader.
|
||||
|
||||
Dynamically imports and executes custom feature functions from configured module paths.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import importlib
|
||||
from typing import List
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_custom_features(
|
||||
df: pd.DataFrame,
|
||||
custom_feature_paths: List[str]
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Load and apply custom feature functions.
|
||||
|
||||
Each custom feature path should be a Python module path (e.g., "features.custom.trend_slope").
|
||||
The module should define a function with the same name as the module's last component.
|
||||
The function should accept a pandas DataFrame and return a pandas Series.
|
||||
|
||||
Args:
|
||||
df: DataFrame with OHLCV + computed features
|
||||
custom_feature_paths: List of module paths to import
|
||||
|
||||
Returns:
|
||||
DataFrame with original columns + custom feature columns
|
||||
|
||||
Raises:
|
||||
ImportError: If a custom feature module cannot be imported
|
||||
AttributeError: If the expected function is not found in the module
|
||||
ValueError: If the custom function doesn't return a Series
|
||||
"""
|
||||
if not custom_feature_paths:
|
||||
logger.debug("No custom features configured")
|
||||
return df
|
||||
|
||||
logger.info(f"Loading {len(custom_feature_paths)} custom feature(s)")
|
||||
|
||||
# Make a copy to avoid modifying the original
|
||||
result_df = df.copy()
|
||||
|
||||
for feature_path in custom_feature_paths:
|
||||
logger.debug(f"Loading custom feature: {feature_path}")
|
||||
|
||||
try:
|
||||
# Import the module
|
||||
module = importlib.import_module(feature_path)
|
||||
|
||||
# Get the function name (last component of the path)
|
||||
function_name = feature_path.split('.')[-1]
|
||||
|
||||
# Get the function from the module
|
||||
if not hasattr(module, function_name):
|
||||
raise AttributeError(
|
||||
f"Module '{feature_path}' does not have a function named '{function_name}'. "
|
||||
f"Custom feature modules must define a function with the same name as the module."
|
||||
)
|
||||
|
||||
feature_func = getattr(module, function_name)
|
||||
|
||||
# Call the function with the current DataFrame
|
||||
logger.debug(f"Calling custom feature function: {function_name}")
|
||||
feature_result = feature_func(result_df)
|
||||
|
||||
# Validate the result is a Series
|
||||
if not isinstance(feature_result, pd.Series):
|
||||
raise ValueError(
|
||||
f"Custom feature function '{function_name}' must return a pandas Series, "
|
||||
f"but returned {type(feature_result).__name__}"
|
||||
)
|
||||
|
||||
# Check the Series has the right length
|
||||
if len(feature_result) != len(result_df):
|
||||
raise ValueError(
|
||||
f"Custom feature function '{function_name}' returned Series with "
|
||||
f"{len(feature_result)} rows, but DataFrame has {len(result_df)} rows"
|
||||
)
|
||||
|
||||
# Add the feature as a new column
|
||||
result_df[function_name] = feature_result.values
|
||||
|
||||
logger.info(f"Added custom feature: {function_name}")
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import custom feature module '{feature_path}': {e}")
|
||||
raise ImportError(
|
||||
f"Cannot import custom feature module '{feature_path}'. "
|
||||
f"Ensure the module exists and is in the Python path. Error: {e}"
|
||||
)
|
||||
except AttributeError as e:
|
||||
logger.error(f"Custom feature function not found: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error applying custom feature '{feature_path}': {e}")
|
||||
raise
|
||||
|
||||
return result_df
|
||||
|
||||
|
||||
def validate_custom_feature_function(func, function_name: str) -> None:
|
||||
"""
|
||||
Validate that a custom feature function has the correct signature.
|
||||
|
||||
Args:
|
||||
func: The function to validate
|
||||
function_name: Name of the function for error messages
|
||||
|
||||
Raises:
|
||||
ValueError: If the function signature is invalid
|
||||
"""
|
||||
import inspect
|
||||
|
||||
sig = inspect.signature(func)
|
||||
params = list(sig.parameters.values())
|
||||
|
||||
if len(params) != 1:
|
||||
raise ValueError(
|
||||
f"Custom feature function '{function_name}' must accept exactly 1 parameter "
|
||||
f"(a pandas DataFrame), but has {len(params)} parameters"
|
||||
)
|
||||
|
||||
# Check if the parameter is annotated as DataFrame (optional check)
|
||||
param = params[0]
|
||||
if param.annotation != inspect.Parameter.empty:
|
||||
if param.annotation not in [pd.DataFrame, 'pd.DataFrame', 'DataFrame']:
|
||||
logger.warning(
|
||||
f"Custom feature function '{function_name}' parameter is annotated as "
|
||||
f"{param.annotation}, but should be pd.DataFrame"
|
||||
)
|
||||
143
services/ml/features/engineer.py
Normal file
143
services/ml/features/engineer.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
"""
|
||||
Feature engineering stage orchestrator.
|
||||
|
||||
Coordinates TA-Lib indicators, candle features, and custom features.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from app.config import PipelineConfig
|
||||
from features.talib_features import compute_talib_indicators
|
||||
from features.candle_features import compute_candle_features, validate_candle_data
|
||||
from features.custom_loader import load_custom_features
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_feature_engineering_stage(config: PipelineConfig) -> None:
|
||||
"""
|
||||
Run the complete feature engineering stage.
|
||||
|
||||
Steps:
|
||||
1. Load raw OHLCV data
|
||||
2. Validate OHLC data consistency
|
||||
3. Compute TA-Lib indicators (if enabled)
|
||||
4. Compute candle features (if enabled)
|
||||
5. Load custom features (if configured)
|
||||
6. Handle NaN values from indicator warmup periods
|
||||
7. Write enriched CSV
|
||||
|
||||
Args:
|
||||
config: Pipeline configuration
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If raw data file doesn't exist
|
||||
ValueError: If data validation fails
|
||||
"""
|
||||
fe_config = config.stages.feature_engineering
|
||||
data_config = config.data
|
||||
|
||||
# Load raw data
|
||||
raw_path = Path(data_config.raw_path)
|
||||
if not raw_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Raw data file not found: {raw_path}. "
|
||||
f"Please ensure OHLCV data is available at this path."
|
||||
)
|
||||
|
||||
logger.info(f"Loading raw OHLCV data from: {raw_path}")
|
||||
df = pd.read_csv(raw_path)
|
||||
|
||||
logger.info(f"Loaded {len(df)} rows with columns: {list(df.columns)}")
|
||||
|
||||
# Validate OHLC data
|
||||
validate_candle_data(df)
|
||||
|
||||
original_rows = len(df)
|
||||
|
||||
# Compute TA-Lib indicators
|
||||
if fe_config.talib_indicators:
|
||||
logger.info(f"Computing {len(fe_config.talib_indicators)} TA-Lib indicators")
|
||||
df = compute_talib_indicators(df, fe_config.talib_indicators)
|
||||
else:
|
||||
logger.info("No TA-Lib indicators configured, skipping")
|
||||
|
||||
# Compute candle features
|
||||
if fe_config.candle_features:
|
||||
logger.info("Computing candle features")
|
||||
df = compute_candle_features(df)
|
||||
else:
|
||||
logger.info("Candle features disabled, skipping")
|
||||
|
||||
# Load custom features
|
||||
if fe_config.custom_features:
|
||||
logger.info(f"Loading {len(fe_config.custom_features)} custom feature(s)")
|
||||
df = load_custom_features(df, fe_config.custom_features)
|
||||
else:
|
||||
logger.info("No custom features configured, skipping")
|
||||
|
||||
# Handle NaN values from indicator warmup periods
|
||||
df = handle_indicator_warmup(df, original_rows)
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
enriched_path = Path(data_config.enriched_path)
|
||||
enriched_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write enriched data
|
||||
logger.info(f"Writing enriched data to: {enriched_path}")
|
||||
df.to_csv(enriched_path, index=False)
|
||||
|
||||
logger.info(
|
||||
f"Feature engineering complete: {original_rows} rows -> {len(df)} rows "
|
||||
f"({original_rows - len(df)} dropped), {len(df.columns)} columns"
|
||||
)
|
||||
|
||||
|
||||
def handle_indicator_warmup(df: pd.DataFrame, original_rows: int) -> pd.DataFrame:
|
||||
"""
|
||||
Handle NaN values introduced by indicator warmup periods.
|
||||
|
||||
Rows with NaN values in any column are dropped. This is necessary because
|
||||
indicators like RSI, MACD, etc. need a warmup period before producing valid values.
|
||||
|
||||
Args:
|
||||
df: DataFrame with computed indicators
|
||||
original_rows: Number of rows before computing indicators
|
||||
|
||||
Returns:
|
||||
DataFrame with NaN rows dropped
|
||||
"""
|
||||
# Count NaN values before dropping
|
||||
nan_counts = df.isnull().sum()
|
||||
cols_with_nan = nan_counts[nan_counts > 0]
|
||||
|
||||
if not cols_with_nan.empty:
|
||||
logger.info("Columns with NaN values (indicator warmup):")
|
||||
for col, count in cols_with_nan.items():
|
||||
logger.info(f" {col}: {count} NaN values")
|
||||
|
||||
# Drop rows with any NaN values
|
||||
df_clean = df.dropna()
|
||||
|
||||
rows_dropped = original_rows - len(df_clean)
|
||||
|
||||
if rows_dropped > 0:
|
||||
logger.info(
|
||||
f"Dropped {rows_dropped} rows due to indicator warmup "
|
||||
f"({rows_dropped / original_rows * 100:.1f}% of original data)"
|
||||
)
|
||||
|
||||
# Warn if too much data was dropped
|
||||
if rows_dropped / original_rows > 0.1:
|
||||
logger.warning(
|
||||
f"More than 10% of data was dropped due to indicator warmup. "
|
||||
f"Consider reducing indicator periods or using more historical data."
|
||||
)
|
||||
else:
|
||||
logger.info("No rows dropped (no NaN values from indicators)")
|
||||
|
||||
return df_clean
|
||||
262
services/ml/features/talib_features.py
Normal file
262
services/ml/features/talib_features.py
Normal file
|
|
@ -0,0 +1,262 @@
|
|||
"""
|
||||
TA-Lib technical indicator computation.
|
||||
|
||||
Computes technical indicators from raw OHLCV data using TA-Lib.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from app.config import TALibIndicator
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def compute_talib_indicators(
|
||||
df: pd.DataFrame,
|
||||
indicators: List[TALibIndicator]
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Compute TA-Lib indicators and append as columns.
|
||||
|
||||
Args:
|
||||
df: DataFrame with OHLCV columns (open, high, low, close, volume)
|
||||
indicators: List of indicator configurations from pipeline config
|
||||
|
||||
Returns:
|
||||
DataFrame with original columns + computed indicator columns
|
||||
|
||||
Raises:
|
||||
ImportError: If TA-Lib is not installed
|
||||
ValueError: If required OHLCV columns are missing
|
||||
AttributeError: If an indicator name is not valid
|
||||
"""
|
||||
# Check if TA-Lib is installed
|
||||
try:
|
||||
import talib
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"TA-Lib is not installed. Please install the TA-Lib C library first:\n\n"
|
||||
" Ubuntu/Debian: sudo apt-get install libta-lib-dev\n"
|
||||
" macOS: brew install ta-lib\n"
|
||||
" Windows: Download from https://www.ta-lib.org/\n\n"
|
||||
"Then install the Python wrapper: pip install TA-Lib\n"
|
||||
)
|
||||
|
||||
# Validate required columns
|
||||
required_cols = ['open', 'high', 'low', 'close', 'volume']
|
||||
missing_cols = [col for col in required_cols if col not in df.columns]
|
||||
if missing_cols:
|
||||
raise ValueError(f"Missing required OHLCV columns: {missing_cols}")
|
||||
|
||||
# Make a copy to avoid modifying the original
|
||||
result_df = df.copy()
|
||||
|
||||
# Extract OHLCV arrays (TA-Lib expects numpy arrays)
|
||||
open_prices = df['open'].values
|
||||
high_prices = df['high'].values
|
||||
low_prices = df['low'].values
|
||||
close_prices = df['close'].values
|
||||
volume = df['volume'].values
|
||||
|
||||
logger.info(f"Computing {len(indicators)} TA-Lib indicators")
|
||||
|
||||
for indicator_config in indicators:
|
||||
indicator_name = indicator_config.name.upper()
|
||||
params = indicator_config.params
|
||||
|
||||
# Check if indicator function exists
|
||||
if not hasattr(talib, indicator_name):
|
||||
raise AttributeError(
|
||||
f"TA-Lib indicator '{indicator_name}' not found. "
|
||||
f"Check TA-Lib documentation for valid indicator names."
|
||||
)
|
||||
|
||||
indicator_func = getattr(talib, indicator_name)
|
||||
|
||||
try:
|
||||
# Call the TA-Lib function with OHLCV data and parameters
|
||||
result = _call_talib_function(
|
||||
indicator_func,
|
||||
indicator_name,
|
||||
open_prices,
|
||||
high_prices,
|
||||
low_prices,
|
||||
close_prices,
|
||||
volume,
|
||||
params
|
||||
)
|
||||
|
||||
# Add result columns to DataFrame
|
||||
result_df = _add_indicator_columns(
|
||||
result_df,
|
||||
indicator_name,
|
||||
result,
|
||||
params
|
||||
)
|
||||
|
||||
logger.debug(f"Computed indicator: {indicator_name} with params {params}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to compute {indicator_name}: {e}")
|
||||
raise
|
||||
|
||||
# Count new columns added
|
||||
new_cols = len(result_df.columns) - len(df.columns)
|
||||
logger.info(f"Added {new_cols} indicator columns")
|
||||
|
||||
return result_df
|
||||
|
||||
|
||||
def _call_talib_function(
|
||||
func,
|
||||
name: str,
|
||||
open_prices: np.ndarray,
|
||||
high_prices: np.ndarray,
|
||||
low_prices: np.ndarray,
|
||||
close_prices: np.ndarray,
|
||||
volume: np.ndarray,
|
||||
params: Dict
|
||||
):
|
||||
"""
|
||||
Call a TA-Lib function with appropriate inputs.
|
||||
|
||||
Different indicators require different inputs (price only, OHLC, OHLCV, etc.)
|
||||
This function handles the common cases.
|
||||
"""
|
||||
# Price indicators (close only)
|
||||
price_only = ['RSI', 'EMA', 'SMA', 'WMA', 'TEMA', 'DEMA', 'TRIMA', 'KAMA',
|
||||
'MAMA', 'T3', 'CCI', 'CMO', 'MOM', 'ROC', 'ROCP', 'ROCR',
|
||||
'TRIX', 'WILLR', 'DX', 'MINUS_DI', 'PLUS_DI', 'MINUS_DM',
|
||||
'PLUS_DM', 'TSF', 'HT_DCPERIOD', 'HT_DCPHASE', 'HT_PHASOR',
|
||||
'HT_SINE', 'HT_TRENDMODE']
|
||||
|
||||
# High-Low-Close indicators
|
||||
hlc_indicators = ['ULTOSC', 'NATR']
|
||||
|
||||
# OHLC indicators
|
||||
ohlc_indicators = ['CDL2CROWS', 'CDL3BLACKCROWS', 'CDL3INSIDE', 'CDL3LINESTRIKE',
|
||||
'CDL3OUTSIDE', 'CDL3STARSINSOUTH', 'CDL3WHITESOLDIERS',
|
||||
'CDLABANDONEDBABY', 'CDLADVANCEBLOCK', 'CDLBELTHOLD',
|
||||
'CDLBREAKAWAY', 'CDLCLOSINGMARUBOZU', 'CDLCONCEALBABYSWALL',
|
||||
'CDLCOUNTERATTACK', 'CDLDARKCLOUDCOVER', 'CDLDOJI',
|
||||
'CDLDOJISTAR', 'CDLDRAGONFLYDOJI', 'CDLENGULFING',
|
||||
'CDLEVENINGDOJISTAR', 'CDLEVENINGSTAR', 'CDLGAPSIDESIDEWHITE',
|
||||
'CDLGRAVESTONEDOJI', 'CDLHAMMER', 'CDLHANGINGMAN',
|
||||
'CDLHARAMI', 'CDLHARAMICROSS', 'CDLHIGHWAVE', 'CDLHIKKAKE',
|
||||
'CDLHIKKAKEMOD', 'CDLHOMINGPIGEON', 'CDLIDENTICAL3CROWS',
|
||||
'CDLINNECK', 'CDLINVERTEDHAMMER', 'CDLKICKING',
|
||||
'CDLKICKINGBYLENGTH', 'CDLLADDERBOTTOM', 'CDLLONGLEGGEDDOJI',
|
||||
'CDLLONGLINE', 'CDLMARUBOZU', 'CDLMATCHINGLOW',
|
||||
'CDLMATHOLD', 'CDLMORNINGDOJISTAR', 'CDLMORNINGSTAR',
|
||||
'CDLONNECK', 'CDLPIERCING', 'CDLRICKSHAWMAN',
|
||||
'CDLRISEFALL3METHODS', 'CDLSEPARATINGLINES', 'CDLSHOOTINGSTAR',
|
||||
'CDLSHORTLINE', 'CDLSPINNINGTOP', 'CDLSTALLEDPATTERN',
|
||||
'CDLSTICKSANDWICH', 'CDLTAKURI', 'CDLTASUKIGAP',
|
||||
'CDLTHRUSTING', 'CDLTRISTAR', 'CDLUNIQUE3RIVER',
|
||||
'CDLUPSIDEGAP2CROWS', 'CDLXSIDEGAP3METHODS']
|
||||
|
||||
# Volume indicators
|
||||
volume_indicators = ['OBV', 'AD', 'ADOSC', 'MFI']
|
||||
|
||||
# High-Low indicators
|
||||
hl_indicators = ['AROON', 'AROONOSC', 'MINUS_DM', 'PLUS_DM']
|
||||
|
||||
if name in price_only:
|
||||
return func(close_prices, **params)
|
||||
elif name in hlc_indicators:
|
||||
return func(high_prices, low_prices, close_prices, **params)
|
||||
elif name in ohlc_indicators:
|
||||
return func(open_prices, high_prices, low_prices, close_prices, **params)
|
||||
elif name in volume_indicators:
|
||||
if name == 'OBV':
|
||||
return func(close_prices, volume, **params)
|
||||
elif name in ['AD', 'ADOSC']:
|
||||
return func(high_prices, low_prices, close_prices, volume, **params)
|
||||
elif name == 'MFI':
|
||||
return func(high_prices, low_prices, close_prices, volume, **params)
|
||||
elif name in hl_indicators:
|
||||
return func(high_prices, low_prices, **params)
|
||||
else:
|
||||
# Default: try with high, low, close (most common)
|
||||
try:
|
||||
return func(high_prices, low_prices, close_prices, **params)
|
||||
except TypeError:
|
||||
# If that fails, try with just close
|
||||
return func(close_prices, **params)
|
||||
|
||||
|
||||
def _add_indicator_columns(
|
||||
df: pd.DataFrame,
|
||||
indicator_name: str,
|
||||
result,
|
||||
params: Dict
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Add indicator result(s) as column(s) to DataFrame.
|
||||
|
||||
Handles single-output and multi-output indicators.
|
||||
Column names follow the pattern: {indicator_lower}_{param} or just {indicator_lower}
|
||||
"""
|
||||
indicator_lower = indicator_name.lower()
|
||||
|
||||
# Check if result is a tuple (multi-output indicator like MACD, BBANDS, STOCH)
|
||||
if isinstance(result, tuple):
|
||||
# Multi-output indicator
|
||||
output_names = _get_output_names(indicator_name, len(result))
|
||||
|
||||
for i, (output_name, values) in enumerate(zip(output_names, result)):
|
||||
col_name = f"{indicator_lower}_{output_name}"
|
||||
df[col_name] = values
|
||||
else:
|
||||
# Single-output indicator
|
||||
# Add parameter to column name if there's a significant param
|
||||
if params:
|
||||
# Use the first parameter value in the column name
|
||||
# Common params: timeperiod, fastperiod, etc.
|
||||
param_key = list(params.keys())[0]
|
||||
param_val = params[param_key]
|
||||
col_name = f"{indicator_lower}_{param_val}"
|
||||
else:
|
||||
col_name = indicator_lower
|
||||
|
||||
df[col_name] = result
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def _get_output_names(indicator_name: str, num_outputs: int) -> List[str]:
|
||||
"""
|
||||
Get output names for multi-output indicators.
|
||||
|
||||
Args:
|
||||
indicator_name: Name of the indicator (e.g., "MACD", "BBANDS")
|
||||
num_outputs: Number of outputs from the indicator
|
||||
|
||||
Returns:
|
||||
List of output names (e.g., ["macd", "signal", "hist"])
|
||||
"""
|
||||
# Known multi-output indicators
|
||||
output_mappings = {
|
||||
'MACD': ['macd', 'signal', 'hist'],
|
||||
'MACDEXT': ['macd', 'signal', 'hist'],
|
||||
'MACDFIX': ['macd', 'signal', 'hist'],
|
||||
'BBANDS': ['upper', 'middle', 'lower'],
|
||||
'STOCH': ['slowk', 'slowd'],
|
||||
'STOCHF': ['fastk', 'fastd'],
|
||||
'STOCHRSI': ['fastk', 'fastd'],
|
||||
'AROON': ['aroondown', 'aroonup'],
|
||||
'HT_PHASOR': ['inphase', 'quadrature'],
|
||||
'HT_SINE': ['sine', 'leadsine'],
|
||||
'MAMA': ['mama', 'fama'],
|
||||
}
|
||||
|
||||
if indicator_name in output_mappings:
|
||||
return output_mappings[indicator_name]
|
||||
|
||||
# Default: generic names
|
||||
return [f"output{i}" for i in range(num_outputs)]
|
||||
Loading…
Add table
Add a link
Reference in a new issue