fix(training): use selected chart and include TA-Lib span sources
This commit is contained in:
parent
3448c6febd
commit
07064fbf40
6 changed files with 89 additions and 22 deletions
|
|
@ -99,7 +99,7 @@ class AnnotationIngestion:
|
||||||
def load_annotations_from_db(
|
def load_annotations_from_db(
|
||||||
self,
|
self,
|
||||||
chart_name: str,
|
chart_name: str,
|
||||||
source: str = "human"
|
source: Optional[str] = "human"
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Load annotations directly from PostgreSQL database.
|
Load annotations directly from PostgreSQL database.
|
||||||
|
|
@ -108,7 +108,8 @@ class AnnotationIngestion:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chart_name: Name of the chart to load annotations for
|
chart_name: Name of the chart to load annotations for
|
||||||
source: Filter by annotation source ('human', 'model', 'hybrid')
|
source: Optional source filter (e.g. 'human', 'talib', 'model').
|
||||||
|
When None, includes all sources.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of annotation dictionaries compatible with existing processing
|
List of annotation dictionaries compatible with existing processing
|
||||||
|
|
@ -545,7 +546,7 @@ class AnnotationIngestion:
|
||||||
self,
|
self,
|
||||||
enriched_df: pd.DataFrame,
|
enriched_df: pd.DataFrame,
|
||||||
chart_name: str,
|
chart_name: str,
|
||||||
source: str = "human"
|
source: Optional[str] = "human"
|
||||||
) -> pd.DataFrame:
|
) -> pd.DataFrame:
|
||||||
"""
|
"""
|
||||||
Main processing pipeline using direct database access.
|
Main processing pipeline using direct database access.
|
||||||
|
|
@ -555,7 +556,8 @@ class AnnotationIngestion:
|
||||||
Args:
|
Args:
|
||||||
enriched_df: DataFrame with engineered features
|
enriched_df: DataFrame with engineered features
|
||||||
chart_name: Name of the chart to load annotations for
|
chart_name: Name of the chart to load annotations for
|
||||||
source: Filter by annotation source ('human', 'model', 'hybrid')
|
source: Optional source filter (e.g. 'human', 'talib', 'model').
|
||||||
|
When None, includes all sources.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Labeled DataFrame ready for training
|
Labeled DataFrame ready for training
|
||||||
|
|
|
||||||
|
|
@ -59,6 +59,25 @@ class DataAccess:
|
||||||
if result:
|
if result:
|
||||||
return dict(result._mapping)
|
return dict(result._mapping)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_chart_by_id(self, chart_id: int) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get chart by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chart_id: Chart ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Chart dictionary or None if not found
|
||||||
|
"""
|
||||||
|
with get_db() as db:
|
||||||
|
chart_id = int(chart_id)
|
||||||
|
stmt = select(self.charts).where(self.charts.c.id == chart_id)
|
||||||
|
result = db.execute(stmt).fetchone()
|
||||||
|
|
||||||
|
if result:
|
||||||
|
return dict(result._mapping)
|
||||||
|
return None
|
||||||
|
|
||||||
def get_candles(
|
def get_candles(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -1039,6 +1039,11 @@ class TrainingStartRequest(BaseModel):
|
||||||
"random_forest",
|
"random_forest",
|
||||||
description="Model type: random_forest or xgboost",
|
description="Model type: random_forest or xgboost",
|
||||||
)
|
)
|
||||||
|
chart_id: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
ge=1,
|
||||||
|
description="Chart ID to train on. If omitted, falls back to first chart.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TrainingStartResponse(BaseModel):
|
class TrainingStartResponse(BaseModel):
|
||||||
|
|
@ -1072,7 +1077,7 @@ class DatasetInfoResponse(BaseModel):
|
||||||
row_count: Optional[int] = None
|
row_count: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
def build_dataset_from_db(config: PipelineConfig) -> dict:
|
def build_dataset_from_db(config: PipelineConfig, chart_id: Optional[int] = None) -> dict:
|
||||||
"""
|
"""
|
||||||
Build the labeled training dataset directly from the database.
|
Build the labeled training dataset directly from the database.
|
||||||
|
|
||||||
|
|
@ -1090,18 +1095,25 @@ def build_dataset_from_db(config: PipelineConfig) -> dict:
|
||||||
|
|
||||||
data_access = DataAccess()
|
data_access = DataAccess()
|
||||||
|
|
||||||
# Find all charts, use the first one (single-chart app)
|
# Resolve target chart
|
||||||
charts_df = data_access.get_all_charts()
|
if chart_id is not None:
|
||||||
if charts_df.empty:
|
chart = data_access.get_chart_by_id(chart_id)
|
||||||
raise ValueError("No charts found in database. Upload candle data first.")
|
if not chart:
|
||||||
|
raise ValueError(f"Chart not found: id={chart_id}")
|
||||||
|
chart_name = chart["name"]
|
||||||
|
chart_id_int = int(chart["id"])
|
||||||
|
else:
|
||||||
|
charts_df = data_access.get_all_charts()
|
||||||
|
if charts_df.empty:
|
||||||
|
raise ValueError("No charts found in database. Upload candle data first.")
|
||||||
|
chart = charts_df.iloc[0]
|
||||||
|
chart_name = chart["name"]
|
||||||
|
chart_id_int = int(chart["id"])
|
||||||
|
|
||||||
chart = charts_df.iloc[0]
|
logger.info(f"Building dataset for chart: {chart_name} (id={chart_id_int})")
|
||||||
chart_name = chart["name"]
|
|
||||||
chart_id = int(chart["id"])
|
|
||||||
logger.info(f"Building dataset for chart: {chart_name} (id={chart_id})")
|
|
||||||
|
|
||||||
# Step 1: Export candles to raw CSV
|
# Step 1: Export candles to raw CSV
|
||||||
candles_df = data_access.get_candles(chart_id)
|
candles_df = data_access.get_candles(chart_id_int)
|
||||||
if candles_df.empty:
|
if candles_df.empty:
|
||||||
raise ValueError(f"No candles found for chart: {chart_name}")
|
raise ValueError(f"No candles found for chart: {chart_name}")
|
||||||
|
|
||||||
|
|
@ -1121,7 +1133,9 @@ def build_dataset_from_db(config: PipelineConfig) -> dict:
|
||||||
# Step 3: Run annotation ingestion from database
|
# Step 3: Run annotation ingestion from database
|
||||||
enriched_df = pd.read_csv(enriched_path, parse_dates=["time"])
|
enriched_df = pd.read_csv(enriched_path, parse_dates=["time"])
|
||||||
ingestion = AnnotationIngestion(config.stages.annotation_ingestion)
|
ingestion = AnnotationIngestion(config.stages.annotation_ingestion)
|
||||||
labeled_df = ingestion.process_from_db(enriched_df, chart_name, source="human")
|
# Include all annotation sources so TA-Lib generated spans (source='talib')
|
||||||
|
# can be used for training alongside manual labels.
|
||||||
|
labeled_df = ingestion.process_from_db(enriched_df, chart_name, source=None)
|
||||||
|
|
||||||
if labeled_df.empty:
|
if labeled_df.empty:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
@ -1145,7 +1159,12 @@ def build_dataset_from_db(config: PipelineConfig) -> dict:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _run_training_background(run_id: str, model_type: str, config: PipelineConfig) -> None:
|
def _run_training_background(
|
||||||
|
run_id: str,
|
||||||
|
model_type: str,
|
||||||
|
config: PipelineConfig,
|
||||||
|
chart_id: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Background thread target: build dataset then train a model.
|
Background thread target: build dataset then train a model.
|
||||||
|
|
||||||
|
|
@ -1160,7 +1179,7 @@ def _run_training_background(run_id: str, model_type: str, config: PipelineConfi
|
||||||
|
|
||||||
# Build dataset from database (feature engineering + annotation ingestion)
|
# Build dataset from database (feature engineering + annotation ingestion)
|
||||||
logger.info("Building dataset from database...")
|
logger.info("Building dataset from database...")
|
||||||
build_dataset_from_db(config)
|
build_dataset_from_db(config, chart_id=chart_id)
|
||||||
|
|
||||||
labeled_path = Path(config.data.labeled_path)
|
labeled_path = Path(config.data.labeled_path)
|
||||||
if not labeled_path.exists():
|
if not labeled_path.exists():
|
||||||
|
|
@ -1336,6 +1355,15 @@ async def training_start(request: TrainingStartRequest):
|
||||||
detail=f"Unsupported model type. Available: {', '.join(SUPPORTED_MODEL_TYPES)}",
|
detail=f"Unsupported model type. Available: {', '.join(SUPPORTED_MODEL_TYPES)}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if request.chart_id is not None:
|
||||||
|
from app.data_access import DataAccess
|
||||||
|
chart = DataAccess().get_chart_by_id(request.chart_id)
|
||||||
|
if not chart:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Chart not found: id={request.chart_id}",
|
||||||
|
)
|
||||||
|
|
||||||
# Reject concurrent runs (atomic check-and-set)
|
# Reject concurrent runs (atomic check-and-set)
|
||||||
with state.training_lock:
|
with state.training_lock:
|
||||||
if state.active_training_run_id is not None:
|
if state.active_training_run_id is not None:
|
||||||
|
|
@ -1385,7 +1413,7 @@ async def training_start(request: TrainingStartRequest):
|
||||||
# Launch background thread (daemon so it doesn't block process exit)
|
# Launch background thread (daemon so it doesn't block process exit)
|
||||||
thread = threading.Thread(
|
thread = threading.Thread(
|
||||||
target=_run_training_background,
|
target=_run_training_background,
|
||||||
args=(run_id, request.model_type, config),
|
args=(run_id, request.model_type, config, request.chart_id),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
name=f"training-{run_id[:8]}",
|
name=f"training-{run_id[:8]}",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ const INFERENCE_API_TIMEOUT = parseInt(process.env.INFERENCE_API_TIMEOUT || '100
|
||||||
|
|
||||||
const TrainingStartRequestSchema = z.object({
|
const TrainingStartRequestSchema = z.object({
|
||||||
model_type: z.string().min(1, 'model_type must be a non-empty string'),
|
model_type: z.string().min(1, 'model_type must be a non-empty string'),
|
||||||
|
chart_id: z.number().int().positive().optional(),
|
||||||
});
|
});
|
||||||
|
|
||||||
export async function POST(request: NextRequest) {
|
export async function POST(request: NextRequest) {
|
||||||
|
|
|
||||||
|
|
@ -856,7 +856,7 @@ export default function Home() {
|
||||||
|
|
||||||
{/* Training Panel */}
|
{/* Training Panel */}
|
||||||
<div className="px-3 py-2 border-b border-sidebar-border">
|
<div className="px-3 py-2 border-b border-sidebar-border">
|
||||||
<TrainingPanel />
|
<TrainingPanel activeChartId={activeChartId} />
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* Predictions */}
|
{/* Predictions */}
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,10 @@ interface TrainingRun {
|
||||||
error?: string;
|
error?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
interface TrainingPanelProps {
|
||||||
|
activeChartId: number | null;
|
||||||
|
}
|
||||||
|
|
||||||
const MODEL_TYPES = [
|
const MODEL_TYPES = [
|
||||||
{ value: 'random_forest', label: 'Random Forest' },
|
{ value: 'random_forest', label: 'Random Forest' },
|
||||||
{ value: 'xgboost', label: 'XGBoost' },
|
{ value: 'xgboost', label: 'XGBoost' },
|
||||||
|
|
@ -58,7 +62,7 @@ function StatusBadge({ status }: { status: string }) {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
export default function TrainingPanel() {
|
export default function TrainingPanel({ activeChartId }: TrainingPanelProps) {
|
||||||
const [expanded, setExpanded] = useState(false);
|
const [expanded, setExpanded] = useState(false);
|
||||||
const [modelType, setModelType] = useState('random_forest');
|
const [modelType, setModelType] = useState('random_forest');
|
||||||
const [datasetInfo, setDatasetInfo] = useState<DatasetInfo | null>(null);
|
const [datasetInfo, setDatasetInfo] = useState<DatasetInfo | null>(null);
|
||||||
|
|
@ -172,13 +176,21 @@ export default function TrainingPanel() {
|
||||||
}, [isTraining, activeRunId, fetchRuns, fetchActiveRun]);
|
}, [isTraining, activeRunId, fetchRuns, fetchActiveRun]);
|
||||||
|
|
||||||
const handleStartTraining = async () => {
|
const handleStartTraining = async () => {
|
||||||
|
if (!activeChartId) {
|
||||||
|
setStatusMessage({
|
||||||
|
type: 'error',
|
||||||
|
text: 'Select a chart before starting training',
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
setStatusMessage(null);
|
setStatusMessage(null);
|
||||||
setIsTraining(true);
|
setIsTraining(true);
|
||||||
try {
|
try {
|
||||||
const res = await fetch('/api/training/start', {
|
const res = await fetch('/api/training/start', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: { 'Content-Type': 'application/json' },
|
headers: { 'Content-Type': 'application/json' },
|
||||||
body: JSON.stringify({ model_type: modelType }),
|
body: JSON.stringify({ model_type: modelType, chart_id: activeChartId }),
|
||||||
});
|
});
|
||||||
|
|
||||||
if (res.status === 409) {
|
if (res.status === 409) {
|
||||||
|
|
@ -227,7 +239,7 @@ export default function TrainingPanel() {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const canTrain = !isTraining;
|
const canTrain = !isTraining && !!activeChartId;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div>
|
<div>
|
||||||
|
|
@ -308,6 +320,11 @@ export default function TrainingPanel() {
|
||||||
'Start Training'
|
'Start Training'
|
||||||
)}
|
)}
|
||||||
</button>
|
</button>
|
||||||
|
{!activeChartId && (
|
||||||
|
<p className="text-[10px] text-muted-foreground">
|
||||||
|
Select a chart to enable training.
|
||||||
|
</p>
|
||||||
|
)}
|
||||||
|
|
||||||
{/* Status message */}
|
{/* Status message */}
|
||||||
{statusMessage && (
|
{statusMessage && (
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue