377 lines
13 KiB
TypeScript
377 lines
13 KiB
TypeScript
'use client';
|
||
|
||
import { useState, useEffect, useRef, useCallback } from 'react';
|
||
import { ChevronDown, Trash2 } from 'lucide-react';
|
||
|
||
interface DatasetInfo {
|
||
path: string;
|
||
exists: boolean;
|
||
size_bytes?: number;
|
||
last_modified?: string;
|
||
row_count?: number;
|
||
}
|
||
|
||
interface TrainingRun {
|
||
run_id: string;
|
||
model_type: string;
|
||
status: 'running' | 'completed' | 'failed';
|
||
created_at: string;
|
||
completed_at?: string;
|
||
metrics_summary?: {
|
||
accuracy?: number;
|
||
f1_macro?: number;
|
||
test_accuracy?: number;
|
||
test_f1_macro?: number;
|
||
[key: string]: number | undefined;
|
||
};
|
||
error?: string;
|
||
}
|
||
|
||
const MODEL_TYPES = [
|
||
{ value: 'random_forest', label: 'Random Forest' },
|
||
{ value: 'xgboost', label: 'XGBoost' },
|
||
];
|
||
|
||
function formatDate(iso: string): string {
|
||
try {
|
||
return new Date(iso).toLocaleString(undefined, {
|
||
month: 'short',
|
||
day: 'numeric',
|
||
hour: '2-digit',
|
||
minute: '2-digit',
|
||
});
|
||
} catch {
|
||
return iso;
|
||
}
|
||
}
|
||
|
||
function StatusBadge({ status }: { status: string }) {
|
||
const colors: Record<string, string> = {
|
||
running: 'bg-blue-500/20 text-blue-500',
|
||
completed: 'bg-green-500/20 text-green-600',
|
||
failed: 'bg-red-500/20 text-destructive',
|
||
};
|
||
return (
|
||
<span className={`px-1 py-0.5 rounded text-[9px] font-medium ${colors[status] || 'bg-secondary text-muted-foreground'}`}>
|
||
{status}
|
||
</span>
|
||
);
|
||
}
|
||
|
||
export default function TrainingPanel() {
|
||
const [expanded, setExpanded] = useState(false);
|
||
const [modelType, setModelType] = useState('random_forest');
|
||
const [datasetInfo, setDatasetInfo] = useState<DatasetInfo | null>(null);
|
||
const [runs, setRuns] = useState<TrainingRun[]>([]);
|
||
const [isTraining, setIsTraining] = useState(false);
|
||
const [activeRunId, setActiveRunId] = useState<string | null>(null);
|
||
const [statusMessage, setStatusMessage] = useState<{ type: 'success' | 'error'; text: string } | null>(null);
|
||
const [isLoadingDataset, setIsLoadingDataset] = useState(false);
|
||
const [isLoadingRuns, setIsLoadingRuns] = useState(false);
|
||
const [deletingRunIds, setDeletingRunIds] = useState<Set<string>>(new Set());
|
||
const pollIntervalRef = useRef<NodeJS.Timeout | null>(null);
|
||
|
||
const fetchDatasetInfo = useCallback(async () => {
|
||
setIsLoadingDataset(true);
|
||
try {
|
||
const res = await fetch('/api/training/dataset-info');
|
||
if (res.ok) setDatasetInfo(await res.json());
|
||
} catch {
|
||
// ignore
|
||
} finally {
|
||
setIsLoadingDataset(false);
|
||
}
|
||
}, []);
|
||
|
||
const fetchRuns = useCallback(async (): Promise<TrainingRun[]> => {
|
||
try {
|
||
const res = await fetch('/api/training/runs');
|
||
if (!res.ok) return [];
|
||
const data = await res.json();
|
||
const runList: TrainingRun[] = data.runs || data;
|
||
setRuns(runList.slice(0, 5));
|
||
return runList;
|
||
} catch {
|
||
return [];
|
||
}
|
||
}, []);
|
||
|
||
// Fetch the authoritative active run from the backend
|
||
const fetchActiveRun = useCallback(async (): Promise<string | null> => {
|
||
try {
|
||
const res = await fetch('/api/training/active');
|
||
if (!res.ok) return null;
|
||
const data = await res.json();
|
||
return data.active ? (data.run_id ?? null) : null;
|
||
} catch {
|
||
return null;
|
||
}
|
||
}, []);
|
||
|
||
// Load data when panel expands; check backend for truly active run
|
||
useEffect(() => {
|
||
if (!expanded) return;
|
||
|
||
fetchDatasetInfo();
|
||
setIsLoadingRuns(true);
|
||
|
||
Promise.all([fetchRuns(), fetchActiveRun()]).then(([_runs, serverActiveRunId]) => {
|
||
if (serverActiveRunId) {
|
||
setActiveRunId(serverActiveRunId);
|
||
setIsTraining(true);
|
||
} else {
|
||
// Ensure we are not stuck in training state
|
||
setIsTraining(false);
|
||
setActiveRunId(null);
|
||
}
|
||
}).finally(() => setIsLoadingRuns(false));
|
||
}, [expanded, fetchDatasetInfo, fetchRuns, fetchActiveRun]);
|
||
|
||
// Poll while training is active
|
||
useEffect(() => {
|
||
if (isTraining && activeRunId) {
|
||
pollIntervalRef.current = setInterval(async () => {
|
||
const [updatedRuns, serverActiveRunId] = await Promise.all([fetchRuns(), fetchActiveRun()]);
|
||
|
||
// If the server says nothing is active, training is done
|
||
if (!serverActiveRunId) {
|
||
setIsTraining(false);
|
||
const finishedRun = updatedRuns.find((r) => r.run_id === activeRunId);
|
||
setActiveRunId(null);
|
||
if (pollIntervalRef.current) clearInterval(pollIntervalRef.current);
|
||
|
||
if (finishedRun?.status === 'completed') {
|
||
const metrics = finishedRun.metrics_summary;
|
||
const metricStr = metrics
|
||
? Object.entries(metrics)
|
||
.map(([k, v]) => `${k}: ${typeof v === 'number' ? (v * 100).toFixed(1) + '%' : v}`)
|
||
.join(', ')
|
||
: '';
|
||
setStatusMessage({
|
||
type: 'success',
|
||
text: `Training complete!${metricStr ? ' ' + metricStr : ''}`,
|
||
});
|
||
} else if (finishedRun?.status === 'failed') {
|
||
setStatusMessage({
|
||
type: 'error',
|
||
text: finishedRun.error || 'Training failed',
|
||
});
|
||
}
|
||
}
|
||
}, 5000);
|
||
} else {
|
||
if (pollIntervalRef.current) {
|
||
clearInterval(pollIntervalRef.current);
|
||
pollIntervalRef.current = null;
|
||
}
|
||
}
|
||
|
||
return () => {
|
||
if (pollIntervalRef.current) clearInterval(pollIntervalRef.current);
|
||
};
|
||
}, [isTraining, activeRunId, fetchRuns, fetchActiveRun]);
|
||
|
||
const handleStartTraining = async () => {
|
||
setStatusMessage(null);
|
||
setIsTraining(true);
|
||
try {
|
||
const res = await fetch('/api/training/start', {
|
||
method: 'POST',
|
||
headers: { 'Content-Type': 'application/json' },
|
||
body: JSON.stringify({ model_type: modelType }),
|
||
});
|
||
|
||
if (res.status === 409) {
|
||
const data = await res.json();
|
||
setActiveRunId(data.run_id || null);
|
||
// Already training – keep isTraining true and let poll handle it
|
||
return;
|
||
}
|
||
|
||
if (!res.ok) {
|
||
const data = await res.json();
|
||
throw new Error(data.error || data.detail || 'Failed to start training');
|
||
}
|
||
|
||
const data = await res.json();
|
||
setActiveRunId(data.run_id);
|
||
// Refresh runs list
|
||
await fetchRuns();
|
||
} catch (e) {
|
||
setIsTraining(false);
|
||
setStatusMessage({
|
||
type: 'error',
|
||
text: e instanceof Error ? e.message : 'Failed to start training',
|
||
});
|
||
}
|
||
};
|
||
|
||
const handleDeleteRun = async (runId: string) => {
|
||
setDeletingRunIds((prev) => new Set(prev).add(runId));
|
||
try {
|
||
const res = await fetch(`/api/training/runs/${runId}`, { method: 'DELETE' });
|
||
if (!res.ok) {
|
||
const data = await res.json().catch(() => ({}));
|
||
setStatusMessage({ type: 'error', text: data.error || data.detail || 'Failed to delete run' });
|
||
return;
|
||
}
|
||
await fetchRuns();
|
||
} catch {
|
||
setStatusMessage({ type: 'error', text: 'Failed to delete run' });
|
||
} finally {
|
||
setDeletingRunIds((prev) => {
|
||
const next = new Set(prev);
|
||
next.delete(runId);
|
||
return next;
|
||
});
|
||
}
|
||
};
|
||
|
||
const canTrain = !isTraining;
|
||
|
||
return (
|
||
<div>
|
||
<button
|
||
onClick={() => setExpanded(!expanded)}
|
||
className="w-full flex items-center justify-between py-1"
|
||
>
|
||
<div className="flex items-center gap-1.5">
|
||
{isTraining && (
|
||
<div className="w-1.5 h-1.5 rounded-full bg-blue-500 animate-pulse" />
|
||
)}
|
||
<span className="text-[10px] font-medium text-muted-foreground uppercase tracking-wider">
|
||
Training
|
||
</span>
|
||
</div>
|
||
<ChevronDown
|
||
className={`w-3 h-3 text-muted-foreground transition-transform ${expanded ? 'rotate-180' : ''}`}
|
||
/>
|
||
</button>
|
||
|
||
{expanded && (
|
||
<div className="mt-1 space-y-2">
|
||
{/* Dataset info */}
|
||
<div className="p-2 bg-secondary/30 rounded text-[10px] space-y-0.5">
|
||
{isLoadingDataset ? (
|
||
<p className="text-muted-foreground">Checking dataset...</p>
|
||
) : datasetInfo === null ? (
|
||
<p className="text-muted-foreground">Dataset info unavailable</p>
|
||
) : datasetInfo.exists ? (
|
||
<>
|
||
<div className="flex justify-between">
|
||
<span className="text-muted-foreground">Dataset:</span>
|
||
<span className="text-green-600 font-mono">Ready</span>
|
||
</div>
|
||
{datasetInfo.row_count !== undefined && (
|
||
<div className="flex justify-between">
|
||
<span className="text-muted-foreground">Rows:</span>
|
||
<span className="font-mono text-foreground">{datasetInfo.row_count.toLocaleString()}</span>
|
||
</div>
|
||
)}
|
||
</>
|
||
) : (
|
||
<p className="text-muted-foreground">
|
||
No cached dataset. It will be built automatically from your annotations when training starts.
|
||
</p>
|
||
)}
|
||
</div>
|
||
|
||
{/* Model type selector */}
|
||
<div>
|
||
<label className="text-[10px] text-muted-foreground block mb-0.5">Model Type</label>
|
||
<select
|
||
value={modelType}
|
||
onChange={(e) => setModelType(e.target.value)}
|
||
disabled={isTraining}
|
||
className="w-full text-[10px] bg-secondary text-foreground rounded px-1.5 py-1 border-0 outline-none disabled:opacity-50"
|
||
>
|
||
{MODEL_TYPES.map((m) => (
|
||
<option key={m.value} value={m.value}>
|
||
{m.label}
|
||
</option>
|
||
))}
|
||
</select>
|
||
</div>
|
||
|
||
{/* Start Training button */}
|
||
<button
|
||
onClick={handleStartTraining}
|
||
disabled={!canTrain}
|
||
className="w-full px-2 py-1.5 text-[10px] bg-primary text-primary-foreground rounded hover:opacity-90 transition-opacity disabled:opacity-50"
|
||
>
|
||
{isTraining ? (
|
||
<span className="flex items-center justify-center gap-1.5">
|
||
<span className="inline-block w-2 h-2 rounded-full border border-primary-foreground border-t-transparent animate-spin" />
|
||
Training in progress...
|
||
</span>
|
||
) : (
|
||
'Start Training'
|
||
)}
|
||
</button>
|
||
|
||
{/* Status message */}
|
||
{statusMessage && (
|
||
<p
|
||
className={`text-[10px] ${
|
||
statusMessage.type === 'success' ? 'text-green-600' : 'text-destructive'
|
||
}`}
|
||
>
|
||
{statusMessage.text}
|
||
</p>
|
||
)}
|
||
|
||
{/* Training run history */}
|
||
<div>
|
||
<p className="text-[10px] text-muted-foreground mb-1">Recent runs</p>
|
||
{isLoadingRuns ? (
|
||
<p className="text-[10px] text-muted-foreground">Loading...</p>
|
||
) : runs.length === 0 ? (
|
||
<p className="text-[10px] text-muted-foreground">No training runs yet</p>
|
||
) : (
|
||
<div className="space-y-1">
|
||
{runs.map((run) => (
|
||
<div key={run.run_id} className="p-1.5 bg-secondary/30 rounded text-[10px] space-y-0.5">
|
||
<div className="flex items-center justify-between gap-1">
|
||
<span className="text-foreground font-medium capitalize">
|
||
{run.model_type.replace('_', ' ')}
|
||
</span>
|
||
<div className="flex items-center gap-1">
|
||
<StatusBadge status={run.status} />
|
||
{run.status !== 'running' && (
|
||
<button
|
||
onClick={() => handleDeleteRun(run.run_id)}
|
||
disabled={deletingRunIds.has(run.run_id)}
|
||
title="Delete run"
|
||
className="text-muted-foreground hover:text-destructive transition-colors disabled:opacity-40"
|
||
>
|
||
<Trash2 className="w-2.5 h-2.5" />
|
||
</button>
|
||
)}
|
||
</div>
|
||
</div>
|
||
<div className="text-muted-foreground">{formatDate(run.created_at)}</div>
|
||
{run.status === 'completed' && run.metrics_summary && (
|
||
<div className="flex flex-wrap gap-x-2 text-muted-foreground">
|
||
{(run.metrics_summary.accuracy ?? run.metrics_summary.test_accuracy) !== undefined && (
|
||
<span>Acc: {((run.metrics_summary.accuracy ?? run.metrics_summary.test_accuracy)! * 100).toFixed(1)}%</span>
|
||
)}
|
||
{(run.metrics_summary.f1_macro ?? run.metrics_summary.test_f1_macro) !== undefined && (
|
||
<span>F1: {((run.metrics_summary.f1_macro ?? run.metrics_summary.test_f1_macro)! * 100).toFixed(1)}%</span>
|
||
)}
|
||
</div>
|
||
)}
|
||
{run.status === 'failed' && run.error && (
|
||
<p className="text-destructive truncate" title={run.error}>
|
||
{run.error}
|
||
</p>
|
||
)}
|
||
</div>
|
||
))}
|
||
</div>
|
||
)}
|
||
</div>
|
||
</div>
|
||
)}
|
||
</div>
|
||
);
|
||
}
|