diff --git a/services/ml/app/annotation_ingestion.py b/services/ml/app/annotation_ingestion.py index 38be4a8..548dc48 100644 --- a/services/ml/app/annotation_ingestion.py +++ b/services/ml/app/annotation_ingestion.py @@ -99,7 +99,7 @@ class AnnotationIngestion: def load_annotations_from_db( self, chart_name: str, - source: str = "human" + source: Optional[str] = "human" ) -> List[Dict[str, Any]]: """ Load annotations directly from PostgreSQL database. @@ -108,7 +108,8 @@ class AnnotationIngestion: Args: chart_name: Name of the chart to load annotations for - source: Filter by annotation source ('human', 'model', 'hybrid') + source: Optional source filter (e.g. 'human', 'talib', 'model'). + When None, includes all sources. Returns: List of annotation dictionaries compatible with existing processing @@ -545,7 +546,7 @@ class AnnotationIngestion: self, enriched_df: pd.DataFrame, chart_name: str, - source: str = "human" + source: Optional[str] = "human" ) -> pd.DataFrame: """ Main processing pipeline using direct database access. @@ -555,7 +556,8 @@ class AnnotationIngestion: Args: enriched_df: DataFrame with engineered features chart_name: Name of the chart to load annotations for - source: Filter by annotation source ('human', 'model', 'hybrid') + source: Optional source filter (e.g. 'human', 'talib', 'model'). + When None, includes all sources. Returns: Labeled DataFrame ready for training diff --git a/services/ml/app/data_access.py b/services/ml/app/data_access.py index be2b363..b827160 100644 --- a/services/ml/app/data_access.py +++ b/services/ml/app/data_access.py @@ -59,6 +59,25 @@ class DataAccess: if result: return dict(result._mapping) return None + + def get_chart_by_id(self, chart_id: int) -> Optional[Dict[str, Any]]: + """ + Get chart by ID. + + Args: + chart_id: Chart ID + + Returns: + Chart dictionary or None if not found + """ + with get_db() as db: + chart_id = int(chart_id) + stmt = select(self.charts).where(self.charts.c.id == chart_id) + result = db.execute(stmt).fetchone() + + if result: + return dict(result._mapping) + return None def get_candles( self, diff --git a/services/ml/app/main.py b/services/ml/app/main.py index 3632072..d428520 100644 --- a/services/ml/app/main.py +++ b/services/ml/app/main.py @@ -1039,6 +1039,11 @@ class TrainingStartRequest(BaseModel): "random_forest", description="Model type: random_forest or xgboost", ) + chart_id: Optional[int] = Field( + default=None, + ge=1, + description="Chart ID to train on. If omitted, falls back to first chart.", + ) class TrainingStartResponse(BaseModel): @@ -1072,7 +1077,7 @@ class DatasetInfoResponse(BaseModel): row_count: Optional[int] = None -def build_dataset_from_db(config: PipelineConfig) -> dict: +def build_dataset_from_db(config: PipelineConfig, chart_id: Optional[int] = None) -> dict: """ Build the labeled training dataset directly from the database. @@ -1090,18 +1095,25 @@ def build_dataset_from_db(config: PipelineConfig) -> dict: data_access = DataAccess() - # Find all charts, use the first one (single-chart app) - charts_df = data_access.get_all_charts() - if charts_df.empty: - raise ValueError("No charts found in database. Upload candle data first.") + # Resolve target chart + if chart_id is not None: + chart = data_access.get_chart_by_id(chart_id) + if not chart: + raise ValueError(f"Chart not found: id={chart_id}") + chart_name = chart["name"] + chart_id_int = int(chart["id"]) + else: + charts_df = data_access.get_all_charts() + if charts_df.empty: + raise ValueError("No charts found in database. Upload candle data first.") + chart = charts_df.iloc[0] + chart_name = chart["name"] + chart_id_int = int(chart["id"]) - chart = charts_df.iloc[0] - chart_name = chart["name"] - chart_id = int(chart["id"]) - logger.info(f"Building dataset for chart: {chart_name} (id={chart_id})") + logger.info(f"Building dataset for chart: {chart_name} (id={chart_id_int})") # Step 1: Export candles to raw CSV - candles_df = data_access.get_candles(chart_id) + candles_df = data_access.get_candles(chart_id_int) if candles_df.empty: raise ValueError(f"No candles found for chart: {chart_name}") @@ -1121,7 +1133,9 @@ def build_dataset_from_db(config: PipelineConfig) -> dict: # Step 3: Run annotation ingestion from database enriched_df = pd.read_csv(enriched_path, parse_dates=["time"]) ingestion = AnnotationIngestion(config.stages.annotation_ingestion) - labeled_df = ingestion.process_from_db(enriched_df, chart_name, source="human") + # Include all annotation sources so TA-Lib generated spans (source='talib') + # can be used for training alongside manual labels. + labeled_df = ingestion.process_from_db(enriched_df, chart_name, source=None) if labeled_df.empty: raise ValueError( @@ -1145,7 +1159,12 @@ def build_dataset_from_db(config: PipelineConfig) -> dict: return result -def _run_training_background(run_id: str, model_type: str, config: PipelineConfig) -> None: +def _run_training_background( + run_id: str, + model_type: str, + config: PipelineConfig, + chart_id: Optional[int] = None, +) -> None: """ Background thread target: build dataset then train a model. @@ -1160,7 +1179,7 @@ def _run_training_background(run_id: str, model_type: str, config: PipelineConfi # Build dataset from database (feature engineering + annotation ingestion) logger.info("Building dataset from database...") - build_dataset_from_db(config) + build_dataset_from_db(config, chart_id=chart_id) labeled_path = Path(config.data.labeled_path) if not labeled_path.exists(): @@ -1336,6 +1355,15 @@ async def training_start(request: TrainingStartRequest): detail=f"Unsupported model type. Available: {', '.join(SUPPORTED_MODEL_TYPES)}", ) + if request.chart_id is not None: + from app.data_access import DataAccess + chart = DataAccess().get_chart_by_id(request.chart_id) + if not chart: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Chart not found: id={request.chart_id}", + ) + # Reject concurrent runs (atomic check-and-set) with state.training_lock: if state.active_training_run_id is not None: @@ -1385,7 +1413,7 @@ async def training_start(request: TrainingStartRequest): # Launch background thread (daemon so it doesn't block process exit) thread = threading.Thread( target=_run_training_background, - args=(run_id, request.model_type, config), + args=(run_id, request.model_type, config, request.chart_id), daemon=True, name=f"training-{run_id[:8]}", ) diff --git a/src/app/api/training/start/route.ts b/src/app/api/training/start/route.ts index 5e73b65..699bf20 100644 --- a/src/app/api/training/start/route.ts +++ b/src/app/api/training/start/route.ts @@ -6,6 +6,7 @@ const INFERENCE_API_TIMEOUT = parseInt(process.env.INFERENCE_API_TIMEOUT || '100 const TrainingStartRequestSchema = z.object({ model_type: z.string().min(1, 'model_type must be a non-empty string'), + chart_id: z.number().int().positive().optional(), }); export async function POST(request: NextRequest) { diff --git a/src/app/page.tsx b/src/app/page.tsx index 0372ca8..2ce8d00 100644 --- a/src/app/page.tsx +++ b/src/app/page.tsx @@ -856,7 +856,7 @@ export default function Home() { {/* Training Panel */}
- +
{/* Predictions */} diff --git a/src/components/TrainingPanel.tsx b/src/components/TrainingPanel.tsx index fb7220b..4101f16 100644 --- a/src/components/TrainingPanel.tsx +++ b/src/components/TrainingPanel.tsx @@ -27,6 +27,10 @@ interface TrainingRun { error?: string; } +interface TrainingPanelProps { + activeChartId: number | null; +} + const MODEL_TYPES = [ { value: 'random_forest', label: 'Random Forest' }, { value: 'xgboost', label: 'XGBoost' }, @@ -58,7 +62,7 @@ function StatusBadge({ status }: { status: string }) { ); } -export default function TrainingPanel() { +export default function TrainingPanel({ activeChartId }: TrainingPanelProps) { const [expanded, setExpanded] = useState(false); const [modelType, setModelType] = useState('random_forest'); const [datasetInfo, setDatasetInfo] = useState(null); @@ -172,13 +176,21 @@ export default function TrainingPanel() { }, [isTraining, activeRunId, fetchRuns, fetchActiveRun]); const handleStartTraining = async () => { + if (!activeChartId) { + setStatusMessage({ + type: 'error', + text: 'Select a chart before starting training', + }); + return; + } + setStatusMessage(null); setIsTraining(true); try { const res = await fetch('/api/training/start', { method: 'POST', headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ model_type: modelType }), + body: JSON.stringify({ model_type: modelType, chart_id: activeChartId }), }); if (res.status === 409) { @@ -227,7 +239,7 @@ export default function TrainingPanel() { } }; - const canTrain = !isTraining; + const canTrain = !isTraining && !!activeChartId; return (
@@ -308,6 +320,11 @@ export default function TrainingPanel() { 'Start Training' )} + {!activeChartId && ( +

+ Select a chart to enable training. +

+ )} {/* Status message */} {statusMessage && (