feat(ui): add prediction state management and PredictionPanel component

- Create prediction type definitions in src/types/predictions.ts
- Add prediction state management to page.tsx with caching
- Implement PredictionPanel component with:
  - Master visibility toggle
  - Model info display (name, version, type, metrics)
  - Action buttons (Run on Visible, Predict All)
  - Confidence threshold slider
  - Label filter checkboxes with per-class metrics
  - Disagreement filter toggle
  - Prediction summary display
  - Model server offline banner
- Add on-demand and batch prediction fetching
- Implement prediction caching by chart and model version
- Add health polling for inference API (30s interval when offline)
- Ensure annotation tools work independently of prediction API

Tasks completed: 9.1-9.5, 12.1-12.3 (59/78 total)
This commit is contained in:
Marko Djordjevic 2026-02-15 16:20:07 +01:00
parent bb1b6d573f
commit 28ebe2c5d1
4 changed files with 608 additions and 8 deletions

View file

@ -5,6 +5,8 @@ import Toolbox, { Tool } from '@/components/Toolbox';
import FileUpload from '@/components/FileUpload';
import CandleChart, { CandleChartHandle } from '@/components/CandleChart';
import ChartSelector from '@/components/ChartSelector';
import PredictionPanel from '@/components/PredictionPanel';
import type { PredictionState, PredictionSpan, ModelInfoResponse } from '@/types/predictions';
interface Chart {
id: number;
@ -60,6 +62,30 @@ export default function Home() {
const [selectedSpanId, setSelectedSpanId] = useState<number | null>(null);
const [spanLabelTypes, setSpanLabelTypes] = useState<SpanLabelType[]>([]);
// Prediction state
const [predictionState, setPredictionState] = useState<PredictionState>({
spans: [],
perCandlePredictions: [],
isLoading: false,
error: null,
modelInfo: null,
visible: false,
confidenceThreshold: 0.5,
selectedLabels: new Set<string>(),
autoPredict: false,
cacheKey: null,
});
// Prediction cache: Map<cacheKey, { spans, predictions, modelVersion }>
const predictionCacheRef = useRef<Map<string, {
spans: PredictionSpan[];
predictions: any[];
modelVersion: string;
}>>(new Map());
// Model health state
const [isModelOnline, setIsModelOnline] = useState(true);
// Fetch charts list
const fetchCharts = useCallback(async () => {
try {
@ -215,6 +241,238 @@ export default function Home() {
}
};
// Fetch model info and initialize selected labels
const fetchModelInfo = useCallback(async () => {
try {
const response = await fetch('/api/model/info');
if (!response.ok) {
setIsModelOnline(false);
throw new Error('Model info unavailable');
}
const data: ModelInfoResponse = await response.json();
setIsModelOnline(true);
setPredictionState((prev) => ({
...prev,
modelInfo: data,
selectedLabels: new Set(data.label_config.map((l) => l.name)),
error: null,
}));
return data;
} catch (error) {
console.error('Failed to fetch model info:', error);
setIsModelOnline(false);
setPredictionState((prev) => ({
...prev,
modelInfo: null,
error: error instanceof Error ? error.message : 'Failed to fetch model info',
}));
return null;
}
}, []);
// Generate cache key from chart, timerange, and model version
const generateCacheKey = useCallback((chartId: number | null, modelVersion?: string) => {
if (!chartId) return null;
const version = modelVersion || predictionState.modelInfo?.model_info.model_version || 'unknown';
return `${chartId}_${version}`;
}, [predictionState.modelInfo]);
// Fetch predictions for visible candles
const fetchPredictions = useCallback(async (candles: any[]) => {
if (!activeChartId || candles.length === 0) return;
const cacheKey = generateCacheKey(activeChartId, predictionState.modelInfo?.model_info.model_version);
// Check cache first
if (cacheKey && predictionCacheRef.current.has(cacheKey)) {
const cached = predictionCacheRef.current.get(cacheKey)!;
if (cached.modelVersion === predictionState.modelInfo?.model_info.model_version) {
setPredictionState((prev) => ({
...prev,
spans: cached.spans,
perCandlePredictions: cached.predictions,
cacheKey,
}));
return;
}
}
setPredictionState((prev) => ({ ...prev, isLoading: true, error: null }));
try {
const response = await fetch('/api/predict', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ candles }),
});
if (!response.ok) {
throw new Error(`Prediction failed: ${response.statusText}`);
}
const data = await response.json();
// Cache the results
if (cacheKey) {
predictionCacheRef.current.set(cacheKey, {
spans: data.spans,
predictions: data.predictions,
modelVersion: data.model_info.model_version,
});
}
setPredictionState((prev) => ({
...prev,
spans: data.spans,
perCandlePredictions: data.predictions,
isLoading: false,
cacheKey,
}));
} catch (error) {
console.error('Failed to fetch predictions:', error);
setPredictionState((prev) => ({
...prev,
isLoading: false,
error: error instanceof Error ? error.message : 'Failed to fetch predictions',
}));
}
}, [activeChartId, predictionState.modelInfo, generateCacheKey]);
// Toggle prediction visibility
const togglePredictionVisibility = useCallback(() => {
setPredictionState((prev) => ({ ...prev, visible: !prev.visible }));
}, []);
// Update confidence threshold
const setConfidenceThreshold = useCallback((threshold: number) => {
setPredictionState((prev) => ({ ...prev, confidenceThreshold: threshold }));
}, []);
// Toggle label selection
const toggleLabelSelection = useCallback((label: string) => {
setPredictionState((prev) => {
const newSelected = new Set(prev.selectedLabels);
if (newSelected.has(label)) {
newSelected.delete(label);
} else {
newSelected.add(label);
}
return { ...prev, selectedLabels: newSelected };
});
}, []);
// Handle on-demand prediction for visible candles
const handleFetchVisiblePredictions = useCallback(() => {
// This will be called by the PredictionPanel
// The actual candles data will be fetched from the chart ref
const candles = chartRef.current?.getVisibleCandles();
if (candles && candles.length > 0) {
fetchPredictions(candles);
}
}, [fetchPredictions]);
// Handle batch prediction for all candles
const handleFetchBatchPredictions = useCallback(async () => {
if (!activeChartId) return;
setPredictionState((prev) => ({ ...prev, isLoading: true, error: null }));
try {
// Fetch chart data to get pair/timeframe info
const chartResponse = await fetch(`/api/charts/${activeChartId}`);
if (!chartResponse.ok) {
throw new Error('Failed to fetch chart info');
}
const chartData = await chartResponse.json();
// Fetch candles for the chart
const candlesResponse = await fetch(`/api/candles?chartId=${activeChartId}`);
if (!candlesResponse.ok) {
throw new Error('Failed to fetch candles');
}
const candlesData = await candlesResponse.json();
if (candlesData.length === 0) {
throw new Error('No candles found for this chart');
}
const startTime = candlesData[0].time;
const endTime = candlesData[candlesData.length - 1].time;
// Make batch prediction request
const response = await fetch('/api/predict/batch', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
pair: chartData.name,
timeframe: '1h', // TODO: Get from chart metadata
start_time: startTime,
end_time: endTime,
}),
});
if (!response.ok) {
throw new Error(`Batch prediction failed: ${response.statusText}`);
}
const data = await response.json();
const cacheKey = generateCacheKey(activeChartId, data.model_info.model_version);
if (cacheKey) {
predictionCacheRef.current.set(cacheKey, {
spans: data.spans,
predictions: data.predictions,
modelVersion: data.model_info.model_version,
});
}
setPredictionState((prev) => ({
...prev,
spans: data.spans,
perCandlePredictions: data.predictions,
isLoading: false,
cacheKey,
}));
} catch (error) {
console.error('Failed to fetch batch predictions:', error);
setPredictionState((prev) => ({
...prev,
isLoading: false,
error: error instanceof Error ? error.message : 'Failed to fetch batch predictions',
}));
}
}, [activeChartId, generateCacheKey]);
// Clear prediction cache when model version changes
useEffect(() => {
if (predictionState.modelInfo) {
const currentVersion = predictionState.modelInfo.model_info.model_version;
// Clear cache entries with different model versions
const newCache = new Map();
for (const [key, value] of predictionCacheRef.current.entries()) {
if (value.modelVersion === currentVersion) {
newCache.set(key, value);
}
}
predictionCacheRef.current = newCache;
}
}, [predictionState.modelInfo?.model_info.model_version]);
// Health polling - check model status every 30 seconds when offline
useEffect(() => {
if (!isModelOnline) {
const interval = setInterval(() => {
fetchModelInfo();
}, 30000);
return () => clearInterval(interval);
}
}, [isModelOnline, fetchModelInfo]);
// Initialize model info on mount
useEffect(() => {
fetchModelInfo();
}, [fetchModelInfo]);
// Keyboard handler for Delete/Backspace key
useEffect(() => {
const handleKeyDown = async (e: KeyboardEvent) => {
@ -289,6 +547,15 @@ export default function Home() {
onDeleteSpan={handleDeleteSpan}
/>
</div>
<PredictionPanel
predictionState={predictionState}
onToggleVisibility={togglePredictionVisibility}
onFetchPredictions={handleFetchVisiblePredictions}
onFetchBatchPredictions={handleFetchBatchPredictions}
onConfidenceChange={setConfidenceThreshold}
onToggleLabelSelection={toggleLabelSelection}
isModelOnline={isModelOnline}
/>
</aside>
{/* Main chart area */}