feat: add FastAPI pattern detection endpoints (Section 1)
- Extract CDL pattern detection logic into services/ml/app/patterns.py
with TALIB_PATTERNS dict, get_available_patterns(), validate_pattern_names(),
and detect_patterns(candles, pattern_names) functions
- Add GET /patterns/available endpoint returning all 54 supported CDL pattern
names with display names
- Add POST /patterns/detect endpoint accepting {candles, patterns}, running
selected CDL functions, returning span annotations with source "talib"
- Add input validation: reject invalid pattern names with HTTP 400,
treat empty patterns list as "run all"
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
38df874255
commit
b8e649e333
3 changed files with 320 additions and 0 deletions
|
|
@ -22,6 +22,12 @@ import mlflow.xgboost
|
|||
|
||||
from app.config import load_config, PipelineConfig
|
||||
from app.preprocessing import preprocess_candles, extract_feature_columns
|
||||
from app.patterns import (
|
||||
TALIB_PATTERNS,
|
||||
get_available_patterns,
|
||||
validate_pattern_names,
|
||||
detect_patterns,
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
|
|
@ -741,6 +747,85 @@ async def predict_batch(request: BatchPredictRequest):
|
|||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pattern Detection Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class PatternInfo(BaseModel):
|
||||
"""A single supported CDL pattern."""
|
||||
function_name: str
|
||||
display_name: str
|
||||
|
||||
|
||||
class DetectPatternsRequest(BaseModel):
|
||||
"""Request model for POST /patterns/detect."""
|
||||
candles: List[CandleData] = Field(..., min_length=1, description="Array of candle data")
|
||||
patterns: List[str] = Field(
|
||||
default=[],
|
||||
description="CDL function names to run. Empty list means run all.",
|
||||
)
|
||||
|
||||
|
||||
class SpanAnnotation(BaseModel):
|
||||
"""A span annotation returned by pattern detection."""
|
||||
start_time: int
|
||||
end_time: int
|
||||
label: str
|
||||
confidence: float = Field(..., ge=0.0, le=1.0)
|
||||
source: str
|
||||
notes: str
|
||||
|
||||
|
||||
class DetectPatternsResponse(BaseModel):
|
||||
"""Response model for POST /patterns/detect."""
|
||||
annotations: List[SpanAnnotation]
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
@app.get("/patterns/available", response_model=List[PatternInfo])
|
||||
async def patterns_available():
|
||||
"""
|
||||
Return all supported CDL pattern names with display names.
|
||||
"""
|
||||
return get_available_patterns()
|
||||
|
||||
|
||||
@app.post("/patterns/detect", response_model=DetectPatternsResponse)
|
||||
async def patterns_detect(request: DetectPatternsRequest):
|
||||
"""
|
||||
Detect TA-Lib CDL patterns on provided candle data.
|
||||
|
||||
- Empty ``patterns`` list runs all available CDL functions.
|
||||
- Invalid pattern names return HTTP 400.
|
||||
"""
|
||||
# 1.4 – Validate pattern names
|
||||
if request.patterns:
|
||||
invalid = validate_pattern_names(request.patterns)
|
||||
if invalid:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid pattern name(s): {', '.join(invalid)}. "
|
||||
f"Use GET /patterns/available to see supported patterns.",
|
||||
)
|
||||
|
||||
candles_data = [c.model_dump() for c in request.candles]
|
||||
|
||||
try:
|
||||
raw_annotations = detect_patterns(candles_data, request.patterns or None)
|
||||
except RuntimeError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=str(exc),
|
||||
)
|
||||
|
||||
annotations = [SpanAnnotation(**ann) for ann in raw_annotations]
|
||||
|
||||
return DetectPatternsResponse(
|
||||
annotations=annotations,
|
||||
metadata={"source": "talib", "count": len(annotations)},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8001)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue