fix(training): use selected chart and include TA-Lib span sources

This commit is contained in:
Marko Djordjevic 2026-02-18 23:21:23 +01:00
parent 3448c6febd
commit 07064fbf40
6 changed files with 89 additions and 22 deletions

View file

@ -99,7 +99,7 @@ class AnnotationIngestion:
def load_annotations_from_db(
self,
chart_name: str,
source: str = "human"
source: Optional[str] = "human"
) -> List[Dict[str, Any]]:
"""
Load annotations directly from PostgreSQL database.
@ -108,7 +108,8 @@ class AnnotationIngestion:
Args:
chart_name: Name of the chart to load annotations for
source: Filter by annotation source ('human', 'model', 'hybrid')
source: Optional source filter (e.g. 'human', 'talib', 'model').
When None, includes all sources.
Returns:
List of annotation dictionaries compatible with existing processing
@ -545,7 +546,7 @@ class AnnotationIngestion:
self,
enriched_df: pd.DataFrame,
chart_name: str,
source: str = "human"
source: Optional[str] = "human"
) -> pd.DataFrame:
"""
Main processing pipeline using direct database access.
@ -555,7 +556,8 @@ class AnnotationIngestion:
Args:
enriched_df: DataFrame with engineered features
chart_name: Name of the chart to load annotations for
source: Filter by annotation source ('human', 'model', 'hybrid')
source: Optional source filter (e.g. 'human', 'talib', 'model').
When None, includes all sources.
Returns:
Labeled DataFrame ready for training

View file

@ -59,6 +59,25 @@ class DataAccess:
if result:
return dict(result._mapping)
return None
def get_chart_by_id(self, chart_id: int) -> Optional[Dict[str, Any]]:
"""
Get chart by ID.
Args:
chart_id: Chart ID
Returns:
Chart dictionary or None if not found
"""
with get_db() as db:
chart_id = int(chart_id)
stmt = select(self.charts).where(self.charts.c.id == chart_id)
result = db.execute(stmt).fetchone()
if result:
return dict(result._mapping)
return None
def get_candles(
self,

View file

@ -1039,6 +1039,11 @@ class TrainingStartRequest(BaseModel):
"random_forest",
description="Model type: random_forest or xgboost",
)
chart_id: Optional[int] = Field(
default=None,
ge=1,
description="Chart ID to train on. If omitted, falls back to first chart.",
)
class TrainingStartResponse(BaseModel):
@ -1072,7 +1077,7 @@ class DatasetInfoResponse(BaseModel):
row_count: Optional[int] = None
def build_dataset_from_db(config: PipelineConfig) -> dict:
def build_dataset_from_db(config: PipelineConfig, chart_id: Optional[int] = None) -> dict:
"""
Build the labeled training dataset directly from the database.
@ -1090,18 +1095,25 @@ def build_dataset_from_db(config: PipelineConfig) -> dict:
data_access = DataAccess()
# Find all charts, use the first one (single-chart app)
charts_df = data_access.get_all_charts()
if charts_df.empty:
raise ValueError("No charts found in database. Upload candle data first.")
# Resolve target chart
if chart_id is not None:
chart = data_access.get_chart_by_id(chart_id)
if not chart:
raise ValueError(f"Chart not found: id={chart_id}")
chart_name = chart["name"]
chart_id_int = int(chart["id"])
else:
charts_df = data_access.get_all_charts()
if charts_df.empty:
raise ValueError("No charts found in database. Upload candle data first.")
chart = charts_df.iloc[0]
chart_name = chart["name"]
chart_id_int = int(chart["id"])
chart = charts_df.iloc[0]
chart_name = chart["name"]
chart_id = int(chart["id"])
logger.info(f"Building dataset for chart: {chart_name} (id={chart_id})")
logger.info(f"Building dataset for chart: {chart_name} (id={chart_id_int})")
# Step 1: Export candles to raw CSV
candles_df = data_access.get_candles(chart_id)
candles_df = data_access.get_candles(chart_id_int)
if candles_df.empty:
raise ValueError(f"No candles found for chart: {chart_name}")
@ -1121,7 +1133,9 @@ def build_dataset_from_db(config: PipelineConfig) -> dict:
# Step 3: Run annotation ingestion from database
enriched_df = pd.read_csv(enriched_path, parse_dates=["time"])
ingestion = AnnotationIngestion(config.stages.annotation_ingestion)
labeled_df = ingestion.process_from_db(enriched_df, chart_name, source="human")
# Include all annotation sources so TA-Lib generated spans (source='talib')
# can be used for training alongside manual labels.
labeled_df = ingestion.process_from_db(enriched_df, chart_name, source=None)
if labeled_df.empty:
raise ValueError(
@ -1145,7 +1159,12 @@ def build_dataset_from_db(config: PipelineConfig) -> dict:
return result
def _run_training_background(run_id: str, model_type: str, config: PipelineConfig) -> None:
def _run_training_background(
run_id: str,
model_type: str,
config: PipelineConfig,
chart_id: Optional[int] = None,
) -> None:
"""
Background thread target: build dataset then train a model.
@ -1160,7 +1179,7 @@ def _run_training_background(run_id: str, model_type: str, config: PipelineConfi
# Build dataset from database (feature engineering + annotation ingestion)
logger.info("Building dataset from database...")
build_dataset_from_db(config)
build_dataset_from_db(config, chart_id=chart_id)
labeled_path = Path(config.data.labeled_path)
if not labeled_path.exists():
@ -1336,6 +1355,15 @@ async def training_start(request: TrainingStartRequest):
detail=f"Unsupported model type. Available: {', '.join(SUPPORTED_MODEL_TYPES)}",
)
if request.chart_id is not None:
from app.data_access import DataAccess
chart = DataAccess().get_chart_by_id(request.chart_id)
if not chart:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Chart not found: id={request.chart_id}",
)
# Reject concurrent runs (atomic check-and-set)
with state.training_lock:
if state.active_training_run_id is not None:
@ -1385,7 +1413,7 @@ async def training_start(request: TrainingStartRequest):
# Launch background thread (daemon so it doesn't block process exit)
thread = threading.Thread(
target=_run_training_background,
args=(run_id, request.model_type, config),
args=(run_id, request.model_type, config, request.chart_id),
daemon=True,
name=f"training-{run_id[:8]}",
)