feat: add TalibPatternPanel, TrainingPanel, ModelSelector UI components (tasks 5-8)
- TalibPatternPanel: pattern checkboxes, detect button, results summary, clear-all and per-pattern delete - TrainingPanel: model type selector, dataset info, start training, polling, run history - ModelSelector: dropdown of completed runs, wired into PredictionPanel for model switching - page.tsx: integrate all three panels into sidebar, wire callbacks (model load, annotations refresh) - tasks.md: mark all 39 tasks complete
This commit is contained in:
parent
2a02669222
commit
12a9603fce
7 changed files with 849 additions and 23 deletions
329
src/components/TrainingPanel.tsx
Normal file
329
src/components/TrainingPanel.tsx
Normal file
|
|
@ -0,0 +1,329 @@
|
|||
'use client';
|
||||
|
||||
import { useState, useEffect, useRef, useCallback } from 'react';
|
||||
import { ChevronDown } from 'lucide-react';
|
||||
|
||||
interface DatasetInfo {
|
||||
path: string;
|
||||
exists: boolean;
|
||||
size_bytes?: number;
|
||||
last_modified?: string;
|
||||
row_count?: number;
|
||||
}
|
||||
|
||||
interface TrainingRun {
|
||||
run_id: string;
|
||||
model_type: string;
|
||||
status: 'running' | 'completed' | 'failed';
|
||||
created_at: string;
|
||||
completed_at?: string;
|
||||
metrics_summary?: {
|
||||
accuracy?: number;
|
||||
f1_macro?: number;
|
||||
[key: string]: number | undefined;
|
||||
};
|
||||
error?: string;
|
||||
}
|
||||
|
||||
const MODEL_TYPES = [
|
||||
{ value: 'random_forest', label: 'Random Forest' },
|
||||
{ value: 'xgboost', label: 'XGBoost' },
|
||||
];
|
||||
|
||||
function formatDate(iso: string): string {
|
||||
try {
|
||||
return new Date(iso).toLocaleString(undefined, {
|
||||
month: 'short',
|
||||
day: 'numeric',
|
||||
hour: '2-digit',
|
||||
minute: '2-digit',
|
||||
});
|
||||
} catch {
|
||||
return iso;
|
||||
}
|
||||
}
|
||||
|
||||
function StatusBadge({ status }: { status: string }) {
|
||||
const colors: Record<string, string> = {
|
||||
running: 'bg-blue-500/20 text-blue-500',
|
||||
completed: 'bg-green-500/20 text-green-600',
|
||||
failed: 'bg-red-500/20 text-destructive',
|
||||
};
|
||||
return (
|
||||
<span className={`px-1 py-0.5 rounded text-[9px] font-medium ${colors[status] || 'bg-secondary text-muted-foreground'}`}>
|
||||
{status}
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
export default function TrainingPanel() {
|
||||
const [expanded, setExpanded] = useState(false);
|
||||
const [modelType, setModelType] = useState('random_forest');
|
||||
const [datasetInfo, setDatasetInfo] = useState<DatasetInfo | null>(null);
|
||||
const [runs, setRuns] = useState<TrainingRun[]>([]);
|
||||
const [isTraining, setIsTraining] = useState(false);
|
||||
const [activeRunId, setActiveRunId] = useState<string | null>(null);
|
||||
const [statusMessage, setStatusMessage] = useState<{ type: 'success' | 'error'; text: string } | null>(null);
|
||||
const [isLoadingDataset, setIsLoadingDataset] = useState(false);
|
||||
const [isLoadingRuns, setIsLoadingRuns] = useState(false);
|
||||
const pollIntervalRef = useRef<NodeJS.Timeout | null>(null);
|
||||
|
||||
const fetchDatasetInfo = useCallback(async () => {
|
||||
setIsLoadingDataset(true);
|
||||
try {
|
||||
const res = await fetch('/api/training/dataset-info');
|
||||
if (res.ok) setDatasetInfo(await res.json());
|
||||
} catch {
|
||||
// ignore
|
||||
} finally {
|
||||
setIsLoadingDataset(false);
|
||||
}
|
||||
}, []);
|
||||
|
||||
const fetchRuns = useCallback(async (): Promise<TrainingRun[]> => {
|
||||
try {
|
||||
const res = await fetch('/api/training/runs');
|
||||
if (!res.ok) return [];
|
||||
const data = await res.json();
|
||||
const runList: TrainingRun[] = data.runs || data;
|
||||
setRuns(runList.slice(0, 5));
|
||||
return runList;
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
}, []);
|
||||
|
||||
// Load data when panel expands
|
||||
useEffect(() => {
|
||||
if (expanded) {
|
||||
fetchDatasetInfo();
|
||||
setIsLoadingRuns(true);
|
||||
fetchRuns().finally(() => setIsLoadingRuns(false));
|
||||
}
|
||||
}, [expanded, fetchDatasetInfo, fetchRuns]);
|
||||
|
||||
// Check if there's an active training run on expand
|
||||
useEffect(() => {
|
||||
if (expanded && runs.length > 0) {
|
||||
const running = runs.find((r) => r.status === 'running');
|
||||
if (running && !activeRunId) {
|
||||
setActiveRunId(running.run_id);
|
||||
setIsTraining(true);
|
||||
}
|
||||
}
|
||||
}, [expanded, runs, activeRunId]);
|
||||
|
||||
// Poll while training is active
|
||||
useEffect(() => {
|
||||
if (isTraining && activeRunId) {
|
||||
pollIntervalRef.current = setInterval(async () => {
|
||||
const updatedRuns = await fetchRuns();
|
||||
const activeRun = updatedRuns.find((r) => r.run_id === activeRunId);
|
||||
if (activeRun && activeRun.status !== 'running') {
|
||||
setIsTraining(false);
|
||||
setActiveRunId(null);
|
||||
if (pollIntervalRef.current) clearInterval(pollIntervalRef.current);
|
||||
|
||||
if (activeRun.status === 'completed') {
|
||||
const metrics = activeRun.metrics_summary;
|
||||
const metricStr = metrics
|
||||
? Object.entries(metrics)
|
||||
.map(([k, v]) => `${k}: ${typeof v === 'number' ? (v * 100).toFixed(1) + '%' : v}`)
|
||||
.join(', ')
|
||||
: '';
|
||||
setStatusMessage({
|
||||
type: 'success',
|
||||
text: `Training complete!${metricStr ? ' ' + metricStr : ''}`,
|
||||
});
|
||||
} else {
|
||||
setStatusMessage({
|
||||
type: 'error',
|
||||
text: activeRun.error || 'Training failed',
|
||||
});
|
||||
}
|
||||
}
|
||||
}, 5000);
|
||||
} else {
|
||||
if (pollIntervalRef.current) {
|
||||
clearInterval(pollIntervalRef.current);
|
||||
pollIntervalRef.current = null;
|
||||
}
|
||||
}
|
||||
|
||||
return () => {
|
||||
if (pollIntervalRef.current) clearInterval(pollIntervalRef.current);
|
||||
};
|
||||
}, [isTraining, activeRunId, fetchRuns]);
|
||||
|
||||
const handleStartTraining = async () => {
|
||||
setStatusMessage(null);
|
||||
setIsTraining(true);
|
||||
try {
|
||||
const res = await fetch('/api/training/start', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ model_type: modelType }),
|
||||
});
|
||||
|
||||
if (res.status === 409) {
|
||||
const data = await res.json();
|
||||
setActiveRunId(data.run_id || null);
|
||||
// Already training – keep isTraining true and let poll handle it
|
||||
return;
|
||||
}
|
||||
|
||||
if (!res.ok) {
|
||||
const data = await res.json();
|
||||
throw new Error(data.error || data.detail || 'Failed to start training');
|
||||
}
|
||||
|
||||
const data = await res.json();
|
||||
setActiveRunId(data.run_id);
|
||||
// Refresh runs list
|
||||
await fetchRuns();
|
||||
} catch (e) {
|
||||
setIsTraining(false);
|
||||
setStatusMessage({
|
||||
type: 'error',
|
||||
text: e instanceof Error ? e.message : 'Failed to start training',
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const datasetMissing = datasetInfo !== null && !datasetInfo.exists;
|
||||
const canTrain = !isTraining && !datasetMissing && datasetInfo !== null;
|
||||
|
||||
return (
|
||||
<div>
|
||||
<button
|
||||
onClick={() => setExpanded(!expanded)}
|
||||
className="w-full flex items-center justify-between py-1"
|
||||
>
|
||||
<div className="flex items-center gap-1.5">
|
||||
{isTraining && (
|
||||
<div className="w-1.5 h-1.5 rounded-full bg-blue-500 animate-pulse" />
|
||||
)}
|
||||
<span className="text-[10px] font-medium text-muted-foreground uppercase tracking-wider">
|
||||
Training
|
||||
</span>
|
||||
</div>
|
||||
<ChevronDown
|
||||
className={`w-3 h-3 text-muted-foreground transition-transform ${expanded ? 'rotate-180' : ''}`}
|
||||
/>
|
||||
</button>
|
||||
|
||||
{expanded && (
|
||||
<div className="mt-1 space-y-2">
|
||||
{/* Dataset info */}
|
||||
<div className="p-2 bg-secondary/30 rounded text-[10px] space-y-0.5">
|
||||
{isLoadingDataset ? (
|
||||
<p className="text-muted-foreground">Checking dataset...</p>
|
||||
) : datasetInfo === null ? (
|
||||
<p className="text-muted-foreground">Dataset info unavailable</p>
|
||||
) : datasetInfo.exists ? (
|
||||
<>
|
||||
<div className="flex justify-between">
|
||||
<span className="text-muted-foreground">Dataset:</span>
|
||||
<span className="text-green-600 font-mono">Ready</span>
|
||||
</div>
|
||||
{datasetInfo.row_count !== undefined && (
|
||||
<div className="flex justify-between">
|
||||
<span className="text-muted-foreground">Rows:</span>
|
||||
<span className="font-mono text-foreground">{datasetInfo.row_count.toLocaleString()}</span>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
) : (
|
||||
<p className="text-orange-500">
|
||||
No training dataset found. Export annotations first.
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Model type selector */}
|
||||
<div>
|
||||
<label className="text-[10px] text-muted-foreground block mb-0.5">Model Type</label>
|
||||
<select
|
||||
value={modelType}
|
||||
onChange={(e) => setModelType(e.target.value)}
|
||||
disabled={isTraining}
|
||||
className="w-full text-[10px] bg-secondary text-foreground rounded px-1.5 py-1 border-0 outline-none disabled:opacity-50"
|
||||
>
|
||||
{MODEL_TYPES.map((m) => (
|
||||
<option key={m.value} value={m.value}>
|
||||
{m.label}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
</div>
|
||||
|
||||
{/* Start Training button */}
|
||||
<button
|
||||
onClick={handleStartTraining}
|
||||
disabled={!canTrain}
|
||||
className="w-full px-2 py-1.5 text-[10px] bg-primary text-primary-foreground rounded hover:opacity-90 transition-opacity disabled:opacity-50"
|
||||
>
|
||||
{isTraining ? (
|
||||
<span className="flex items-center justify-center gap-1.5">
|
||||
<span className="inline-block w-2 h-2 rounded-full border border-primary-foreground border-t-transparent animate-spin" />
|
||||
Training in progress...
|
||||
</span>
|
||||
) : (
|
||||
'Start Training'
|
||||
)}
|
||||
</button>
|
||||
|
||||
{/* Status message */}
|
||||
{statusMessage && (
|
||||
<p
|
||||
className={`text-[10px] ${
|
||||
statusMessage.type === 'success' ? 'text-green-600' : 'text-destructive'
|
||||
}`}
|
||||
>
|
||||
{statusMessage.text}
|
||||
</p>
|
||||
)}
|
||||
|
||||
{/* Training run history */}
|
||||
<div>
|
||||
<p className="text-[10px] text-muted-foreground mb-1">Recent runs</p>
|
||||
{isLoadingRuns ? (
|
||||
<p className="text-[10px] text-muted-foreground">Loading...</p>
|
||||
) : runs.length === 0 ? (
|
||||
<p className="text-[10px] text-muted-foreground">No training runs yet</p>
|
||||
) : (
|
||||
<div className="space-y-1">
|
||||
{runs.map((run) => (
|
||||
<div key={run.run_id} className="p-1.5 bg-secondary/30 rounded text-[10px] space-y-0.5">
|
||||
<div className="flex items-center justify-between gap-1">
|
||||
<span className="text-foreground font-medium capitalize">
|
||||
{run.model_type.replace('_', ' ')}
|
||||
</span>
|
||||
<StatusBadge status={run.status} />
|
||||
</div>
|
||||
<div className="text-muted-foreground">{formatDate(run.created_at)}</div>
|
||||
{run.status === 'completed' && run.metrics_summary && (
|
||||
<div className="flex flex-wrap gap-x-2 text-muted-foreground">
|
||||
{run.metrics_summary.accuracy !== undefined && (
|
||||
<span>Acc: {(run.metrics_summary.accuracy * 100).toFixed(1)}%</span>
|
||||
)}
|
||||
{run.metrics_summary.f1_macro !== undefined && (
|
||||
<span>F1: {(run.metrics_summary.f1_macro * 100).toFixed(1)}%</span>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
{run.status === 'failed' && run.error && (
|
||||
<p className="text-destructive truncate" title={run.error}>
|
||||
{run.error}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue