feat: add TalibPatternPanel, TrainingPanel, ModelSelector UI components (tasks 5-8)

- TalibPatternPanel: pattern checkboxes, detect button, results summary, clear-all and per-pattern delete
- TrainingPanel: model type selector, dataset info, start training, polling, run history
- ModelSelector: dropdown of completed runs, wired into PredictionPanel for model switching
- page.tsx: integrate all three panels into sidebar, wire callbacks (model load, annotations refresh)
- tasks.md: mark all 39 tasks complete
This commit is contained in:
Marko Djordjevic 2026-02-17 18:55:52 +01:00
parent 2a02669222
commit 12a9603fce
7 changed files with 849 additions and 23 deletions

View file

@ -29,34 +29,34 @@
## 5. TA-Lib Pattern UI Panel
- [ ] 5.1 Create `TalibPatternPanel` component with collapsible section, fetching available patterns from `/api/patterns/available` on mount
- [ ] 5.2 Add pattern checkboxes grouped by category with "Select All" / "Deselect All" toggle
- [ ] 5.3 Add "Detect Patterns" button that sends selected patterns + chart candles to `/api/patterns/detect`, shows loading state
- [ ] 5.4 Save detection results as span annotations via `POST /api/span-annotations` with `source: "talib"` and refresh the annotation list
- [ ] 5.5 Add detection results summary showing pattern counts grouped by name
- [ ] 5.6 Add "Clear All TA-Lib" bulk delete button calling `DELETE /api/span-annotations?source=talib&chartId=X`
- [ ] 5.7 Add per-pattern-type delete in results summary
- [x] 5.1 Create `TalibPatternPanel` component with collapsible section, fetching available patterns from `/api/patterns/available` on mount
- [x] 5.2 Add pattern checkboxes grouped by category with "Select All" / "Deselect All" toggle
- [x] 5.3 Add "Detect Patterns" button that sends selected patterns + chart candles to `/api/patterns/detect`, shows loading state
- [x] 5.4 Save detection results as span annotations via `POST /api/span-annotations` with `source: "talib"` and refresh the annotation list
- [x] 5.5 Add detection results summary showing pattern counts grouped by name
- [x] 5.6 Add "Clear All TA-Lib" bulk delete button calling `DELETE /api/span-annotations?source=talib&chartId=X`
- [x] 5.7 Add per-pattern-type delete in results summary
## 6. Training UI Panel
- [ ] 6.1 Create `TrainingPanel` component with collapsible section
- [ ] 6.2 Add model type dropdown (Random Forest / XGBoost) defaulting to Random Forest
- [ ] 6.3 Add dataset info display fetching from `/api/training/dataset-info`, showing warning if missing
- [ ] 6.4 Add "Start Training" button with loading state, disabled when training active or dataset missing
- [ ] 6.5 Add training status polling (5s interval) while a run is active, showing progress indicator
- [ ] 6.6 Add training completion/failure handling with success message and metrics display
- [ ] 6.7 Add training run history list fetching from `/api/training/runs`, showing 5 most recent runs with model type, status, date, metrics
- [x] 6.1 Create `TrainingPanel` component with collapsible section
- [x] 6.2 Add model type dropdown (Random Forest / XGBoost) defaulting to Random Forest
- [x] 6.3 Add dataset info display fetching from `/api/training/dataset-info`, showing warning if missing
- [x] 6.4 Add "Start Training" button with loading state, disabled when training active or dataset missing
- [x] 6.5 Add training status polling (5s interval) while a run is active, showing progress indicator
- [x] 6.6 Add training completion/failure handling with success message and metrics display
- [x] 6.7 Add training run history list fetching from `/api/training/runs`, showing 5 most recent runs with model type, status, date, metrics
## 7. Model Selector Integration
- [ ] 7.1 Create `ModelSelector` dropdown component fetching completed training runs from `/api/training/runs`
- [ ] 7.2 Integrate `ModelSelector` into `PredictionPanel` above action buttons, showing current model as active
- [ ] 7.3 Wire model switch: on selection call `POST /api/model/load`, clear prediction cache, refresh model info
- [ ] 7.4 Handle model load errors: show error toast, keep previous model active
- [x] 7.1 Create `ModelSelector` dropdown component fetching completed training runs from `/api/training/runs`
- [x] 7.2 Integrate `ModelSelector` into `PredictionPanel` above action buttons, showing current model as active
- [x] 7.3 Wire model switch: on selection call `POST /api/model/load`, clear prediction cache, refresh model info
- [x] 7.4 Handle model load errors: show error toast, keep previous model active
## 8. Sidebar Layout Integration
- [ ] 8.1 Add `TalibPatternPanel` to the sidebar in `page.tsx` between SpanAnnotationList and PredictionPanel
- [ ] 8.2 Add `TrainingPanel` to the sidebar between TalibPatternPanel and PredictionPanel
- [ ] 8.3 Make TalibPatternPanel, TrainingPanel, and PredictionPanel collapsible (default collapsed for new panels)
- [ ] 8.4 Wire all new component state and callbacks in `page.tsx`
- [x] 8.1 Add `TalibPatternPanel` to the sidebar in `page.tsx` between SpanAnnotationList and PredictionPanel
- [x] 8.2 Add `TrainingPanel` to the sidebar between TalibPatternPanel and PredictionPanel
- [x] 8.3 Make TalibPatternPanel, TrainingPanel, and PredictionPanel collapsible (default collapsed for new panels)
- [x] 8.4 Wire all new component state and callbacks in `page.tsx`

View file

@ -7,6 +7,8 @@ import CandleChart, { CandleChartHandle } from '@/components/CandleChart';
import ChartSelector from '@/components/ChartSelector';
import PredictionPanel from '@/components/PredictionPanel';
import SpanAnnotationList from '@/components/SpanAnnotationList';
import TalibPatternPanel from '@/components/TalibPatternPanel';
import TrainingPanel from '@/components/TrainingPanel';
import { ThemeToggle } from '@/components/ThemeToggle';
import type { PredictionState, PredictionSpan, ModelInfoResponse, Disagreement, DisagreementType, PredictionSummary } from '@/types/predictions';
@ -566,6 +568,13 @@ export default function Home() {
}
}, [fetchPredictions]);
// Handle model loaded via ModelSelector: refresh model info and clear prediction cache
const handleModelLoaded = useCallback(async () => {
predictionCacheRef.current = new Map();
setPredictionState((prev) => ({ ...prev, spans: [], perCandlePredictions: [], visible: false }));
await fetchModelInfo();
}, [fetchModelInfo]);
// Handle batch prediction for all candles
const handleFetchBatchPredictions = useCallback(async () => {
if (!activeChartId) return;
@ -758,6 +767,20 @@ export default function Home() {
/>
</div>
{/* TA-Lib Pattern Panel */}
<div className="px-3 py-2 border-t border-sidebar-border">
<TalibPatternPanel
activeChartId={activeChartId}
onAnnotationsChanged={() => fetchSpanAnnotations(activeChartId)}
getCandles={() => chartRef.current?.getVisibleCandles()}
/>
</div>
{/* Training Panel */}
<div className="px-3 py-2 border-t border-sidebar-border">
<TrainingPanel />
</div>
{/* Predictions */}
<div className="px-3 py-2 border-t border-sidebar-border">
<PredictionPanel
@ -771,6 +794,7 @@ export default function Home() {
isModelOnline={isModelOnline}
showOnlyDisagreements={showOnlyDisagreements}
onToggleShowOnlyDisagreements={toggleShowOnlyDisagreements}
onModelLoaded={handleModelLoaded}
/>
</div>

View file

@ -0,0 +1,138 @@
'use client';
import { useState, useEffect, useCallback } from 'react';
interface TrainingRun {
run_id: string;
model_type: string;
status: string;
created_at: string;
metrics_summary?: {
f1_macro?: number;
accuracy?: number;
[key: string]: number | undefined;
};
}
interface ModelSelectorProps {
currentModelVersion: string | null | undefined;
onModelLoaded: () => void;
onLoadError: (msg: string) => void;
onLoadStart: () => void;
disabled?: boolean;
}
function formatDate(iso: string): string {
try {
return new Date(iso).toLocaleString(undefined, {
month: 'short',
day: 'numeric',
hour: '2-digit',
minute: '2-digit',
});
} catch {
return iso;
}
}
export default function ModelSelector({
currentModelVersion,
onModelLoaded,
onLoadError,
onLoadStart,
disabled = false,
}: ModelSelectorProps) {
const [runs, setRuns] = useState<TrainingRun[]>([]);
const [isLoadingRuns, setIsLoadingRuns] = useState(false);
const [isLoadingModel, setIsLoadingModel] = useState(false);
const [selectedRunId, setSelectedRunId] = useState<string>('');
const fetchRuns = useCallback(async () => {
setIsLoadingRuns(true);
try {
const res = await fetch('/api/training/runs');
if (!res.ok) return;
const data = await res.json();
const runList: TrainingRun[] = (data.runs || data).filter(
(r: TrainingRun) => r.status === 'completed'
);
setRuns(runList);
} catch {
// ignore
} finally {
setIsLoadingRuns(false);
}
}, []);
useEffect(() => {
fetchRuns();
}, [fetchRuns]);
const handleChange = async (runId: string) => {
if (!runId || isLoadingModel) return;
setSelectedRunId(runId);
setIsLoadingModel(true);
onLoadStart();
try {
const res = await fetch('/api/model/load', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ run_id: runId }),
});
if (!res.ok) {
const data = await res.json();
throw new Error(data.error || data.detail || 'Failed to load model');
}
onModelLoaded();
} catch (e) {
setSelectedRunId(''); // revert selection
onLoadError(e instanceof Error ? e.message : 'Failed to load model');
} finally {
setIsLoadingModel(false);
}
};
if (isLoadingRuns) {
return <p className="text-[10px] text-muted-foreground">Loading models...</p>;
}
if (runs.length === 0) {
return (
<p className="text-[10px] text-muted-foreground italic">No trained models available</p>
);
}
return (
<div>
<label className="text-[10px] text-muted-foreground block mb-0.5">Switch Model</label>
<select
value={selectedRunId}
onChange={(e) => handleChange(e.target.value)}
disabled={disabled || isLoadingModel}
className="w-full text-[10px] bg-secondary text-foreground rounded px-1.5 py-1 border-0 outline-none disabled:opacity-50"
>
<option value="">
{isLoadingModel ? 'Loading...' : '— select a model —'}
</option>
{runs.map((run) => {
const f1 = run.metrics_summary?.f1_macro;
const label = [
run.model_type.replace('_', ' '),
formatDate(run.created_at),
f1 !== undefined ? `F1:${(f1 * 100).toFixed(0)}%` : null,
]
.filter(Boolean)
.join(' · ');
return (
<option key={run.run_id} value={run.run_id}>
{label}
</option>
);
})}
</select>
</div>
);
}

View file

@ -3,6 +3,7 @@
import { useState } from 'react';
import { ChevronDown } from 'lucide-react';
import type { PredictionState, ModelInfoResponse, PredictionSummary } from '@/types/predictions';
import ModelSelector from '@/components/ModelSelector';
interface PredictionPanelProps {
predictionState: PredictionState;
@ -15,6 +16,9 @@ interface PredictionPanelProps {
isModelOnline: boolean;
showOnlyDisagreements?: boolean;
onToggleShowOnlyDisagreements?: () => void;
onModelLoaded?: () => void;
onModelLoadError?: (msg: string) => void;
onModelLoadStart?: () => void;
}
export default function PredictionPanel({
@ -28,8 +32,12 @@ export default function PredictionPanel({
isModelOnline,
showOnlyDisagreements = false,
onToggleShowOnlyDisagreements,
onModelLoaded,
onModelLoadError,
onModelLoadStart,
}: PredictionPanelProps) {
const [expanded, setExpanded] = useState(true);
const [modelLoadError, setModelLoadError] = useState<string | null>(null);
const {
visible,
@ -86,6 +94,28 @@ export default function PredictionPanel({
</div>
)}
{/* Model Selector */}
{onModelLoaded && (
<div>
<ModelSelector
currentModelVersion={modelInfo?.model_version}
onModelLoaded={() => {
setModelLoadError(null);
onModelLoaded();
}}
onLoadError={(msg) => {
setModelLoadError(msg);
onModelLoadError?.(msg);
}}
onLoadStart={() => onModelLoadStart?.()}
disabled={isLoading}
/>
{modelLoadError && (
<p className="text-[10px] text-destructive mt-0.5">{modelLoadError}</p>
)}
</div>
)}
{/* Action Buttons */}
<div className="flex gap-1">
<button

View file

@ -0,0 +1,305 @@
'use client';
import { useState, useEffect, useCallback } from 'react';
import { ChevronDown, Trash2 } from 'lucide-react';
interface PatternInfo {
function_name: string;
display_name: string;
}
interface DetectionResult {
label: string;
count: number;
}
interface TalibPatternPanelProps {
activeChartId: number | null;
onAnnotationsChanged: () => void;
getCandles: () => any[] | undefined;
}
export default function TalibPatternPanel({
activeChartId,
onAnnotationsChanged,
getCandles,
}: TalibPatternPanelProps) {
const [expanded, setExpanded] = useState(false);
const [patterns, setPatterns] = useState<PatternInfo[]>([]);
const [selectedPatterns, setSelectedPatterns] = useState<Set<string>>(new Set());
const [isLoading, setIsLoading] = useState(false);
const [isFetchingPatterns, setIsFetchingPatterns] = useState(false);
const [error, setError] = useState<string | null>(null);
const [results, setResults] = useState<DetectionResult[]>([]);
const [deletingLabel, setDeletingLabel] = useState<string | null>(null);
const [isDeletingAll, setIsDeletingAll] = useState(false);
// Fetch available patterns when panel expands
useEffect(() => {
if (expanded && patterns.length === 0) {
setIsFetchingPatterns(true);
fetch('/api/patterns/available')
.then((r) => r.json())
.then((data) => {
const list: PatternInfo[] = data.patterns || data;
setPatterns(list);
})
.catch(() => setError('Failed to load patterns'))
.finally(() => setIsFetchingPatterns(false));
}
}, [expanded, patterns.length]);
const handleSelectAll = () => {
setSelectedPatterns(new Set(patterns.map((p) => p.function_name)));
};
const handleDeselectAll = () => {
setSelectedPatterns(new Set());
};
const handleTogglePattern = (fn: string) => {
setSelectedPatterns((prev) => {
const next = new Set(prev);
if (next.has(fn)) next.delete(fn);
else next.add(fn);
return next;
});
};
const handleDetect = useCallback(async () => {
if (!activeChartId) {
setError('Load a chart first');
return;
}
const candles = getCandles();
if (!candles || candles.length === 0) {
setError('No candle data available');
return;
}
if (selectedPatterns.size === 0) return;
setIsLoading(true);
setError(null);
setResults([]);
try {
// Detect patterns
const detectRes = await fetch('/api/patterns/detect', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
candles,
patterns: Array.from(selectedPatterns),
}),
});
if (!detectRes.ok) {
const err = await detectRes.json();
throw new Error(err.detail || err.error || 'Detection failed');
}
const detectData = await detectRes.json();
const annotations: any[] = detectData.annotations || [];
if (annotations.length === 0) {
setResults([]);
return;
}
// Save each annotation via POST /api/span-annotations
await Promise.all(
annotations.map((ann: any) =>
fetch('/api/span-annotations', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
chart_id: activeChartId,
start_time: ann.start_time,
end_time: ann.end_time,
label: ann.label,
confidence: ann.confidence ?? null,
source: 'talib',
}),
})
)
);
// Build results summary
const counts: Record<string, number> = {};
for (const ann of annotations) {
counts[ann.label] = (counts[ann.label] || 0) + 1;
}
setResults(
Object.entries(counts).map(([label, count]) => ({ label, count }))
);
onAnnotationsChanged();
} catch (e) {
setError(e instanceof Error ? e.message : 'Detection failed');
} finally {
setIsLoading(false);
}
}, [activeChartId, getCandles, selectedPatterns, onAnnotationsChanged]);
const handleClearAll = async () => {
if (!activeChartId) return;
setIsDeletingAll(true);
try {
await fetch(
`/api/span-annotations?source=talib&chartId=${activeChartId}`,
{ method: 'DELETE' }
);
setResults([]);
onAnnotationsChanged();
} catch {
setError('Failed to clear TA-Lib annotations');
} finally {
setIsDeletingAll(false);
}
};
const handleDeleteByLabel = async (label: string) => {
if (!activeChartId) return;
setDeletingLabel(label);
try {
await fetch(
`/api/span-annotations?source=talib&label=${encodeURIComponent(label)}&chartId=${activeChartId}`,
{ method: 'DELETE' }
);
setResults((prev) => prev.filter((r) => r.label !== label));
onAnnotationsChanged();
} catch {
setError('Failed to delete pattern annotations');
} finally {
setDeletingLabel(null);
}
};
const totalDetected = results.reduce((s, r) => s + r.count, 0);
return (
<div>
<button
onClick={() => setExpanded(!expanded)}
className="w-full flex items-center justify-between py-1"
>
<span className="text-[10px] font-medium text-muted-foreground uppercase tracking-wider">
TA-Lib Patterns
</span>
<ChevronDown
className={`w-3 h-3 text-muted-foreground transition-transform ${expanded ? 'rotate-180' : ''}`}
/>
</button>
{expanded && (
<div className="mt-1 space-y-2">
{isFetchingPatterns && (
<p className="text-[10px] text-muted-foreground">Loading patterns...</p>
)}
{!isFetchingPatterns && patterns.length > 0 && (
<>
{/* Select All / Deselect All */}
<div className="flex gap-1">
<button
onClick={handleSelectAll}
className="flex-1 px-1.5 py-0.5 text-[10px] bg-secondary text-secondary-foreground rounded hover:opacity-80"
>
Select All
</button>
<button
onClick={handleDeselectAll}
className="flex-1 px-1.5 py-0.5 text-[10px] bg-secondary text-secondary-foreground rounded hover:opacity-80"
>
Deselect All
</button>
</div>
{/* Pattern checkboxes */}
<div className="max-h-40 overflow-y-auto scrollbar-thin space-y-0.5 pr-1">
{patterns.map((p) => (
<label
key={p.function_name}
className="flex items-center gap-1.5 p-0.5 rounded hover:bg-secondary/50 cursor-pointer"
>
<input
type="checkbox"
checked={selectedPatterns.has(p.function_name)}
onChange={() => handleTogglePattern(p.function_name)}
className="w-3 h-3"
/>
<span className="text-[10px] text-foreground truncate">{p.display_name}</span>
</label>
))}
</div>
{/* Detect button */}
<button
onClick={handleDetect}
disabled={isLoading || selectedPatterns.size === 0 || !activeChartId}
className="w-full px-2 py-1.5 text-[10px] bg-primary text-primary-foreground rounded hover:opacity-90 transition-opacity disabled:opacity-50"
>
{isLoading
? 'Detecting...'
: selectedPatterns.size === 0
? 'Detect Patterns'
: `Detect Patterns (${selectedPatterns.size} selected)`}
</button>
</>
)}
{/* Error */}
{error && (
<p className="text-[10px] text-destructive">{error}</p>
)}
{/* Results summary */}
{results.length > 0 && (
<div className="p-2 bg-secondary/30 rounded space-y-1">
<div className="flex items-center justify-between">
<span className="text-[10px] text-muted-foreground">
Found: {totalDetected} pattern{totalDetected !== 1 ? 's' : ''}
</span>
<button
onClick={handleClearAll}
disabled={isDeletingAll}
className="text-[10px] text-destructive hover:opacity-80 disabled:opacity-50"
title="Clear all TA-Lib annotations"
>
{isDeletingAll ? 'Clearing...' : 'Clear All'}
</button>
</div>
<div className="space-y-0.5">
{results.map((r) => (
<div key={r.label} className="flex items-center justify-between gap-1">
<span className="text-[10px] text-foreground truncate flex-1">{r.label}</span>
<span className="text-[10px] font-mono text-muted-foreground">{r.count}</span>
<button
onClick={() => handleDeleteByLabel(r.label)}
disabled={deletingLabel === r.label}
className="text-muted-foreground hover:text-destructive transition-colors disabled:opacity-50"
title={`Delete all ${r.label} annotations`}
>
<Trash2 className="w-2.5 h-2.5" />
</button>
</div>
))}
</div>
</div>
)}
{/* Clear All button when results are empty but TA-Lib annotations may exist */}
{results.length === 0 && !isLoading && (
<button
onClick={handleClearAll}
disabled={isDeletingAll || !activeChartId}
className="w-full px-2 py-1 text-[10px] text-destructive border border-destructive/30 rounded hover:bg-destructive/10 transition-colors disabled:opacity-50"
>
{isDeletingAll ? 'Clearing...' : 'Clear All TA-Lib'}
</button>
)}
</div>
)}
</div>
);
}

View file

@ -0,0 +1,329 @@
'use client';
import { useState, useEffect, useRef, useCallback } from 'react';
import { ChevronDown } from 'lucide-react';
interface DatasetInfo {
path: string;
exists: boolean;
size_bytes?: number;
last_modified?: string;
row_count?: number;
}
interface TrainingRun {
run_id: string;
model_type: string;
status: 'running' | 'completed' | 'failed';
created_at: string;
completed_at?: string;
metrics_summary?: {
accuracy?: number;
f1_macro?: number;
[key: string]: number | undefined;
};
error?: string;
}
const MODEL_TYPES = [
{ value: 'random_forest', label: 'Random Forest' },
{ value: 'xgboost', label: 'XGBoost' },
];
function formatDate(iso: string): string {
try {
return new Date(iso).toLocaleString(undefined, {
month: 'short',
day: 'numeric',
hour: '2-digit',
minute: '2-digit',
});
} catch {
return iso;
}
}
function StatusBadge({ status }: { status: string }) {
const colors: Record<string, string> = {
running: 'bg-blue-500/20 text-blue-500',
completed: 'bg-green-500/20 text-green-600',
failed: 'bg-red-500/20 text-destructive',
};
return (
<span className={`px-1 py-0.5 rounded text-[9px] font-medium ${colors[status] || 'bg-secondary text-muted-foreground'}`}>
{status}
</span>
);
}
export default function TrainingPanel() {
const [expanded, setExpanded] = useState(false);
const [modelType, setModelType] = useState('random_forest');
const [datasetInfo, setDatasetInfo] = useState<DatasetInfo | null>(null);
const [runs, setRuns] = useState<TrainingRun[]>([]);
const [isTraining, setIsTraining] = useState(false);
const [activeRunId, setActiveRunId] = useState<string | null>(null);
const [statusMessage, setStatusMessage] = useState<{ type: 'success' | 'error'; text: string } | null>(null);
const [isLoadingDataset, setIsLoadingDataset] = useState(false);
const [isLoadingRuns, setIsLoadingRuns] = useState(false);
const pollIntervalRef = useRef<NodeJS.Timeout | null>(null);
const fetchDatasetInfo = useCallback(async () => {
setIsLoadingDataset(true);
try {
const res = await fetch('/api/training/dataset-info');
if (res.ok) setDatasetInfo(await res.json());
} catch {
// ignore
} finally {
setIsLoadingDataset(false);
}
}, []);
const fetchRuns = useCallback(async (): Promise<TrainingRun[]> => {
try {
const res = await fetch('/api/training/runs');
if (!res.ok) return [];
const data = await res.json();
const runList: TrainingRun[] = data.runs || data;
setRuns(runList.slice(0, 5));
return runList;
} catch {
return [];
}
}, []);
// Load data when panel expands
useEffect(() => {
if (expanded) {
fetchDatasetInfo();
setIsLoadingRuns(true);
fetchRuns().finally(() => setIsLoadingRuns(false));
}
}, [expanded, fetchDatasetInfo, fetchRuns]);
// Check if there's an active training run on expand
useEffect(() => {
if (expanded && runs.length > 0) {
const running = runs.find((r) => r.status === 'running');
if (running && !activeRunId) {
setActiveRunId(running.run_id);
setIsTraining(true);
}
}
}, [expanded, runs, activeRunId]);
// Poll while training is active
useEffect(() => {
if (isTraining && activeRunId) {
pollIntervalRef.current = setInterval(async () => {
const updatedRuns = await fetchRuns();
const activeRun = updatedRuns.find((r) => r.run_id === activeRunId);
if (activeRun && activeRun.status !== 'running') {
setIsTraining(false);
setActiveRunId(null);
if (pollIntervalRef.current) clearInterval(pollIntervalRef.current);
if (activeRun.status === 'completed') {
const metrics = activeRun.metrics_summary;
const metricStr = metrics
? Object.entries(metrics)
.map(([k, v]) => `${k}: ${typeof v === 'number' ? (v * 100).toFixed(1) + '%' : v}`)
.join(', ')
: '';
setStatusMessage({
type: 'success',
text: `Training complete!${metricStr ? ' ' + metricStr : ''}`,
});
} else {
setStatusMessage({
type: 'error',
text: activeRun.error || 'Training failed',
});
}
}
}, 5000);
} else {
if (pollIntervalRef.current) {
clearInterval(pollIntervalRef.current);
pollIntervalRef.current = null;
}
}
return () => {
if (pollIntervalRef.current) clearInterval(pollIntervalRef.current);
};
}, [isTraining, activeRunId, fetchRuns]);
const handleStartTraining = async () => {
setStatusMessage(null);
setIsTraining(true);
try {
const res = await fetch('/api/training/start', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ model_type: modelType }),
});
if (res.status === 409) {
const data = await res.json();
setActiveRunId(data.run_id || null);
// Already training keep isTraining true and let poll handle it
return;
}
if (!res.ok) {
const data = await res.json();
throw new Error(data.error || data.detail || 'Failed to start training');
}
const data = await res.json();
setActiveRunId(data.run_id);
// Refresh runs list
await fetchRuns();
} catch (e) {
setIsTraining(false);
setStatusMessage({
type: 'error',
text: e instanceof Error ? e.message : 'Failed to start training',
});
}
};
const datasetMissing = datasetInfo !== null && !datasetInfo.exists;
const canTrain = !isTraining && !datasetMissing && datasetInfo !== null;
return (
<div>
<button
onClick={() => setExpanded(!expanded)}
className="w-full flex items-center justify-between py-1"
>
<div className="flex items-center gap-1.5">
{isTraining && (
<div className="w-1.5 h-1.5 rounded-full bg-blue-500 animate-pulse" />
)}
<span className="text-[10px] font-medium text-muted-foreground uppercase tracking-wider">
Training
</span>
</div>
<ChevronDown
className={`w-3 h-3 text-muted-foreground transition-transform ${expanded ? 'rotate-180' : ''}`}
/>
</button>
{expanded && (
<div className="mt-1 space-y-2">
{/* Dataset info */}
<div className="p-2 bg-secondary/30 rounded text-[10px] space-y-0.5">
{isLoadingDataset ? (
<p className="text-muted-foreground">Checking dataset...</p>
) : datasetInfo === null ? (
<p className="text-muted-foreground">Dataset info unavailable</p>
) : datasetInfo.exists ? (
<>
<div className="flex justify-between">
<span className="text-muted-foreground">Dataset:</span>
<span className="text-green-600 font-mono">Ready</span>
</div>
{datasetInfo.row_count !== undefined && (
<div className="flex justify-between">
<span className="text-muted-foreground">Rows:</span>
<span className="font-mono text-foreground">{datasetInfo.row_count.toLocaleString()}</span>
</div>
)}
</>
) : (
<p className="text-orange-500">
No training dataset found. Export annotations first.
</p>
)}
</div>
{/* Model type selector */}
<div>
<label className="text-[10px] text-muted-foreground block mb-0.5">Model Type</label>
<select
value={modelType}
onChange={(e) => setModelType(e.target.value)}
disabled={isTraining}
className="w-full text-[10px] bg-secondary text-foreground rounded px-1.5 py-1 border-0 outline-none disabled:opacity-50"
>
{MODEL_TYPES.map((m) => (
<option key={m.value} value={m.value}>
{m.label}
</option>
))}
</select>
</div>
{/* Start Training button */}
<button
onClick={handleStartTraining}
disabled={!canTrain}
className="w-full px-2 py-1.5 text-[10px] bg-primary text-primary-foreground rounded hover:opacity-90 transition-opacity disabled:opacity-50"
>
{isTraining ? (
<span className="flex items-center justify-center gap-1.5">
<span className="inline-block w-2 h-2 rounded-full border border-primary-foreground border-t-transparent animate-spin" />
Training in progress...
</span>
) : (
'Start Training'
)}
</button>
{/* Status message */}
{statusMessage && (
<p
className={`text-[10px] ${
statusMessage.type === 'success' ? 'text-green-600' : 'text-destructive'
}`}
>
{statusMessage.text}
</p>
)}
{/* Training run history */}
<div>
<p className="text-[10px] text-muted-foreground mb-1">Recent runs</p>
{isLoadingRuns ? (
<p className="text-[10px] text-muted-foreground">Loading...</p>
) : runs.length === 0 ? (
<p className="text-[10px] text-muted-foreground">No training runs yet</p>
) : (
<div className="space-y-1">
{runs.map((run) => (
<div key={run.run_id} className="p-1.5 bg-secondary/30 rounded text-[10px] space-y-0.5">
<div className="flex items-center justify-between gap-1">
<span className="text-foreground font-medium capitalize">
{run.model_type.replace('_', ' ')}
</span>
<StatusBadge status={run.status} />
</div>
<div className="text-muted-foreground">{formatDate(run.created_at)}</div>
{run.status === 'completed' && run.metrics_summary && (
<div className="flex flex-wrap gap-x-2 text-muted-foreground">
{run.metrics_summary.accuracy !== undefined && (
<span>Acc: {(run.metrics_summary.accuracy * 100).toFixed(1)}%</span>
)}
{run.metrics_summary.f1_macro !== undefined && (
<span>F1: {(run.metrics_summary.f1_macro * 100).toFixed(1)}%</span>
)}
</div>
)}
{run.status === 'failed' && run.error && (
<p className="text-destructive truncate" title={run.error}>
{run.error}
</p>
)}
</div>
))}
</div>
)}
</div>
</div>
)}
</div>
);
}

File diff suppressed because one or more lines are too long