candle-annotator/src/components/TrainingPanel.tsx

377 lines
13 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

'use client';
import { useState, useEffect, useRef, useCallback } from 'react';
import { ChevronDown, Trash2 } from 'lucide-react';
interface DatasetInfo {
path: string;
exists: boolean;
size_bytes?: number;
last_modified?: string;
row_count?: number;
}
interface TrainingRun {
run_id: string;
model_type: string;
status: 'running' | 'completed' | 'failed';
created_at: string;
completed_at?: string;
metrics_summary?: {
accuracy?: number;
f1_macro?: number;
test_accuracy?: number;
test_f1_macro?: number;
[key: string]: number | undefined;
};
error?: string;
}
const MODEL_TYPES = [
{ value: 'random_forest', label: 'Random Forest' },
{ value: 'xgboost', label: 'XGBoost' },
];
function formatDate(iso: string): string {
try {
return new Date(iso).toLocaleString(undefined, {
month: 'short',
day: 'numeric',
hour: '2-digit',
minute: '2-digit',
});
} catch {
return iso;
}
}
function StatusBadge({ status }: { status: string }) {
const colors: Record<string, string> = {
running: 'bg-blue-500/20 text-blue-500',
completed: 'bg-green-500/20 text-green-600',
failed: 'bg-red-500/20 text-destructive',
};
return (
<span className={`px-1 py-0.5 rounded text-[9px] font-medium ${colors[status] || 'bg-secondary text-muted-foreground'}`}>
{status}
</span>
);
}
export default function TrainingPanel() {
const [expanded, setExpanded] = useState(false);
const [modelType, setModelType] = useState('random_forest');
const [datasetInfo, setDatasetInfo] = useState<DatasetInfo | null>(null);
const [runs, setRuns] = useState<TrainingRun[]>([]);
const [isTraining, setIsTraining] = useState(false);
const [activeRunId, setActiveRunId] = useState<string | null>(null);
const [statusMessage, setStatusMessage] = useState<{ type: 'success' | 'error'; text: string } | null>(null);
const [isLoadingDataset, setIsLoadingDataset] = useState(false);
const [isLoadingRuns, setIsLoadingRuns] = useState(false);
const [deletingRunIds, setDeletingRunIds] = useState<Set<string>>(new Set());
const pollIntervalRef = useRef<NodeJS.Timeout | null>(null);
const fetchDatasetInfo = useCallback(async () => {
setIsLoadingDataset(true);
try {
const res = await fetch('/api/training/dataset-info');
if (res.ok) setDatasetInfo(await res.json());
} catch {
// ignore
} finally {
setIsLoadingDataset(false);
}
}, []);
const fetchRuns = useCallback(async (): Promise<TrainingRun[]> => {
try {
const res = await fetch('/api/training/runs');
if (!res.ok) return [];
const data = await res.json();
const runList: TrainingRun[] = data.runs || data;
setRuns(runList.slice(0, 5));
return runList;
} catch {
return [];
}
}, []);
// Fetch the authoritative active run from the backend
const fetchActiveRun = useCallback(async (): Promise<string | null> => {
try {
const res = await fetch('/api/training/active');
if (!res.ok) return null;
const data = await res.json();
return data.active ? (data.run_id ?? null) : null;
} catch {
return null;
}
}, []);
// Load data when panel expands; check backend for truly active run
useEffect(() => {
if (!expanded) return;
fetchDatasetInfo();
setIsLoadingRuns(true);
Promise.all([fetchRuns(), fetchActiveRun()]).then(([_runs, serverActiveRunId]) => {
if (serverActiveRunId) {
setActiveRunId(serverActiveRunId);
setIsTraining(true);
} else {
// Ensure we are not stuck in training state
setIsTraining(false);
setActiveRunId(null);
}
}).finally(() => setIsLoadingRuns(false));
}, [expanded, fetchDatasetInfo, fetchRuns, fetchActiveRun]);
// Poll while training is active
useEffect(() => {
if (isTraining && activeRunId) {
pollIntervalRef.current = setInterval(async () => {
const [updatedRuns, serverActiveRunId] = await Promise.all([fetchRuns(), fetchActiveRun()]);
// If the server says nothing is active, training is done
if (!serverActiveRunId) {
setIsTraining(false);
const finishedRun = updatedRuns.find((r) => r.run_id === activeRunId);
setActiveRunId(null);
if (pollIntervalRef.current) clearInterval(pollIntervalRef.current);
if (finishedRun?.status === 'completed') {
const metrics = finishedRun.metrics_summary;
const metricStr = metrics
? Object.entries(metrics)
.map(([k, v]) => `${k}: ${typeof v === 'number' ? (v * 100).toFixed(1) + '%' : v}`)
.join(', ')
: '';
setStatusMessage({
type: 'success',
text: `Training complete!${metricStr ? ' ' + metricStr : ''}`,
});
} else if (finishedRun?.status === 'failed') {
setStatusMessage({
type: 'error',
text: finishedRun.error || 'Training failed',
});
}
}
}, 5000);
} else {
if (pollIntervalRef.current) {
clearInterval(pollIntervalRef.current);
pollIntervalRef.current = null;
}
}
return () => {
if (pollIntervalRef.current) clearInterval(pollIntervalRef.current);
};
}, [isTraining, activeRunId, fetchRuns, fetchActiveRun]);
const handleStartTraining = async () => {
setStatusMessage(null);
setIsTraining(true);
try {
const res = await fetch('/api/training/start', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ model_type: modelType }),
});
if (res.status === 409) {
const data = await res.json();
setActiveRunId(data.run_id || null);
// Already training keep isTraining true and let poll handle it
return;
}
if (!res.ok) {
const data = await res.json();
throw new Error(data.error || data.detail || 'Failed to start training');
}
const data = await res.json();
setActiveRunId(data.run_id);
// Refresh runs list
await fetchRuns();
} catch (e) {
setIsTraining(false);
setStatusMessage({
type: 'error',
text: e instanceof Error ? e.message : 'Failed to start training',
});
}
};
const handleDeleteRun = async (runId: string) => {
setDeletingRunIds((prev) => new Set(prev).add(runId));
try {
const res = await fetch(`/api/training/runs/${runId}`, { method: 'DELETE' });
if (!res.ok) {
const data = await res.json().catch(() => ({}));
setStatusMessage({ type: 'error', text: data.error || data.detail || 'Failed to delete run' });
return;
}
await fetchRuns();
} catch {
setStatusMessage({ type: 'error', text: 'Failed to delete run' });
} finally {
setDeletingRunIds((prev) => {
const next = new Set(prev);
next.delete(runId);
return next;
});
}
};
const canTrain = !isTraining;
return (
<div>
<button
onClick={() => setExpanded(!expanded)}
className="w-full flex items-center justify-between py-1"
>
<div className="flex items-center gap-1.5">
{isTraining && (
<div className="w-1.5 h-1.5 rounded-full bg-blue-500 animate-pulse" />
)}
<span className="text-[10px] font-medium text-muted-foreground uppercase tracking-wider">
Training
</span>
</div>
<ChevronDown
className={`w-3 h-3 text-muted-foreground transition-transform ${expanded ? 'rotate-180' : ''}`}
/>
</button>
{expanded && (
<div className="mt-1 space-y-2">
{/* Dataset info */}
<div className="p-2 bg-secondary/30 rounded text-[10px] space-y-0.5">
{isLoadingDataset ? (
<p className="text-muted-foreground">Checking dataset...</p>
) : datasetInfo === null ? (
<p className="text-muted-foreground">Dataset info unavailable</p>
) : datasetInfo.exists ? (
<>
<div className="flex justify-between">
<span className="text-muted-foreground">Dataset:</span>
<span className="text-green-600 font-mono">Ready</span>
</div>
{datasetInfo.row_count !== undefined && (
<div className="flex justify-between">
<span className="text-muted-foreground">Rows:</span>
<span className="font-mono text-foreground">{datasetInfo.row_count.toLocaleString()}</span>
</div>
)}
</>
) : (
<p className="text-muted-foreground">
No cached dataset. It will be built automatically from your annotations when training starts.
</p>
)}
</div>
{/* Model type selector */}
<div>
<label className="text-[10px] text-muted-foreground block mb-0.5">Model Type</label>
<select
value={modelType}
onChange={(e) => setModelType(e.target.value)}
disabled={isTraining}
className="w-full text-[10px] bg-secondary text-foreground rounded px-1.5 py-1 border-0 outline-none disabled:opacity-50"
>
{MODEL_TYPES.map((m) => (
<option key={m.value} value={m.value}>
{m.label}
</option>
))}
</select>
</div>
{/* Start Training button */}
<button
onClick={handleStartTraining}
disabled={!canTrain}
className="w-full px-2 py-1.5 text-[10px] bg-primary text-primary-foreground rounded hover:opacity-90 transition-opacity disabled:opacity-50"
>
{isTraining ? (
<span className="flex items-center justify-center gap-1.5">
<span className="inline-block w-2 h-2 rounded-full border border-primary-foreground border-t-transparent animate-spin" />
Training in progress...
</span>
) : (
'Start Training'
)}
</button>
{/* Status message */}
{statusMessage && (
<p
className={`text-[10px] ${
statusMessage.type === 'success' ? 'text-green-600' : 'text-destructive'
}`}
>
{statusMessage.text}
</p>
)}
{/* Training run history */}
<div>
<p className="text-[10px] text-muted-foreground mb-1">Recent runs</p>
{isLoadingRuns ? (
<p className="text-[10px] text-muted-foreground">Loading...</p>
) : runs.length === 0 ? (
<p className="text-[10px] text-muted-foreground">No training runs yet</p>
) : (
<div className="space-y-1">
{runs.map((run) => (
<div key={run.run_id} className="p-1.5 bg-secondary/30 rounded text-[10px] space-y-0.5">
<div className="flex items-center justify-between gap-1">
<span className="text-foreground font-medium capitalize">
{run.model_type.replace('_', ' ')}
</span>
<div className="flex items-center gap-1">
<StatusBadge status={run.status} />
{run.status !== 'running' && (
<button
onClick={() => handleDeleteRun(run.run_id)}
disabled={deletingRunIds.has(run.run_id)}
title="Delete run"
className="text-muted-foreground hover:text-destructive transition-colors disabled:opacity-40"
>
<Trash2 className="w-2.5 h-2.5" />
</button>
)}
</div>
</div>
<div className="text-muted-foreground">{formatDate(run.created_at)}</div>
{run.status === 'completed' && run.metrics_summary && (
<div className="flex flex-wrap gap-x-2 text-muted-foreground">
{(run.metrics_summary.accuracy ?? run.metrics_summary.test_accuracy) !== undefined && (
<span>Acc: {((run.metrics_summary.accuracy ?? run.metrics_summary.test_accuracy)! * 100).toFixed(1)}%</span>
)}
{(run.metrics_summary.f1_macro ?? run.metrics_summary.test_f1_macro) !== undefined && (
<span>F1: {((run.metrics_summary.f1_macro ?? run.metrics_summary.test_f1_macro)! * 100).toFixed(1)}%</span>
)}
</div>
)}
{run.status === 'failed' && run.error && (
<p className="text-destructive truncate" title={run.error}>
{run.error}
</p>
)}
</div>
))}
</div>
)}
</div>
</div>
)}
</div>
);
}