feat: add TalibPatternPanel, TrainingPanel, ModelSelector UI components (tasks 5-8)
- TalibPatternPanel: pattern checkboxes, detect button, results summary, clear-all and per-pattern delete - TrainingPanel: model type selector, dataset info, start training, polling, run history - ModelSelector: dropdown of completed runs, wired into PredictionPanel for model switching - page.tsx: integrate all three panels into sidebar, wire callbacks (model load, annotations refresh) - tasks.md: mark all 39 tasks complete
This commit is contained in:
parent
2a02669222
commit
12a9603fce
7 changed files with 849 additions and 23 deletions
|
|
@ -29,34 +29,34 @@
|
|||
|
||||
## 5. TA-Lib Pattern UI Panel
|
||||
|
||||
- [ ] 5.1 Create `TalibPatternPanel` component with collapsible section, fetching available patterns from `/api/patterns/available` on mount
|
||||
- [ ] 5.2 Add pattern checkboxes grouped by category with "Select All" / "Deselect All" toggle
|
||||
- [ ] 5.3 Add "Detect Patterns" button that sends selected patterns + chart candles to `/api/patterns/detect`, shows loading state
|
||||
- [ ] 5.4 Save detection results as span annotations via `POST /api/span-annotations` with `source: "talib"` and refresh the annotation list
|
||||
- [ ] 5.5 Add detection results summary showing pattern counts grouped by name
|
||||
- [ ] 5.6 Add "Clear All TA-Lib" bulk delete button calling `DELETE /api/span-annotations?source=talib&chartId=X`
|
||||
- [ ] 5.7 Add per-pattern-type delete in results summary
|
||||
- [x] 5.1 Create `TalibPatternPanel` component with collapsible section, fetching available patterns from `/api/patterns/available` on mount
|
||||
- [x] 5.2 Add pattern checkboxes grouped by category with "Select All" / "Deselect All" toggle
|
||||
- [x] 5.3 Add "Detect Patterns" button that sends selected patterns + chart candles to `/api/patterns/detect`, shows loading state
|
||||
- [x] 5.4 Save detection results as span annotations via `POST /api/span-annotations` with `source: "talib"` and refresh the annotation list
|
||||
- [x] 5.5 Add detection results summary showing pattern counts grouped by name
|
||||
- [x] 5.6 Add "Clear All TA-Lib" bulk delete button calling `DELETE /api/span-annotations?source=talib&chartId=X`
|
||||
- [x] 5.7 Add per-pattern-type delete in results summary
|
||||
|
||||
## 6. Training UI Panel
|
||||
|
||||
- [ ] 6.1 Create `TrainingPanel` component with collapsible section
|
||||
- [ ] 6.2 Add model type dropdown (Random Forest / XGBoost) defaulting to Random Forest
|
||||
- [ ] 6.3 Add dataset info display fetching from `/api/training/dataset-info`, showing warning if missing
|
||||
- [ ] 6.4 Add "Start Training" button with loading state, disabled when training active or dataset missing
|
||||
- [ ] 6.5 Add training status polling (5s interval) while a run is active, showing progress indicator
|
||||
- [ ] 6.6 Add training completion/failure handling with success message and metrics display
|
||||
- [ ] 6.7 Add training run history list fetching from `/api/training/runs`, showing 5 most recent runs with model type, status, date, metrics
|
||||
- [x] 6.1 Create `TrainingPanel` component with collapsible section
|
||||
- [x] 6.2 Add model type dropdown (Random Forest / XGBoost) defaulting to Random Forest
|
||||
- [x] 6.3 Add dataset info display fetching from `/api/training/dataset-info`, showing warning if missing
|
||||
- [x] 6.4 Add "Start Training" button with loading state, disabled when training active or dataset missing
|
||||
- [x] 6.5 Add training status polling (5s interval) while a run is active, showing progress indicator
|
||||
- [x] 6.6 Add training completion/failure handling with success message and metrics display
|
||||
- [x] 6.7 Add training run history list fetching from `/api/training/runs`, showing 5 most recent runs with model type, status, date, metrics
|
||||
|
||||
## 7. Model Selector Integration
|
||||
|
||||
- [ ] 7.1 Create `ModelSelector` dropdown component fetching completed training runs from `/api/training/runs`
|
||||
- [ ] 7.2 Integrate `ModelSelector` into `PredictionPanel` above action buttons, showing current model as active
|
||||
- [ ] 7.3 Wire model switch: on selection call `POST /api/model/load`, clear prediction cache, refresh model info
|
||||
- [ ] 7.4 Handle model load errors: show error toast, keep previous model active
|
||||
- [x] 7.1 Create `ModelSelector` dropdown component fetching completed training runs from `/api/training/runs`
|
||||
- [x] 7.2 Integrate `ModelSelector` into `PredictionPanel` above action buttons, showing current model as active
|
||||
- [x] 7.3 Wire model switch: on selection call `POST /api/model/load`, clear prediction cache, refresh model info
|
||||
- [x] 7.4 Handle model load errors: show error toast, keep previous model active
|
||||
|
||||
## 8. Sidebar Layout Integration
|
||||
|
||||
- [ ] 8.1 Add `TalibPatternPanel` to the sidebar in `page.tsx` between SpanAnnotationList and PredictionPanel
|
||||
- [ ] 8.2 Add `TrainingPanel` to the sidebar between TalibPatternPanel and PredictionPanel
|
||||
- [ ] 8.3 Make TalibPatternPanel, TrainingPanel, and PredictionPanel collapsible (default collapsed for new panels)
|
||||
- [ ] 8.4 Wire all new component state and callbacks in `page.tsx`
|
||||
- [x] 8.1 Add `TalibPatternPanel` to the sidebar in `page.tsx` between SpanAnnotationList and PredictionPanel
|
||||
- [x] 8.2 Add `TrainingPanel` to the sidebar between TalibPatternPanel and PredictionPanel
|
||||
- [x] 8.3 Make TalibPatternPanel, TrainingPanel, and PredictionPanel collapsible (default collapsed for new panels)
|
||||
- [x] 8.4 Wire all new component state and callbacks in `page.tsx`
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ import CandleChart, { CandleChartHandle } from '@/components/CandleChart';
|
|||
import ChartSelector from '@/components/ChartSelector';
|
||||
import PredictionPanel from '@/components/PredictionPanel';
|
||||
import SpanAnnotationList from '@/components/SpanAnnotationList';
|
||||
import TalibPatternPanel from '@/components/TalibPatternPanel';
|
||||
import TrainingPanel from '@/components/TrainingPanel';
|
||||
import { ThemeToggle } from '@/components/ThemeToggle';
|
||||
import type { PredictionState, PredictionSpan, ModelInfoResponse, Disagreement, DisagreementType, PredictionSummary } from '@/types/predictions';
|
||||
|
||||
|
|
@ -566,6 +568,13 @@ export default function Home() {
|
|||
}
|
||||
}, [fetchPredictions]);
|
||||
|
||||
// Handle model loaded via ModelSelector: refresh model info and clear prediction cache
|
||||
const handleModelLoaded = useCallback(async () => {
|
||||
predictionCacheRef.current = new Map();
|
||||
setPredictionState((prev) => ({ ...prev, spans: [], perCandlePredictions: [], visible: false }));
|
||||
await fetchModelInfo();
|
||||
}, [fetchModelInfo]);
|
||||
|
||||
// Handle batch prediction for all candles
|
||||
const handleFetchBatchPredictions = useCallback(async () => {
|
||||
if (!activeChartId) return;
|
||||
|
|
@ -758,6 +767,20 @@ export default function Home() {
|
|||
/>
|
||||
</div>
|
||||
|
||||
{/* TA-Lib Pattern Panel */}
|
||||
<div className="px-3 py-2 border-t border-sidebar-border">
|
||||
<TalibPatternPanel
|
||||
activeChartId={activeChartId}
|
||||
onAnnotationsChanged={() => fetchSpanAnnotations(activeChartId)}
|
||||
getCandles={() => chartRef.current?.getVisibleCandles()}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Training Panel */}
|
||||
<div className="px-3 py-2 border-t border-sidebar-border">
|
||||
<TrainingPanel />
|
||||
</div>
|
||||
|
||||
{/* Predictions */}
|
||||
<div className="px-3 py-2 border-t border-sidebar-border">
|
||||
<PredictionPanel
|
||||
|
|
@ -771,6 +794,7 @@ export default function Home() {
|
|||
isModelOnline={isModelOnline}
|
||||
showOnlyDisagreements={showOnlyDisagreements}
|
||||
onToggleShowOnlyDisagreements={toggleShowOnlyDisagreements}
|
||||
onModelLoaded={handleModelLoaded}
|
||||
/>
|
||||
</div>
|
||||
|
||||
|
|
|
|||
138
src/components/ModelSelector.tsx
Normal file
138
src/components/ModelSelector.tsx
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
'use client';
|
||||
|
||||
import { useState, useEffect, useCallback } from 'react';
|
||||
|
||||
interface TrainingRun {
|
||||
run_id: string;
|
||||
model_type: string;
|
||||
status: string;
|
||||
created_at: string;
|
||||
metrics_summary?: {
|
||||
f1_macro?: number;
|
||||
accuracy?: number;
|
||||
[key: string]: number | undefined;
|
||||
};
|
||||
}
|
||||
|
||||
interface ModelSelectorProps {
|
||||
currentModelVersion: string | null | undefined;
|
||||
onModelLoaded: () => void;
|
||||
onLoadError: (msg: string) => void;
|
||||
onLoadStart: () => void;
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
function formatDate(iso: string): string {
|
||||
try {
|
||||
return new Date(iso).toLocaleString(undefined, {
|
||||
month: 'short',
|
||||
day: 'numeric',
|
||||
hour: '2-digit',
|
||||
minute: '2-digit',
|
||||
});
|
||||
} catch {
|
||||
return iso;
|
||||
}
|
||||
}
|
||||
|
||||
export default function ModelSelector({
|
||||
currentModelVersion,
|
||||
onModelLoaded,
|
||||
onLoadError,
|
||||
onLoadStart,
|
||||
disabled = false,
|
||||
}: ModelSelectorProps) {
|
||||
const [runs, setRuns] = useState<TrainingRun[]>([]);
|
||||
const [isLoadingRuns, setIsLoadingRuns] = useState(false);
|
||||
const [isLoadingModel, setIsLoadingModel] = useState(false);
|
||||
const [selectedRunId, setSelectedRunId] = useState<string>('');
|
||||
|
||||
const fetchRuns = useCallback(async () => {
|
||||
setIsLoadingRuns(true);
|
||||
try {
|
||||
const res = await fetch('/api/training/runs');
|
||||
if (!res.ok) return;
|
||||
const data = await res.json();
|
||||
const runList: TrainingRun[] = (data.runs || data).filter(
|
||||
(r: TrainingRun) => r.status === 'completed'
|
||||
);
|
||||
setRuns(runList);
|
||||
} catch {
|
||||
// ignore
|
||||
} finally {
|
||||
setIsLoadingRuns(false);
|
||||
}
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
fetchRuns();
|
||||
}, [fetchRuns]);
|
||||
|
||||
const handleChange = async (runId: string) => {
|
||||
if (!runId || isLoadingModel) return;
|
||||
setSelectedRunId(runId);
|
||||
setIsLoadingModel(true);
|
||||
onLoadStart();
|
||||
|
||||
try {
|
||||
const res = await fetch('/api/model/load', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ run_id: runId }),
|
||||
});
|
||||
|
||||
if (!res.ok) {
|
||||
const data = await res.json();
|
||||
throw new Error(data.error || data.detail || 'Failed to load model');
|
||||
}
|
||||
|
||||
onModelLoaded();
|
||||
} catch (e) {
|
||||
setSelectedRunId(''); // revert selection
|
||||
onLoadError(e instanceof Error ? e.message : 'Failed to load model');
|
||||
} finally {
|
||||
setIsLoadingModel(false);
|
||||
}
|
||||
};
|
||||
|
||||
if (isLoadingRuns) {
|
||||
return <p className="text-[10px] text-muted-foreground">Loading models...</p>;
|
||||
}
|
||||
|
||||
if (runs.length === 0) {
|
||||
return (
|
||||
<p className="text-[10px] text-muted-foreground italic">No trained models available</p>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div>
|
||||
<label className="text-[10px] text-muted-foreground block mb-0.5">Switch Model</label>
|
||||
<select
|
||||
value={selectedRunId}
|
||||
onChange={(e) => handleChange(e.target.value)}
|
||||
disabled={disabled || isLoadingModel}
|
||||
className="w-full text-[10px] bg-secondary text-foreground rounded px-1.5 py-1 border-0 outline-none disabled:opacity-50"
|
||||
>
|
||||
<option value="">
|
||||
{isLoadingModel ? 'Loading...' : '— select a model —'}
|
||||
</option>
|
||||
{runs.map((run) => {
|
||||
const f1 = run.metrics_summary?.f1_macro;
|
||||
const label = [
|
||||
run.model_type.replace('_', ' '),
|
||||
formatDate(run.created_at),
|
||||
f1 !== undefined ? `F1:${(f1 * 100).toFixed(0)}%` : null,
|
||||
]
|
||||
.filter(Boolean)
|
||||
.join(' · ');
|
||||
return (
|
||||
<option key={run.run_id} value={run.run_id}>
|
||||
{label}
|
||||
</option>
|
||||
);
|
||||
})}
|
||||
</select>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -3,6 +3,7 @@
|
|||
import { useState } from 'react';
|
||||
import { ChevronDown } from 'lucide-react';
|
||||
import type { PredictionState, ModelInfoResponse, PredictionSummary } from '@/types/predictions';
|
||||
import ModelSelector from '@/components/ModelSelector';
|
||||
|
||||
interface PredictionPanelProps {
|
||||
predictionState: PredictionState;
|
||||
|
|
@ -15,6 +16,9 @@ interface PredictionPanelProps {
|
|||
isModelOnline: boolean;
|
||||
showOnlyDisagreements?: boolean;
|
||||
onToggleShowOnlyDisagreements?: () => void;
|
||||
onModelLoaded?: () => void;
|
||||
onModelLoadError?: (msg: string) => void;
|
||||
onModelLoadStart?: () => void;
|
||||
}
|
||||
|
||||
export default function PredictionPanel({
|
||||
|
|
@ -28,8 +32,12 @@ export default function PredictionPanel({
|
|||
isModelOnline,
|
||||
showOnlyDisagreements = false,
|
||||
onToggleShowOnlyDisagreements,
|
||||
onModelLoaded,
|
||||
onModelLoadError,
|
||||
onModelLoadStart,
|
||||
}: PredictionPanelProps) {
|
||||
const [expanded, setExpanded] = useState(true);
|
||||
const [modelLoadError, setModelLoadError] = useState<string | null>(null);
|
||||
|
||||
const {
|
||||
visible,
|
||||
|
|
@ -86,6 +94,28 @@ export default function PredictionPanel({
|
|||
</div>
|
||||
)}
|
||||
|
||||
{/* Model Selector */}
|
||||
{onModelLoaded && (
|
||||
<div>
|
||||
<ModelSelector
|
||||
currentModelVersion={modelInfo?.model_version}
|
||||
onModelLoaded={() => {
|
||||
setModelLoadError(null);
|
||||
onModelLoaded();
|
||||
}}
|
||||
onLoadError={(msg) => {
|
||||
setModelLoadError(msg);
|
||||
onModelLoadError?.(msg);
|
||||
}}
|
||||
onLoadStart={() => onModelLoadStart?.()}
|
||||
disabled={isLoading}
|
||||
/>
|
||||
{modelLoadError && (
|
||||
<p className="text-[10px] text-destructive mt-0.5">{modelLoadError}</p>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Action Buttons */}
|
||||
<div className="flex gap-1">
|
||||
<button
|
||||
|
|
|
|||
305
src/components/TalibPatternPanel.tsx
Normal file
305
src/components/TalibPatternPanel.tsx
Normal file
|
|
@ -0,0 +1,305 @@
|
|||
'use client';
|
||||
|
||||
import { useState, useEffect, useCallback } from 'react';
|
||||
import { ChevronDown, Trash2 } from 'lucide-react';
|
||||
|
||||
interface PatternInfo {
|
||||
function_name: string;
|
||||
display_name: string;
|
||||
}
|
||||
|
||||
interface DetectionResult {
|
||||
label: string;
|
||||
count: number;
|
||||
}
|
||||
|
||||
interface TalibPatternPanelProps {
|
||||
activeChartId: number | null;
|
||||
onAnnotationsChanged: () => void;
|
||||
getCandles: () => any[] | undefined;
|
||||
}
|
||||
|
||||
export default function TalibPatternPanel({
|
||||
activeChartId,
|
||||
onAnnotationsChanged,
|
||||
getCandles,
|
||||
}: TalibPatternPanelProps) {
|
||||
const [expanded, setExpanded] = useState(false);
|
||||
const [patterns, setPatterns] = useState<PatternInfo[]>([]);
|
||||
const [selectedPatterns, setSelectedPatterns] = useState<Set<string>>(new Set());
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [isFetchingPatterns, setIsFetchingPatterns] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [results, setResults] = useState<DetectionResult[]>([]);
|
||||
const [deletingLabel, setDeletingLabel] = useState<string | null>(null);
|
||||
const [isDeletingAll, setIsDeletingAll] = useState(false);
|
||||
|
||||
// Fetch available patterns when panel expands
|
||||
useEffect(() => {
|
||||
if (expanded && patterns.length === 0) {
|
||||
setIsFetchingPatterns(true);
|
||||
fetch('/api/patterns/available')
|
||||
.then((r) => r.json())
|
||||
.then((data) => {
|
||||
const list: PatternInfo[] = data.patterns || data;
|
||||
setPatterns(list);
|
||||
})
|
||||
.catch(() => setError('Failed to load patterns'))
|
||||
.finally(() => setIsFetchingPatterns(false));
|
||||
}
|
||||
}, [expanded, patterns.length]);
|
||||
|
||||
const handleSelectAll = () => {
|
||||
setSelectedPatterns(new Set(patterns.map((p) => p.function_name)));
|
||||
};
|
||||
|
||||
const handleDeselectAll = () => {
|
||||
setSelectedPatterns(new Set());
|
||||
};
|
||||
|
||||
const handleTogglePattern = (fn: string) => {
|
||||
setSelectedPatterns((prev) => {
|
||||
const next = new Set(prev);
|
||||
if (next.has(fn)) next.delete(fn);
|
||||
else next.add(fn);
|
||||
return next;
|
||||
});
|
||||
};
|
||||
|
||||
const handleDetect = useCallback(async () => {
|
||||
if (!activeChartId) {
|
||||
setError('Load a chart first');
|
||||
return;
|
||||
}
|
||||
const candles = getCandles();
|
||||
if (!candles || candles.length === 0) {
|
||||
setError('No candle data available');
|
||||
return;
|
||||
}
|
||||
if (selectedPatterns.size === 0) return;
|
||||
|
||||
setIsLoading(true);
|
||||
setError(null);
|
||||
setResults([]);
|
||||
|
||||
try {
|
||||
// Detect patterns
|
||||
const detectRes = await fetch('/api/patterns/detect', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
candles,
|
||||
patterns: Array.from(selectedPatterns),
|
||||
}),
|
||||
});
|
||||
|
||||
if (!detectRes.ok) {
|
||||
const err = await detectRes.json();
|
||||
throw new Error(err.detail || err.error || 'Detection failed');
|
||||
}
|
||||
|
||||
const detectData = await detectRes.json();
|
||||
const annotations: any[] = detectData.annotations || [];
|
||||
|
||||
if (annotations.length === 0) {
|
||||
setResults([]);
|
||||
return;
|
||||
}
|
||||
|
||||
// Save each annotation via POST /api/span-annotations
|
||||
await Promise.all(
|
||||
annotations.map((ann: any) =>
|
||||
fetch('/api/span-annotations', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
chart_id: activeChartId,
|
||||
start_time: ann.start_time,
|
||||
end_time: ann.end_time,
|
||||
label: ann.label,
|
||||
confidence: ann.confidence ?? null,
|
||||
source: 'talib',
|
||||
}),
|
||||
})
|
||||
)
|
||||
);
|
||||
|
||||
// Build results summary
|
||||
const counts: Record<string, number> = {};
|
||||
for (const ann of annotations) {
|
||||
counts[ann.label] = (counts[ann.label] || 0) + 1;
|
||||
}
|
||||
setResults(
|
||||
Object.entries(counts).map(([label, count]) => ({ label, count }))
|
||||
);
|
||||
|
||||
onAnnotationsChanged();
|
||||
} catch (e) {
|
||||
setError(e instanceof Error ? e.message : 'Detection failed');
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}, [activeChartId, getCandles, selectedPatterns, onAnnotationsChanged]);
|
||||
|
||||
const handleClearAll = async () => {
|
||||
if (!activeChartId) return;
|
||||
setIsDeletingAll(true);
|
||||
try {
|
||||
await fetch(
|
||||
`/api/span-annotations?source=talib&chartId=${activeChartId}`,
|
||||
{ method: 'DELETE' }
|
||||
);
|
||||
setResults([]);
|
||||
onAnnotationsChanged();
|
||||
} catch {
|
||||
setError('Failed to clear TA-Lib annotations');
|
||||
} finally {
|
||||
setIsDeletingAll(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDeleteByLabel = async (label: string) => {
|
||||
if (!activeChartId) return;
|
||||
setDeletingLabel(label);
|
||||
try {
|
||||
await fetch(
|
||||
`/api/span-annotations?source=talib&label=${encodeURIComponent(label)}&chartId=${activeChartId}`,
|
||||
{ method: 'DELETE' }
|
||||
);
|
||||
setResults((prev) => prev.filter((r) => r.label !== label));
|
||||
onAnnotationsChanged();
|
||||
} catch {
|
||||
setError('Failed to delete pattern annotations');
|
||||
} finally {
|
||||
setDeletingLabel(null);
|
||||
}
|
||||
};
|
||||
|
||||
const totalDetected = results.reduce((s, r) => s + r.count, 0);
|
||||
|
||||
return (
|
||||
<div>
|
||||
<button
|
||||
onClick={() => setExpanded(!expanded)}
|
||||
className="w-full flex items-center justify-between py-1"
|
||||
>
|
||||
<span className="text-[10px] font-medium text-muted-foreground uppercase tracking-wider">
|
||||
TA-Lib Patterns
|
||||
</span>
|
||||
<ChevronDown
|
||||
className={`w-3 h-3 text-muted-foreground transition-transform ${expanded ? 'rotate-180' : ''}`}
|
||||
/>
|
||||
</button>
|
||||
|
||||
{expanded && (
|
||||
<div className="mt-1 space-y-2">
|
||||
{isFetchingPatterns && (
|
||||
<p className="text-[10px] text-muted-foreground">Loading patterns...</p>
|
||||
)}
|
||||
|
||||
{!isFetchingPatterns && patterns.length > 0 && (
|
||||
<>
|
||||
{/* Select All / Deselect All */}
|
||||
<div className="flex gap-1">
|
||||
<button
|
||||
onClick={handleSelectAll}
|
||||
className="flex-1 px-1.5 py-0.5 text-[10px] bg-secondary text-secondary-foreground rounded hover:opacity-80"
|
||||
>
|
||||
Select All
|
||||
</button>
|
||||
<button
|
||||
onClick={handleDeselectAll}
|
||||
className="flex-1 px-1.5 py-0.5 text-[10px] bg-secondary text-secondary-foreground rounded hover:opacity-80"
|
||||
>
|
||||
Deselect All
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Pattern checkboxes */}
|
||||
<div className="max-h-40 overflow-y-auto scrollbar-thin space-y-0.5 pr-1">
|
||||
{patterns.map((p) => (
|
||||
<label
|
||||
key={p.function_name}
|
||||
className="flex items-center gap-1.5 p-0.5 rounded hover:bg-secondary/50 cursor-pointer"
|
||||
>
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={selectedPatterns.has(p.function_name)}
|
||||
onChange={() => handleTogglePattern(p.function_name)}
|
||||
className="w-3 h-3"
|
||||
/>
|
||||
<span className="text-[10px] text-foreground truncate">{p.display_name}</span>
|
||||
</label>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* Detect button */}
|
||||
<button
|
||||
onClick={handleDetect}
|
||||
disabled={isLoading || selectedPatterns.size === 0 || !activeChartId}
|
||||
className="w-full px-2 py-1.5 text-[10px] bg-primary text-primary-foreground rounded hover:opacity-90 transition-opacity disabled:opacity-50"
|
||||
>
|
||||
{isLoading
|
||||
? 'Detecting...'
|
||||
: selectedPatterns.size === 0
|
||||
? 'Detect Patterns'
|
||||
: `Detect Patterns (${selectedPatterns.size} selected)`}
|
||||
</button>
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Error */}
|
||||
{error && (
|
||||
<p className="text-[10px] text-destructive">{error}</p>
|
||||
)}
|
||||
|
||||
{/* Results summary */}
|
||||
{results.length > 0 && (
|
||||
<div className="p-2 bg-secondary/30 rounded space-y-1">
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-[10px] text-muted-foreground">
|
||||
Found: {totalDetected} pattern{totalDetected !== 1 ? 's' : ''}
|
||||
</span>
|
||||
<button
|
||||
onClick={handleClearAll}
|
||||
disabled={isDeletingAll}
|
||||
className="text-[10px] text-destructive hover:opacity-80 disabled:opacity-50"
|
||||
title="Clear all TA-Lib annotations"
|
||||
>
|
||||
{isDeletingAll ? 'Clearing...' : 'Clear All'}
|
||||
</button>
|
||||
</div>
|
||||
<div className="space-y-0.5">
|
||||
{results.map((r) => (
|
||||
<div key={r.label} className="flex items-center justify-between gap-1">
|
||||
<span className="text-[10px] text-foreground truncate flex-1">{r.label}</span>
|
||||
<span className="text-[10px] font-mono text-muted-foreground">{r.count}</span>
|
||||
<button
|
||||
onClick={() => handleDeleteByLabel(r.label)}
|
||||
disabled={deletingLabel === r.label}
|
||||
className="text-muted-foreground hover:text-destructive transition-colors disabled:opacity-50"
|
||||
title={`Delete all ${r.label} annotations`}
|
||||
>
|
||||
<Trash2 className="w-2.5 h-2.5" />
|
||||
</button>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Clear All button when results are empty but TA-Lib annotations may exist */}
|
||||
{results.length === 0 && !isLoading && (
|
||||
<button
|
||||
onClick={handleClearAll}
|
||||
disabled={isDeletingAll || !activeChartId}
|
||||
className="w-full px-2 py-1 text-[10px] text-destructive border border-destructive/30 rounded hover:bg-destructive/10 transition-colors disabled:opacity-50"
|
||||
>
|
||||
{isDeletingAll ? 'Clearing...' : 'Clear All TA-Lib'}
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
329
src/components/TrainingPanel.tsx
Normal file
329
src/components/TrainingPanel.tsx
Normal file
|
|
@ -0,0 +1,329 @@
|
|||
'use client';
|
||||
|
||||
import { useState, useEffect, useRef, useCallback } from 'react';
|
||||
import { ChevronDown } from 'lucide-react';
|
||||
|
||||
interface DatasetInfo {
|
||||
path: string;
|
||||
exists: boolean;
|
||||
size_bytes?: number;
|
||||
last_modified?: string;
|
||||
row_count?: number;
|
||||
}
|
||||
|
||||
interface TrainingRun {
|
||||
run_id: string;
|
||||
model_type: string;
|
||||
status: 'running' | 'completed' | 'failed';
|
||||
created_at: string;
|
||||
completed_at?: string;
|
||||
metrics_summary?: {
|
||||
accuracy?: number;
|
||||
f1_macro?: number;
|
||||
[key: string]: number | undefined;
|
||||
};
|
||||
error?: string;
|
||||
}
|
||||
|
||||
const MODEL_TYPES = [
|
||||
{ value: 'random_forest', label: 'Random Forest' },
|
||||
{ value: 'xgboost', label: 'XGBoost' },
|
||||
];
|
||||
|
||||
function formatDate(iso: string): string {
|
||||
try {
|
||||
return new Date(iso).toLocaleString(undefined, {
|
||||
month: 'short',
|
||||
day: 'numeric',
|
||||
hour: '2-digit',
|
||||
minute: '2-digit',
|
||||
});
|
||||
} catch {
|
||||
return iso;
|
||||
}
|
||||
}
|
||||
|
||||
function StatusBadge({ status }: { status: string }) {
|
||||
const colors: Record<string, string> = {
|
||||
running: 'bg-blue-500/20 text-blue-500',
|
||||
completed: 'bg-green-500/20 text-green-600',
|
||||
failed: 'bg-red-500/20 text-destructive',
|
||||
};
|
||||
return (
|
||||
<span className={`px-1 py-0.5 rounded text-[9px] font-medium ${colors[status] || 'bg-secondary text-muted-foreground'}`}>
|
||||
{status}
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
export default function TrainingPanel() {
|
||||
const [expanded, setExpanded] = useState(false);
|
||||
const [modelType, setModelType] = useState('random_forest');
|
||||
const [datasetInfo, setDatasetInfo] = useState<DatasetInfo | null>(null);
|
||||
const [runs, setRuns] = useState<TrainingRun[]>([]);
|
||||
const [isTraining, setIsTraining] = useState(false);
|
||||
const [activeRunId, setActiveRunId] = useState<string | null>(null);
|
||||
const [statusMessage, setStatusMessage] = useState<{ type: 'success' | 'error'; text: string } | null>(null);
|
||||
const [isLoadingDataset, setIsLoadingDataset] = useState(false);
|
||||
const [isLoadingRuns, setIsLoadingRuns] = useState(false);
|
||||
const pollIntervalRef = useRef<NodeJS.Timeout | null>(null);
|
||||
|
||||
const fetchDatasetInfo = useCallback(async () => {
|
||||
setIsLoadingDataset(true);
|
||||
try {
|
||||
const res = await fetch('/api/training/dataset-info');
|
||||
if (res.ok) setDatasetInfo(await res.json());
|
||||
} catch {
|
||||
// ignore
|
||||
} finally {
|
||||
setIsLoadingDataset(false);
|
||||
}
|
||||
}, []);
|
||||
|
||||
const fetchRuns = useCallback(async (): Promise<TrainingRun[]> => {
|
||||
try {
|
||||
const res = await fetch('/api/training/runs');
|
||||
if (!res.ok) return [];
|
||||
const data = await res.json();
|
||||
const runList: TrainingRun[] = data.runs || data;
|
||||
setRuns(runList.slice(0, 5));
|
||||
return runList;
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
}, []);
|
||||
|
||||
// Load data when panel expands
|
||||
useEffect(() => {
|
||||
if (expanded) {
|
||||
fetchDatasetInfo();
|
||||
setIsLoadingRuns(true);
|
||||
fetchRuns().finally(() => setIsLoadingRuns(false));
|
||||
}
|
||||
}, [expanded, fetchDatasetInfo, fetchRuns]);
|
||||
|
||||
// Check if there's an active training run on expand
|
||||
useEffect(() => {
|
||||
if (expanded && runs.length > 0) {
|
||||
const running = runs.find((r) => r.status === 'running');
|
||||
if (running && !activeRunId) {
|
||||
setActiveRunId(running.run_id);
|
||||
setIsTraining(true);
|
||||
}
|
||||
}
|
||||
}, [expanded, runs, activeRunId]);
|
||||
|
||||
// Poll while training is active
|
||||
useEffect(() => {
|
||||
if (isTraining && activeRunId) {
|
||||
pollIntervalRef.current = setInterval(async () => {
|
||||
const updatedRuns = await fetchRuns();
|
||||
const activeRun = updatedRuns.find((r) => r.run_id === activeRunId);
|
||||
if (activeRun && activeRun.status !== 'running') {
|
||||
setIsTraining(false);
|
||||
setActiveRunId(null);
|
||||
if (pollIntervalRef.current) clearInterval(pollIntervalRef.current);
|
||||
|
||||
if (activeRun.status === 'completed') {
|
||||
const metrics = activeRun.metrics_summary;
|
||||
const metricStr = metrics
|
||||
? Object.entries(metrics)
|
||||
.map(([k, v]) => `${k}: ${typeof v === 'number' ? (v * 100).toFixed(1) + '%' : v}`)
|
||||
.join(', ')
|
||||
: '';
|
||||
setStatusMessage({
|
||||
type: 'success',
|
||||
text: `Training complete!${metricStr ? ' ' + metricStr : ''}`,
|
||||
});
|
||||
} else {
|
||||
setStatusMessage({
|
||||
type: 'error',
|
||||
text: activeRun.error || 'Training failed',
|
||||
});
|
||||
}
|
||||
}
|
||||
}, 5000);
|
||||
} else {
|
||||
if (pollIntervalRef.current) {
|
||||
clearInterval(pollIntervalRef.current);
|
||||
pollIntervalRef.current = null;
|
||||
}
|
||||
}
|
||||
|
||||
return () => {
|
||||
if (pollIntervalRef.current) clearInterval(pollIntervalRef.current);
|
||||
};
|
||||
}, [isTraining, activeRunId, fetchRuns]);
|
||||
|
||||
const handleStartTraining = async () => {
|
||||
setStatusMessage(null);
|
||||
setIsTraining(true);
|
||||
try {
|
||||
const res = await fetch('/api/training/start', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ model_type: modelType }),
|
||||
});
|
||||
|
||||
if (res.status === 409) {
|
||||
const data = await res.json();
|
||||
setActiveRunId(data.run_id || null);
|
||||
// Already training – keep isTraining true and let poll handle it
|
||||
return;
|
||||
}
|
||||
|
||||
if (!res.ok) {
|
||||
const data = await res.json();
|
||||
throw new Error(data.error || data.detail || 'Failed to start training');
|
||||
}
|
||||
|
||||
const data = await res.json();
|
||||
setActiveRunId(data.run_id);
|
||||
// Refresh runs list
|
||||
await fetchRuns();
|
||||
} catch (e) {
|
||||
setIsTraining(false);
|
||||
setStatusMessage({
|
||||
type: 'error',
|
||||
text: e instanceof Error ? e.message : 'Failed to start training',
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const datasetMissing = datasetInfo !== null && !datasetInfo.exists;
|
||||
const canTrain = !isTraining && !datasetMissing && datasetInfo !== null;
|
||||
|
||||
return (
|
||||
<div>
|
||||
<button
|
||||
onClick={() => setExpanded(!expanded)}
|
||||
className="w-full flex items-center justify-between py-1"
|
||||
>
|
||||
<div className="flex items-center gap-1.5">
|
||||
{isTraining && (
|
||||
<div className="w-1.5 h-1.5 rounded-full bg-blue-500 animate-pulse" />
|
||||
)}
|
||||
<span className="text-[10px] font-medium text-muted-foreground uppercase tracking-wider">
|
||||
Training
|
||||
</span>
|
||||
</div>
|
||||
<ChevronDown
|
||||
className={`w-3 h-3 text-muted-foreground transition-transform ${expanded ? 'rotate-180' : ''}`}
|
||||
/>
|
||||
</button>
|
||||
|
||||
{expanded && (
|
||||
<div className="mt-1 space-y-2">
|
||||
{/* Dataset info */}
|
||||
<div className="p-2 bg-secondary/30 rounded text-[10px] space-y-0.5">
|
||||
{isLoadingDataset ? (
|
||||
<p className="text-muted-foreground">Checking dataset...</p>
|
||||
) : datasetInfo === null ? (
|
||||
<p className="text-muted-foreground">Dataset info unavailable</p>
|
||||
) : datasetInfo.exists ? (
|
||||
<>
|
||||
<div className="flex justify-between">
|
||||
<span className="text-muted-foreground">Dataset:</span>
|
||||
<span className="text-green-600 font-mono">Ready</span>
|
||||
</div>
|
||||
{datasetInfo.row_count !== undefined && (
|
||||
<div className="flex justify-between">
|
||||
<span className="text-muted-foreground">Rows:</span>
|
||||
<span className="font-mono text-foreground">{datasetInfo.row_count.toLocaleString()}</span>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
) : (
|
||||
<p className="text-orange-500">
|
||||
No training dataset found. Export annotations first.
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Model type selector */}
|
||||
<div>
|
||||
<label className="text-[10px] text-muted-foreground block mb-0.5">Model Type</label>
|
||||
<select
|
||||
value={modelType}
|
||||
onChange={(e) => setModelType(e.target.value)}
|
||||
disabled={isTraining}
|
||||
className="w-full text-[10px] bg-secondary text-foreground rounded px-1.5 py-1 border-0 outline-none disabled:opacity-50"
|
||||
>
|
||||
{MODEL_TYPES.map((m) => (
|
||||
<option key={m.value} value={m.value}>
|
||||
{m.label}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
</div>
|
||||
|
||||
{/* Start Training button */}
|
||||
<button
|
||||
onClick={handleStartTraining}
|
||||
disabled={!canTrain}
|
||||
className="w-full px-2 py-1.5 text-[10px] bg-primary text-primary-foreground rounded hover:opacity-90 transition-opacity disabled:opacity-50"
|
||||
>
|
||||
{isTraining ? (
|
||||
<span className="flex items-center justify-center gap-1.5">
|
||||
<span className="inline-block w-2 h-2 rounded-full border border-primary-foreground border-t-transparent animate-spin" />
|
||||
Training in progress...
|
||||
</span>
|
||||
) : (
|
||||
'Start Training'
|
||||
)}
|
||||
</button>
|
||||
|
||||
{/* Status message */}
|
||||
{statusMessage && (
|
||||
<p
|
||||
className={`text-[10px] ${
|
||||
statusMessage.type === 'success' ? 'text-green-600' : 'text-destructive'
|
||||
}`}
|
||||
>
|
||||
{statusMessage.text}
|
||||
</p>
|
||||
)}
|
||||
|
||||
{/* Training run history */}
|
||||
<div>
|
||||
<p className="text-[10px] text-muted-foreground mb-1">Recent runs</p>
|
||||
{isLoadingRuns ? (
|
||||
<p className="text-[10px] text-muted-foreground">Loading...</p>
|
||||
) : runs.length === 0 ? (
|
||||
<p className="text-[10px] text-muted-foreground">No training runs yet</p>
|
||||
) : (
|
||||
<div className="space-y-1">
|
||||
{runs.map((run) => (
|
||||
<div key={run.run_id} className="p-1.5 bg-secondary/30 rounded text-[10px] space-y-0.5">
|
||||
<div className="flex items-center justify-between gap-1">
|
||||
<span className="text-foreground font-medium capitalize">
|
||||
{run.model_type.replace('_', ' ')}
|
||||
</span>
|
||||
<StatusBadge status={run.status} />
|
||||
</div>
|
||||
<div className="text-muted-foreground">{formatDate(run.created_at)}</div>
|
||||
{run.status === 'completed' && run.metrics_summary && (
|
||||
<div className="flex flex-wrap gap-x-2 text-muted-foreground">
|
||||
{run.metrics_summary.accuracy !== undefined && (
|
||||
<span>Acc: {(run.metrics_summary.accuracy * 100).toFixed(1)}%</span>
|
||||
)}
|
||||
{run.metrics_summary.f1_macro !== undefined && (
|
||||
<span>F1: {(run.metrics_summary.f1_macro * 100).toFixed(1)}%</span>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
{run.status === 'failed' && run.error && (
|
||||
<p className="text-destructive truncate" title={run.error}>
|
||||
{run.error}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
Loading…
Add table
Add a link
Reference in a new issue