diff --git a/openspec/changes/ml-ui-connection/tasks.md b/openspec/changes/ml-ui-connection/tasks.md new file mode 100644 index 0000000..00c1cd4 --- /dev/null +++ b/openspec/changes/ml-ui-connection/tasks.md @@ -0,0 +1,62 @@ +## 1. FastAPI Pattern Detection Endpoints + +- [x] 1.1 Extract pattern detection logic from `generate_talib_annotations.py` into a reusable module (e.g., `app/patterns.py`) with `detect_patterns(candles, pattern_names)` function +- [x] 1.2 Add `GET /patterns/available` endpoint returning all supported CDL pattern names with display names +- [x] 1.3 Add `POST /patterns/detect` endpoint accepting `{candles, patterns}`, running selected CDL functions, returning span annotations with source "talib" +- [x] 1.4 Add input validation: reject invalid pattern names with HTTP 400, handle empty patterns list as "run all" + +## 2. FastAPI Training Endpoints + +- [ ] 2.1 Add `POST /training/start` endpoint that launches training in a background thread, returns `{run_id, status: "running"}`, and rejects concurrent runs with HTTP 409 +- [ ] 2.2 Add `GET /training/runs` endpoint returning training run history from the `training_runs` table, sorted by date descending +- [ ] 2.3 Add `GET /training/dataset-info` endpoint returning labeled dataset file path, existence, size, and row count +- [ ] 2.4 Add background training thread management: track active run, update DB status on completion/failure + +## 3. FastAPI Model Loading Endpoint + +- [ ] 3.1 Add `POST /model/load` endpoint accepting `{run_id}`, looking up the training run, loading the model artifact, and replacing the active model in `AppState` +- [ ] 3.2 Add thread-safe model swap with locking to prevent conflicts with in-flight prediction requests + +## 4. Next.js Proxy Routes + +- [ ] 4.1 Add `GET /api/patterns/available` proxy route +- [ ] 4.2 Add `POST /api/patterns/detect` proxy route +- [ ] 4.3 Add `POST /api/training/start` proxy route +- [ ] 4.4 Add `GET /api/training/runs` proxy route +- [ ] 4.5 Add `GET /api/training/dataset-info` proxy route +- [ ] 4.6 Add `POST /api/model/load` proxy route +- [ ] 4.7 Extend `DELETE /api/span-annotations` to support `source` and `label` query parameters for bulk deletion + +## 5. TA-Lib Pattern UI Panel + +- [ ] 5.1 Create `TalibPatternPanel` component with collapsible section, fetching available patterns from `/api/patterns/available` on mount +- [ ] 5.2 Add pattern checkboxes grouped by category with "Select All" / "Deselect All" toggle +- [ ] 5.3 Add "Detect Patterns" button that sends selected patterns + chart candles to `/api/patterns/detect`, shows loading state +- [ ] 5.4 Save detection results as span annotations via `POST /api/span-annotations` with `source: "talib"` and refresh the annotation list +- [ ] 5.5 Add detection results summary showing pattern counts grouped by name +- [ ] 5.6 Add "Clear All TA-Lib" bulk delete button calling `DELETE /api/span-annotations?source=talib&chartId=X` +- [ ] 5.7 Add per-pattern-type delete in results summary + +## 6. Training UI Panel + +- [ ] 6.1 Create `TrainingPanel` component with collapsible section +- [ ] 6.2 Add model type dropdown (Random Forest / XGBoost) defaulting to Random Forest +- [ ] 6.3 Add dataset info display fetching from `/api/training/dataset-info`, showing warning if missing +- [ ] 6.4 Add "Start Training" button with loading state, disabled when training active or dataset missing +- [ ] 6.5 Add training status polling (5s interval) while a run is active, showing progress indicator +- [ ] 6.6 Add training completion/failure handling with success message and metrics display +- [ ] 6.7 Add training run history list fetching from `/api/training/runs`, showing 5 most recent runs with model type, status, date, metrics + +## 7. Model Selector Integration + +- [ ] 7.1 Create `ModelSelector` dropdown component fetching completed training runs from `/api/training/runs` +- [ ] 7.2 Integrate `ModelSelector` into `PredictionPanel` above action buttons, showing current model as active +- [ ] 7.3 Wire model switch: on selection call `POST /api/model/load`, clear prediction cache, refresh model info +- [ ] 7.4 Handle model load errors: show error toast, keep previous model active + +## 8. Sidebar Layout Integration + +- [ ] 8.1 Add `TalibPatternPanel` to the sidebar in `page.tsx` between SpanAnnotationList and PredictionPanel +- [ ] 8.2 Add `TrainingPanel` to the sidebar between TalibPatternPanel and PredictionPanel +- [ ] 8.3 Make TalibPatternPanel, TrainingPanel, and PredictionPanel collapsible (default collapsed for new panels) +- [ ] 8.4 Wire all new component state and callbacks in `page.tsx` diff --git a/services/ml/app/main.py b/services/ml/app/main.py index c6119b9..c769447 100644 --- a/services/ml/app/main.py +++ b/services/ml/app/main.py @@ -22,6 +22,12 @@ import mlflow.xgboost from app.config import load_config, PipelineConfig from app.preprocessing import preprocess_candles, extract_feature_columns +from app.patterns import ( + TALIB_PATTERNS, + get_available_patterns, + validate_pattern_names, + detect_patterns, +) # Configure logging logging.basicConfig( @@ -741,6 +747,85 @@ async def predict_batch(request: BatchPredictRequest): ) +# --------------------------------------------------------------------------- +# Pattern Detection Endpoints +# --------------------------------------------------------------------------- + +class PatternInfo(BaseModel): + """A single supported CDL pattern.""" + function_name: str + display_name: str + + +class DetectPatternsRequest(BaseModel): + """Request model for POST /patterns/detect.""" + candles: List[CandleData] = Field(..., min_length=1, description="Array of candle data") + patterns: List[str] = Field( + default=[], + description="CDL function names to run. Empty list means run all.", + ) + + +class SpanAnnotation(BaseModel): + """A span annotation returned by pattern detection.""" + start_time: int + end_time: int + label: str + confidence: float = Field(..., ge=0.0, le=1.0) + source: str + notes: str + + +class DetectPatternsResponse(BaseModel): + """Response model for POST /patterns/detect.""" + annotations: List[SpanAnnotation] + metadata: Dict[str, Any] + + +@app.get("/patterns/available", response_model=List[PatternInfo]) +async def patterns_available(): + """ + Return all supported CDL pattern names with display names. + """ + return get_available_patterns() + + +@app.post("/patterns/detect", response_model=DetectPatternsResponse) +async def patterns_detect(request: DetectPatternsRequest): + """ + Detect TA-Lib CDL patterns on provided candle data. + + - Empty ``patterns`` list runs all available CDL functions. + - Invalid pattern names return HTTP 400. + """ + # 1.4 – Validate pattern names + if request.patterns: + invalid = validate_pattern_names(request.patterns) + if invalid: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid pattern name(s): {', '.join(invalid)}. " + f"Use GET /patterns/available to see supported patterns.", + ) + + candles_data = [c.model_dump() for c in request.candles] + + try: + raw_annotations = detect_patterns(candles_data, request.patterns or None) + except RuntimeError as exc: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail=str(exc), + ) + + annotations = [SpanAnnotation(**ann) for ann in raw_annotations] + + return DetectPatternsResponse( + annotations=annotations, + metadata={"source": "talib", "count": len(annotations)}, + ) + + if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8001) diff --git a/services/ml/app/patterns.py b/services/ml/app/patterns.py new file mode 100644 index 0000000..d049402 --- /dev/null +++ b/services/ml/app/patterns.py @@ -0,0 +1,173 @@ +""" +TA-Lib CDL pattern detection module. + +Extracted from generate_talib_annotations.py for use in the FastAPI service. +Provides detect_patterns() for use in pattern detection endpoints. +""" + +from typing import List, Dict, Any, Optional +import logging + +import numpy as np + +try: + import talib +except ImportError: + talib = None + +logger = logging.getLogger(__name__) + +TALIB_PATTERNS: Dict[str, str] = { + 'CDLENGULFING': 'Engulfing', + 'CDLHAMMER': 'Hammer', + 'CDLINVERTEDHAMMER': 'Inverted Hammer', + 'CDLSHOOTINGSTAR': 'Shooting Star', + 'CDLDOJI': 'Doji', + 'CDLDOJISTAR': 'Doji Star', + 'CDLMORNINGSTAR': 'Morning Star', + 'CDLEVENINGSTAR': 'Evening Star', + 'CDLHARAMI': 'Harami', + 'CDLHARAMICROSS': 'Harami Cross', + 'CDLPIERCING': 'Piercing', + 'CDLDARKCLOUDCOVER': 'Dark Cloud Cover', + 'CDLMARUBOZU': 'Marubozu', + 'CDLSPINNINGTOP': 'Spinning Top', + 'CDL3WHITESOLDIERS': 'Three White Soldiers', + 'CDL3BLACKCROWS': 'Three Black Crows', + 'CDLABANDONEDBABY': 'Abandoned Baby', + 'CDLADVANCEBLOCK': 'Advance Block', + 'CDLBELTHOLD': 'Belt Hold', + 'CDLBREAKAWAY': 'Breakaway', + 'CDLCLOSINGMARUBOZU': 'Closing Marubozu', + 'CDLCONCEALBABYSWALL': 'Concealing Baby Swallow', + 'CDLCOUNTERATTACK': 'Counterattack', + 'CDLDRAGONFLYDOJI': 'Dragonfly Doji', + 'CDLGAPSIDESIDEWHITE': 'Up/Down Gap Side-by-Side White Lines', + 'CDLGRAVESTONEDOJI': 'Gravestone Doji', + 'CDLHANGINGMAN': 'Hanging Man', + 'CDLHIGHWAVE': 'High Wave', + 'CDLHIKKAKE': 'Hikkake', + 'CDLHIKKAKEMOD': 'Modified Hikkake', + 'CDLHOMINGPIGEON': 'Homing Pigeon', + 'CDLIDENTICAL3CROWS': 'Identical Three Crows', + 'CDLINNECK': 'In-Neck', + 'CDLKICKING': 'Kicking', + 'CDLKICKINGBYLENGTH': 'Kicking by Length', + 'CDLLADDERBOTTOM': 'Ladder Bottom', + 'CDLLONGLEGGEDDOJI': 'Long-Legged Doji', + 'CDLLONGLINE': 'Long Line', + 'CDLMATCHINGLOW': 'Matching Low', + 'CDLMATHOLD': 'Mat Hold', + 'CDLMORNINGDOJISTAR': 'Morning Doji Star', + 'CDLONNECK': 'On-Neck', + 'CDLRISEFALL3METHODS': 'Rising/Falling Three Methods', + 'CDLSEPARATINGLINES': 'Separating Lines', + 'CDLSHORTLINE': 'Short Line', + 'CDLSTALLEDPATTERN': 'Stalled Pattern', + 'CDLSTICKSANDWICH': 'Stick Sandwich', + 'CDLTAKURI': 'Takuri', + 'CDLTASUKIGAP': 'Tasuki Gap', + 'CDLTHRUSTING': 'Thrusting', + 'CDLTRISTAR': 'Tristar', + 'CDLUNIQUE3RIVER': 'Unique Three River', + 'CDLUPSIDEGAP2CROWS': 'Upside Gap Two Crows', + 'CDLXSIDEGAP3METHODS': 'Upside/Downside Gap Three Methods', +} + + +def get_available_patterns() -> List[Dict[str, str]]: + """ + Return all supported CDL pattern names with display names. + + Returns: + List of dicts with function_name and display_name keys. + """ + return [ + {"function_name": fn, "display_name": display} + for fn, display in TALIB_PATTERNS.items() + ] + + +def validate_pattern_names(pattern_names: List[str]) -> List[str]: + """ + Validate a list of pattern names against the supported set. + + Args: + pattern_names: List of CDL function names to validate. + + Returns: + List of invalid pattern names (empty if all valid). + """ + return [name for name in pattern_names if name not in TALIB_PATTERNS] + + +def detect_patterns( + candles: List[Dict[str, Any]], + pattern_names: Optional[List[str]] = None, +) -> List[Dict[str, Any]]: + """ + Run TA-Lib CDL pattern functions on candle data and return span annotations. + + Args: + candles: List of candle dicts with keys time, open, high, low, close. + pattern_names: CDL function names to run. None or empty list means run all. + + Returns: + List of span annotation dicts with keys: + start_time, end_time, label, confidence, source, notes. + """ + if talib is None: + raise RuntimeError("TA-Lib is not installed") + + if not candles: + return [] + + if not pattern_names: + pattern_names = list(TALIB_PATTERNS.keys()) + + times = np.array([c["time"] for c in candles], dtype=np.float64) + open_prices = np.array([c["open"] for c in candles], dtype=np.float64) + high_prices = np.array([c["high"] for c in candles], dtype=np.float64) + low_prices = np.array([c["low"] for c in candles], dtype=np.float64) + close_prices = np.array([c["close"] for c in candles], dtype=np.float64) + + n = len(candles) + annotations = [] + + for pattern_func in pattern_names: + if not hasattr(talib, pattern_func): + logger.warning("Unknown TA-Lib function: %s", pattern_func) + continue + + try: + func = getattr(talib, pattern_func) + values = func(open_prices, high_prices, low_prices, close_prices) + except Exception as exc: + logger.error("Error running %s: %s", pattern_func, exc) + continue + + friendly_name = TALIB_PATTERNS.get(pattern_func, pattern_func) + + for idx, value in enumerate(values): + if value == 0: + continue + + if value > 0: + label = f"Bullish {friendly_name}" + else: + label = f"Bearish {friendly_name}" + + start_idx = max(0, idx - 1) + end_idx = min(n - 1, idx + 1) + + annotations.append({ + "start_time": int(times[start_idx]), + "end_time": int(times[end_idx]), + "label": label, + "confidence": abs(int(value)) / 100.0, + "source": "talib", + "notes": f"TA-Lib {pattern_func} detection", + }) + + logger.info("Detected %d pattern annotations", len(annotations)) + return annotations