feat: add FastAPI pattern detection endpoints (Section 1)
- Extract CDL pattern detection logic into services/ml/app/patterns.py
with TALIB_PATTERNS dict, get_available_patterns(), validate_pattern_names(),
and detect_patterns(candles, pattern_names) functions
- Add GET /patterns/available endpoint returning all 54 supported CDL pattern
names with display names
- Add POST /patterns/detect endpoint accepting {candles, patterns}, running
selected CDL functions, returning span annotations with source "talib"
- Add input validation: reject invalid pattern names with HTTP 400,
treat empty patterns list as "run all"
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
38df874255
commit
b8e649e333
3 changed files with 320 additions and 0 deletions
62
openspec/changes/ml-ui-connection/tasks.md
Normal file
62
openspec/changes/ml-ui-connection/tasks.md
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
## 1. FastAPI Pattern Detection Endpoints
|
||||
|
||||
- [x] 1.1 Extract pattern detection logic from `generate_talib_annotations.py` into a reusable module (e.g., `app/patterns.py`) with `detect_patterns(candles, pattern_names)` function
|
||||
- [x] 1.2 Add `GET /patterns/available` endpoint returning all supported CDL pattern names with display names
|
||||
- [x] 1.3 Add `POST /patterns/detect` endpoint accepting `{candles, patterns}`, running selected CDL functions, returning span annotations with source "talib"
|
||||
- [x] 1.4 Add input validation: reject invalid pattern names with HTTP 400, handle empty patterns list as "run all"
|
||||
|
||||
## 2. FastAPI Training Endpoints
|
||||
|
||||
- [ ] 2.1 Add `POST /training/start` endpoint that launches training in a background thread, returns `{run_id, status: "running"}`, and rejects concurrent runs with HTTP 409
|
||||
- [ ] 2.2 Add `GET /training/runs` endpoint returning training run history from the `training_runs` table, sorted by date descending
|
||||
- [ ] 2.3 Add `GET /training/dataset-info` endpoint returning labeled dataset file path, existence, size, and row count
|
||||
- [ ] 2.4 Add background training thread management: track active run, update DB status on completion/failure
|
||||
|
||||
## 3. FastAPI Model Loading Endpoint
|
||||
|
||||
- [ ] 3.1 Add `POST /model/load` endpoint accepting `{run_id}`, looking up the training run, loading the model artifact, and replacing the active model in `AppState`
|
||||
- [ ] 3.2 Add thread-safe model swap with locking to prevent conflicts with in-flight prediction requests
|
||||
|
||||
## 4. Next.js Proxy Routes
|
||||
|
||||
- [ ] 4.1 Add `GET /api/patterns/available` proxy route
|
||||
- [ ] 4.2 Add `POST /api/patterns/detect` proxy route
|
||||
- [ ] 4.3 Add `POST /api/training/start` proxy route
|
||||
- [ ] 4.4 Add `GET /api/training/runs` proxy route
|
||||
- [ ] 4.5 Add `GET /api/training/dataset-info` proxy route
|
||||
- [ ] 4.6 Add `POST /api/model/load` proxy route
|
||||
- [ ] 4.7 Extend `DELETE /api/span-annotations` to support `source` and `label` query parameters for bulk deletion
|
||||
|
||||
## 5. TA-Lib Pattern UI Panel
|
||||
|
||||
- [ ] 5.1 Create `TalibPatternPanel` component with collapsible section, fetching available patterns from `/api/patterns/available` on mount
|
||||
- [ ] 5.2 Add pattern checkboxes grouped by category with "Select All" / "Deselect All" toggle
|
||||
- [ ] 5.3 Add "Detect Patterns" button that sends selected patterns + chart candles to `/api/patterns/detect`, shows loading state
|
||||
- [ ] 5.4 Save detection results as span annotations via `POST /api/span-annotations` with `source: "talib"` and refresh the annotation list
|
||||
- [ ] 5.5 Add detection results summary showing pattern counts grouped by name
|
||||
- [ ] 5.6 Add "Clear All TA-Lib" bulk delete button calling `DELETE /api/span-annotations?source=talib&chartId=X`
|
||||
- [ ] 5.7 Add per-pattern-type delete in results summary
|
||||
|
||||
## 6. Training UI Panel
|
||||
|
||||
- [ ] 6.1 Create `TrainingPanel` component with collapsible section
|
||||
- [ ] 6.2 Add model type dropdown (Random Forest / XGBoost) defaulting to Random Forest
|
||||
- [ ] 6.3 Add dataset info display fetching from `/api/training/dataset-info`, showing warning if missing
|
||||
- [ ] 6.4 Add "Start Training" button with loading state, disabled when training active or dataset missing
|
||||
- [ ] 6.5 Add training status polling (5s interval) while a run is active, showing progress indicator
|
||||
- [ ] 6.6 Add training completion/failure handling with success message and metrics display
|
||||
- [ ] 6.7 Add training run history list fetching from `/api/training/runs`, showing 5 most recent runs with model type, status, date, metrics
|
||||
|
||||
## 7. Model Selector Integration
|
||||
|
||||
- [ ] 7.1 Create `ModelSelector` dropdown component fetching completed training runs from `/api/training/runs`
|
||||
- [ ] 7.2 Integrate `ModelSelector` into `PredictionPanel` above action buttons, showing current model as active
|
||||
- [ ] 7.3 Wire model switch: on selection call `POST /api/model/load`, clear prediction cache, refresh model info
|
||||
- [ ] 7.4 Handle model load errors: show error toast, keep previous model active
|
||||
|
||||
## 8. Sidebar Layout Integration
|
||||
|
||||
- [ ] 8.1 Add `TalibPatternPanel` to the sidebar in `page.tsx` between SpanAnnotationList and PredictionPanel
|
||||
- [ ] 8.2 Add `TrainingPanel` to the sidebar between TalibPatternPanel and PredictionPanel
|
||||
- [ ] 8.3 Make TalibPatternPanel, TrainingPanel, and PredictionPanel collapsible (default collapsed for new panels)
|
||||
- [ ] 8.4 Wire all new component state and callbacks in `page.tsx`
|
||||
|
|
@ -22,6 +22,12 @@ import mlflow.xgboost
|
|||
|
||||
from app.config import load_config, PipelineConfig
|
||||
from app.preprocessing import preprocess_candles, extract_feature_columns
|
||||
from app.patterns import (
|
||||
TALIB_PATTERNS,
|
||||
get_available_patterns,
|
||||
validate_pattern_names,
|
||||
detect_patterns,
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
|
|
@ -741,6 +747,85 @@ async def predict_batch(request: BatchPredictRequest):
|
|||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pattern Detection Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class PatternInfo(BaseModel):
|
||||
"""A single supported CDL pattern."""
|
||||
function_name: str
|
||||
display_name: str
|
||||
|
||||
|
||||
class DetectPatternsRequest(BaseModel):
|
||||
"""Request model for POST /patterns/detect."""
|
||||
candles: List[CandleData] = Field(..., min_length=1, description="Array of candle data")
|
||||
patterns: List[str] = Field(
|
||||
default=[],
|
||||
description="CDL function names to run. Empty list means run all.",
|
||||
)
|
||||
|
||||
|
||||
class SpanAnnotation(BaseModel):
|
||||
"""A span annotation returned by pattern detection."""
|
||||
start_time: int
|
||||
end_time: int
|
||||
label: str
|
||||
confidence: float = Field(..., ge=0.0, le=1.0)
|
||||
source: str
|
||||
notes: str
|
||||
|
||||
|
||||
class DetectPatternsResponse(BaseModel):
|
||||
"""Response model for POST /patterns/detect."""
|
||||
annotations: List[SpanAnnotation]
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
@app.get("/patterns/available", response_model=List[PatternInfo])
|
||||
async def patterns_available():
|
||||
"""
|
||||
Return all supported CDL pattern names with display names.
|
||||
"""
|
||||
return get_available_patterns()
|
||||
|
||||
|
||||
@app.post("/patterns/detect", response_model=DetectPatternsResponse)
|
||||
async def patterns_detect(request: DetectPatternsRequest):
|
||||
"""
|
||||
Detect TA-Lib CDL patterns on provided candle data.
|
||||
|
||||
- Empty ``patterns`` list runs all available CDL functions.
|
||||
- Invalid pattern names return HTTP 400.
|
||||
"""
|
||||
# 1.4 – Validate pattern names
|
||||
if request.patterns:
|
||||
invalid = validate_pattern_names(request.patterns)
|
||||
if invalid:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid pattern name(s): {', '.join(invalid)}. "
|
||||
f"Use GET /patterns/available to see supported patterns.",
|
||||
)
|
||||
|
||||
candles_data = [c.model_dump() for c in request.candles]
|
||||
|
||||
try:
|
||||
raw_annotations = detect_patterns(candles_data, request.patterns or None)
|
||||
except RuntimeError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=str(exc),
|
||||
)
|
||||
|
||||
annotations = [SpanAnnotation(**ann) for ann in raw_annotations]
|
||||
|
||||
return DetectPatternsResponse(
|
||||
annotations=annotations,
|
||||
metadata={"source": "talib", "count": len(annotations)},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8001)
|
||||
|
|
|
|||
173
services/ml/app/patterns.py
Normal file
173
services/ml/app/patterns.py
Normal file
|
|
@ -0,0 +1,173 @@
|
|||
"""
|
||||
TA-Lib CDL pattern detection module.
|
||||
|
||||
Extracted from generate_talib_annotations.py for use in the FastAPI service.
|
||||
Provides detect_patterns() for use in pattern detection endpoints.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import talib
|
||||
except ImportError:
|
||||
talib = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TALIB_PATTERNS: Dict[str, str] = {
|
||||
'CDLENGULFING': 'Engulfing',
|
||||
'CDLHAMMER': 'Hammer',
|
||||
'CDLINVERTEDHAMMER': 'Inverted Hammer',
|
||||
'CDLSHOOTINGSTAR': 'Shooting Star',
|
||||
'CDLDOJI': 'Doji',
|
||||
'CDLDOJISTAR': 'Doji Star',
|
||||
'CDLMORNINGSTAR': 'Morning Star',
|
||||
'CDLEVENINGSTAR': 'Evening Star',
|
||||
'CDLHARAMI': 'Harami',
|
||||
'CDLHARAMICROSS': 'Harami Cross',
|
||||
'CDLPIERCING': 'Piercing',
|
||||
'CDLDARKCLOUDCOVER': 'Dark Cloud Cover',
|
||||
'CDLMARUBOZU': 'Marubozu',
|
||||
'CDLSPINNINGTOP': 'Spinning Top',
|
||||
'CDL3WHITESOLDIERS': 'Three White Soldiers',
|
||||
'CDL3BLACKCROWS': 'Three Black Crows',
|
||||
'CDLABANDONEDBABY': 'Abandoned Baby',
|
||||
'CDLADVANCEBLOCK': 'Advance Block',
|
||||
'CDLBELTHOLD': 'Belt Hold',
|
||||
'CDLBREAKAWAY': 'Breakaway',
|
||||
'CDLCLOSINGMARUBOZU': 'Closing Marubozu',
|
||||
'CDLCONCEALBABYSWALL': 'Concealing Baby Swallow',
|
||||
'CDLCOUNTERATTACK': 'Counterattack',
|
||||
'CDLDRAGONFLYDOJI': 'Dragonfly Doji',
|
||||
'CDLGAPSIDESIDEWHITE': 'Up/Down Gap Side-by-Side White Lines',
|
||||
'CDLGRAVESTONEDOJI': 'Gravestone Doji',
|
||||
'CDLHANGINGMAN': 'Hanging Man',
|
||||
'CDLHIGHWAVE': 'High Wave',
|
||||
'CDLHIKKAKE': 'Hikkake',
|
||||
'CDLHIKKAKEMOD': 'Modified Hikkake',
|
||||
'CDLHOMINGPIGEON': 'Homing Pigeon',
|
||||
'CDLIDENTICAL3CROWS': 'Identical Three Crows',
|
||||
'CDLINNECK': 'In-Neck',
|
||||
'CDLKICKING': 'Kicking',
|
||||
'CDLKICKINGBYLENGTH': 'Kicking by Length',
|
||||
'CDLLADDERBOTTOM': 'Ladder Bottom',
|
||||
'CDLLONGLEGGEDDOJI': 'Long-Legged Doji',
|
||||
'CDLLONGLINE': 'Long Line',
|
||||
'CDLMATCHINGLOW': 'Matching Low',
|
||||
'CDLMATHOLD': 'Mat Hold',
|
||||
'CDLMORNINGDOJISTAR': 'Morning Doji Star',
|
||||
'CDLONNECK': 'On-Neck',
|
||||
'CDLRISEFALL3METHODS': 'Rising/Falling Three Methods',
|
||||
'CDLSEPARATINGLINES': 'Separating Lines',
|
||||
'CDLSHORTLINE': 'Short Line',
|
||||
'CDLSTALLEDPATTERN': 'Stalled Pattern',
|
||||
'CDLSTICKSANDWICH': 'Stick Sandwich',
|
||||
'CDLTAKURI': 'Takuri',
|
||||
'CDLTASUKIGAP': 'Tasuki Gap',
|
||||
'CDLTHRUSTING': 'Thrusting',
|
||||
'CDLTRISTAR': 'Tristar',
|
||||
'CDLUNIQUE3RIVER': 'Unique Three River',
|
||||
'CDLUPSIDEGAP2CROWS': 'Upside Gap Two Crows',
|
||||
'CDLXSIDEGAP3METHODS': 'Upside/Downside Gap Three Methods',
|
||||
}
|
||||
|
||||
|
||||
def get_available_patterns() -> List[Dict[str, str]]:
|
||||
"""
|
||||
Return all supported CDL pattern names with display names.
|
||||
|
||||
Returns:
|
||||
List of dicts with function_name and display_name keys.
|
||||
"""
|
||||
return [
|
||||
{"function_name": fn, "display_name": display}
|
||||
for fn, display in TALIB_PATTERNS.items()
|
||||
]
|
||||
|
||||
|
||||
def validate_pattern_names(pattern_names: List[str]) -> List[str]:
|
||||
"""
|
||||
Validate a list of pattern names against the supported set.
|
||||
|
||||
Args:
|
||||
pattern_names: List of CDL function names to validate.
|
||||
|
||||
Returns:
|
||||
List of invalid pattern names (empty if all valid).
|
||||
"""
|
||||
return [name for name in pattern_names if name not in TALIB_PATTERNS]
|
||||
|
||||
|
||||
def detect_patterns(
|
||||
candles: List[Dict[str, Any]],
|
||||
pattern_names: Optional[List[str]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Run TA-Lib CDL pattern functions on candle data and return span annotations.
|
||||
|
||||
Args:
|
||||
candles: List of candle dicts with keys time, open, high, low, close.
|
||||
pattern_names: CDL function names to run. None or empty list means run all.
|
||||
|
||||
Returns:
|
||||
List of span annotation dicts with keys:
|
||||
start_time, end_time, label, confidence, source, notes.
|
||||
"""
|
||||
if talib is None:
|
||||
raise RuntimeError("TA-Lib is not installed")
|
||||
|
||||
if not candles:
|
||||
return []
|
||||
|
||||
if not pattern_names:
|
||||
pattern_names = list(TALIB_PATTERNS.keys())
|
||||
|
||||
times = np.array([c["time"] for c in candles], dtype=np.float64)
|
||||
open_prices = np.array([c["open"] for c in candles], dtype=np.float64)
|
||||
high_prices = np.array([c["high"] for c in candles], dtype=np.float64)
|
||||
low_prices = np.array([c["low"] for c in candles], dtype=np.float64)
|
||||
close_prices = np.array([c["close"] for c in candles], dtype=np.float64)
|
||||
|
||||
n = len(candles)
|
||||
annotations = []
|
||||
|
||||
for pattern_func in pattern_names:
|
||||
if not hasattr(talib, pattern_func):
|
||||
logger.warning("Unknown TA-Lib function: %s", pattern_func)
|
||||
continue
|
||||
|
||||
try:
|
||||
func = getattr(talib, pattern_func)
|
||||
values = func(open_prices, high_prices, low_prices, close_prices)
|
||||
except Exception as exc:
|
||||
logger.error("Error running %s: %s", pattern_func, exc)
|
||||
continue
|
||||
|
||||
friendly_name = TALIB_PATTERNS.get(pattern_func, pattern_func)
|
||||
|
||||
for idx, value in enumerate(values):
|
||||
if value == 0:
|
||||
continue
|
||||
|
||||
if value > 0:
|
||||
label = f"Bullish {friendly_name}"
|
||||
else:
|
||||
label = f"Bearish {friendly_name}"
|
||||
|
||||
start_idx = max(0, idx - 1)
|
||||
end_idx = min(n - 1, idx + 1)
|
||||
|
||||
annotations.append({
|
||||
"start_time": int(times[start_idx]),
|
||||
"end_time": int(times[end_idx]),
|
||||
"label": label,
|
||||
"confidence": abs(int(value)) / 100.0,
|
||||
"source": "talib",
|
||||
"notes": f"TA-Lib {pattern_func} detection",
|
||||
})
|
||||
|
||||
logger.info("Detected %d pattern annotations", len(annotations))
|
||||
return annotations
|
||||
Loading…
Add table
Add a link
Reference in a new issue