'use client'; import { useState, useEffect, useRef, useCallback } from 'react'; import { ChevronDown, Trash2 } from 'lucide-react'; interface DatasetInfo { path: string; exists: boolean; size_bytes?: number; last_modified?: string; row_count?: number; } interface TrainingRun { run_id: string; model_type: string; status: 'running' | 'completed' | 'failed'; created_at: string; completed_at?: string; metrics_summary?: { accuracy?: number; f1_macro?: number; test_accuracy?: number; test_f1_macro?: number; [key: string]: number | undefined; }; error?: string; } const MODEL_TYPES = [ { value: 'random_forest', label: 'Random Forest' }, { value: 'xgboost', label: 'XGBoost' }, ]; function formatDate(iso: string): string { try { return new Date(iso).toLocaleString(undefined, { month: 'short', day: 'numeric', hour: '2-digit', minute: '2-digit', }); } catch { return iso; } } function StatusBadge({ status }: { status: string }) { const colors: Record = { running: 'bg-blue-500/20 text-blue-500', completed: 'bg-green-500/20 text-green-600', failed: 'bg-red-500/20 text-destructive', }; return ( {status} ); } export default function TrainingPanel() { const [expanded, setExpanded] = useState(false); const [modelType, setModelType] = useState('random_forest'); const [datasetInfo, setDatasetInfo] = useState(null); const [runs, setRuns] = useState([]); const [isTraining, setIsTraining] = useState(false); const [activeRunId, setActiveRunId] = useState(null); const [statusMessage, setStatusMessage] = useState<{ type: 'success' | 'error'; text: string } | null>(null); const [isLoadingDataset, setIsLoadingDataset] = useState(false); const [isLoadingRuns, setIsLoadingRuns] = useState(false); const [deletingRunIds, setDeletingRunIds] = useState>(new Set()); const pollIntervalRef = useRef(null); const fetchDatasetInfo = useCallback(async () => { setIsLoadingDataset(true); try { const res = await fetch('/api/training/dataset-info'); if (res.ok) setDatasetInfo(await res.json()); } catch { // ignore } finally { setIsLoadingDataset(false); } }, []); const fetchRuns = useCallback(async (): Promise => { try { const res = await fetch('/api/training/runs'); if (!res.ok) return []; const data = await res.json(); const runList: TrainingRun[] = data.runs || data; setRuns(runList.slice(0, 5)); return runList; } catch { return []; } }, []); // Fetch the authoritative active run from the backend const fetchActiveRun = useCallback(async (): Promise => { try { const res = await fetch('/api/training/active'); if (!res.ok) return null; const data = await res.json(); return data.active ? (data.run_id ?? null) : null; } catch { return null; } }, []); // Load data when panel expands; check backend for truly active run useEffect(() => { if (!expanded) return; fetchDatasetInfo(); setIsLoadingRuns(true); Promise.all([fetchRuns(), fetchActiveRun()]).then(([_runs, serverActiveRunId]) => { if (serverActiveRunId) { setActiveRunId(serverActiveRunId); setIsTraining(true); } else { // Ensure we are not stuck in training state setIsTraining(false); setActiveRunId(null); } }).finally(() => setIsLoadingRuns(false)); }, [expanded, fetchDatasetInfo, fetchRuns, fetchActiveRun]); // Poll while training is active useEffect(() => { if (isTraining && activeRunId) { pollIntervalRef.current = setInterval(async () => { const [updatedRuns, serverActiveRunId] = await Promise.all([fetchRuns(), fetchActiveRun()]); // If the server says nothing is active, training is done if (!serverActiveRunId) { setIsTraining(false); const finishedRun = updatedRuns.find((r) => r.run_id === activeRunId); setActiveRunId(null); if (pollIntervalRef.current) clearInterval(pollIntervalRef.current); if (finishedRun?.status === 'completed') { const metrics = finishedRun.metrics_summary; const metricStr = metrics ? Object.entries(metrics) .map(([k, v]) => `${k}: ${typeof v === 'number' ? (v * 100).toFixed(1) + '%' : v}`) .join(', ') : ''; setStatusMessage({ type: 'success', text: `Training complete!${metricStr ? ' ' + metricStr : ''}`, }); } else if (finishedRun?.status === 'failed') { setStatusMessage({ type: 'error', text: finishedRun.error || 'Training failed', }); } } }, 5000); } else { if (pollIntervalRef.current) { clearInterval(pollIntervalRef.current); pollIntervalRef.current = null; } } return () => { if (pollIntervalRef.current) clearInterval(pollIntervalRef.current); }; }, [isTraining, activeRunId, fetchRuns, fetchActiveRun]); const handleStartTraining = async () => { setStatusMessage(null); setIsTraining(true); try { const res = await fetch('/api/training/start', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ model_type: modelType }), }); if (res.status === 409) { const data = await res.json(); setActiveRunId(data.run_id || null); // Already training – keep isTraining true and let poll handle it return; } if (!res.ok) { const data = await res.json(); throw new Error(data.error || data.detail || 'Failed to start training'); } const data = await res.json(); setActiveRunId(data.run_id); // Refresh runs list await fetchRuns(); } catch (e) { setIsTraining(false); setStatusMessage({ type: 'error', text: e instanceof Error ? e.message : 'Failed to start training', }); } }; const handleDeleteRun = async (runId: string) => { setDeletingRunIds((prev) => new Set(prev).add(runId)); try { const res = await fetch(`/api/training/runs/${runId}`, { method: 'DELETE' }); if (!res.ok) { const data = await res.json().catch(() => ({})); setStatusMessage({ type: 'error', text: data.error || data.detail || 'Failed to delete run' }); return; } await fetchRuns(); } catch { setStatusMessage({ type: 'error', text: 'Failed to delete run' }); } finally { setDeletingRunIds((prev) => { const next = new Set(prev); next.delete(runId); return next; }); } }; const canTrain = !isTraining; return (
{expanded && (
{/* Dataset info */}
{isLoadingDataset ? (

Checking dataset...

) : datasetInfo === null ? (

Dataset info unavailable

) : datasetInfo.exists ? ( <>
Dataset: Ready
{datasetInfo.row_count !== undefined && (
Rows: {datasetInfo.row_count.toLocaleString()}
)} ) : (

No cached dataset. It will be built automatically from your annotations when training starts.

)}
{/* Model type selector */}
{/* Start Training button */} {/* Status message */} {statusMessage && (

{statusMessage.text}

)} {/* Training run history */}

Recent runs

{isLoadingRuns ? (

Loading...

) : runs.length === 0 ? (

No training runs yet

) : (
{runs.map((run) => (
{run.model_type.replace('_', ' ')}
{run.status !== 'running' && ( )}
{formatDate(run.created_at)}
{run.status === 'completed' && run.metrics_summary && (
{(run.metrics_summary.accuracy ?? run.metrics_summary.test_accuracy) !== undefined && ( Acc: {((run.metrics_summary.accuracy ?? run.metrics_summary.test_accuracy)! * 100).toFixed(1)}% )} {(run.metrics_summary.f1_macro ?? run.metrics_summary.test_f1_macro) !== undefined && ( F1: {((run.metrics_summary.f1_macro ?? run.metrics_summary.test_f1_macro)! * 100).toFixed(1)}% )}
)} {run.status === 'failed' && run.error && (

{run.error}

)}
))}
)}
)}
); }