diff --git a/openspec/changes/candle-backend/tasks.md b/openspec/changes/candle-backend/tasks.md index 695af23..c71c821 100644 --- a/openspec/changes/candle-backend/tasks.md +++ b/openspec/changes/candle-backend/tasks.md @@ -78,11 +78,11 @@ ## 9. Prediction UI — State & Controls -- [ ] 9.1 Create `src/types/predictions.ts` — PredictionSpan, PredictionState, ModelInfoResponse interfaces -- [ ] 9.2 Create prediction state management in page.tsx (or dedicated context) — spans, isLoading, error, modelInfo, visible, confidenceThreshold, selectedLabels, autoPredict -- [ ] 9.3 Create `src/components/PredictionPanel.tsx` — controls panel with master toggle, model info display, action buttons, confidence slider, label checkboxes with metrics -- [ ] 9.4 Implement on-demand prediction fetching — "Run on Visible" sends visible candles to /api/predict, "Predict All" sends batch request -- [ ] 9.5 Implement prediction caching — Map keyed by pair_timeframe_range_modelVersion, invalidate on model version change +- [x] 9.1 Create `src/types/predictions.ts` — PredictionSpan, PredictionState, ModelInfoResponse interfaces +- [x] 9.2 Create prediction state management in page.tsx (or dedicated context) — spans, isLoading, error, modelInfo, visible, confidenceThreshold, selectedLabels, autoPredict +- [x] 9.3 Create `src/components/PredictionPanel.tsx` — controls panel with master toggle, model info display, action buttons, confidence slider, label checkboxes with metrics +- [x] 9.4 Implement on-demand prediction fetching — "Run on Visible" sends visible candles to /api/predict, "Predict All" sends batch request +- [x] 9.5 Implement prediction caching — Map keyed by pair_timeframe_range_modelVersion, invalidate on model version change ## 10. Prediction UI — Chart Rendering @@ -103,9 +103,9 @@ ## 12. Inference API Connection & Error Handling -- [ ] 12.1 Implement inference API health polling — poll /api/model/info every 30 seconds when API unavailable, auto-reconnect -- [ ] 12.2 Show "Model server offline" banner when inference API unavailable, disable prediction controls -- [ ] 12.3 Ensure annotation tools work independently — prediction API errors never block human annotation +- [x] 12.1 Implement inference API health polling — poll /api/model/info every 30 seconds when API unavailable, auto-reconnect +- [x] 12.2 Show "Model server offline" banner when inference API unavailable, disable prediction controls +- [x] 12.3 Ensure annotation tools work independently — prediction API errors never block human annotation - [ ] 12.4 Add loading states for prediction fetching — skeleton/shimmer overlay during prediction requests ## 13. Documentation & Deployment diff --git a/src/app/page.tsx b/src/app/page.tsx index 9206280..0d40564 100644 --- a/src/app/page.tsx +++ b/src/app/page.tsx @@ -5,6 +5,8 @@ import Toolbox, { Tool } from '@/components/Toolbox'; import FileUpload from '@/components/FileUpload'; import CandleChart, { CandleChartHandle } from '@/components/CandleChart'; import ChartSelector from '@/components/ChartSelector'; +import PredictionPanel from '@/components/PredictionPanel'; +import type { PredictionState, PredictionSpan, ModelInfoResponse } from '@/types/predictions'; interface Chart { id: number; @@ -60,6 +62,30 @@ export default function Home() { const [selectedSpanId, setSelectedSpanId] = useState(null); const [spanLabelTypes, setSpanLabelTypes] = useState([]); + // Prediction state + const [predictionState, setPredictionState] = useState({ + spans: [], + perCandlePredictions: [], + isLoading: false, + error: null, + modelInfo: null, + visible: false, + confidenceThreshold: 0.5, + selectedLabels: new Set(), + autoPredict: false, + cacheKey: null, + }); + + // Prediction cache: Map + const predictionCacheRef = useRef>(new Map()); + + // Model health state + const [isModelOnline, setIsModelOnline] = useState(true); + // Fetch charts list const fetchCharts = useCallback(async () => { try { @@ -215,6 +241,238 @@ export default function Home() { } }; + // Fetch model info and initialize selected labels + const fetchModelInfo = useCallback(async () => { + try { + const response = await fetch('/api/model/info'); + if (!response.ok) { + setIsModelOnline(false); + throw new Error('Model info unavailable'); + } + const data: ModelInfoResponse = await response.json(); + setIsModelOnline(true); + setPredictionState((prev) => ({ + ...prev, + modelInfo: data, + selectedLabels: new Set(data.label_config.map((l) => l.name)), + error: null, + })); + return data; + } catch (error) { + console.error('Failed to fetch model info:', error); + setIsModelOnline(false); + setPredictionState((prev) => ({ + ...prev, + modelInfo: null, + error: error instanceof Error ? error.message : 'Failed to fetch model info', + })); + return null; + } + }, []); + + // Generate cache key from chart, timerange, and model version + const generateCacheKey = useCallback((chartId: number | null, modelVersion?: string) => { + if (!chartId) return null; + const version = modelVersion || predictionState.modelInfo?.model_info.model_version || 'unknown'; + return `${chartId}_${version}`; + }, [predictionState.modelInfo]); + + // Fetch predictions for visible candles + const fetchPredictions = useCallback(async (candles: any[]) => { + if (!activeChartId || candles.length === 0) return; + + const cacheKey = generateCacheKey(activeChartId, predictionState.modelInfo?.model_info.model_version); + + // Check cache first + if (cacheKey && predictionCacheRef.current.has(cacheKey)) { + const cached = predictionCacheRef.current.get(cacheKey)!; + if (cached.modelVersion === predictionState.modelInfo?.model_info.model_version) { + setPredictionState((prev) => ({ + ...prev, + spans: cached.spans, + perCandlePredictions: cached.predictions, + cacheKey, + })); + return; + } + } + + setPredictionState((prev) => ({ ...prev, isLoading: true, error: null })); + + try { + const response = await fetch('/api/predict', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ candles }), + }); + + if (!response.ok) { + throw new Error(`Prediction failed: ${response.statusText}`); + } + + const data = await response.json(); + + // Cache the results + if (cacheKey) { + predictionCacheRef.current.set(cacheKey, { + spans: data.spans, + predictions: data.predictions, + modelVersion: data.model_info.model_version, + }); + } + + setPredictionState((prev) => ({ + ...prev, + spans: data.spans, + perCandlePredictions: data.predictions, + isLoading: false, + cacheKey, + })); + } catch (error) { + console.error('Failed to fetch predictions:', error); + setPredictionState((prev) => ({ + ...prev, + isLoading: false, + error: error instanceof Error ? error.message : 'Failed to fetch predictions', + })); + } + }, [activeChartId, predictionState.modelInfo, generateCacheKey]); + + // Toggle prediction visibility + const togglePredictionVisibility = useCallback(() => { + setPredictionState((prev) => ({ ...prev, visible: !prev.visible })); + }, []); + + // Update confidence threshold + const setConfidenceThreshold = useCallback((threshold: number) => { + setPredictionState((prev) => ({ ...prev, confidenceThreshold: threshold })); + }, []); + + // Toggle label selection + const toggleLabelSelection = useCallback((label: string) => { + setPredictionState((prev) => { + const newSelected = new Set(prev.selectedLabels); + if (newSelected.has(label)) { + newSelected.delete(label); + } else { + newSelected.add(label); + } + return { ...prev, selectedLabels: newSelected }; + }); + }, []); + + // Handle on-demand prediction for visible candles + const handleFetchVisiblePredictions = useCallback(() => { + // This will be called by the PredictionPanel + // The actual candles data will be fetched from the chart ref + const candles = chartRef.current?.getVisibleCandles(); + if (candles && candles.length > 0) { + fetchPredictions(candles); + } + }, [fetchPredictions]); + + // Handle batch prediction for all candles + const handleFetchBatchPredictions = useCallback(async () => { + if (!activeChartId) return; + + setPredictionState((prev) => ({ ...prev, isLoading: true, error: null })); + + try { + // Fetch chart data to get pair/timeframe info + const chartResponse = await fetch(`/api/charts/${activeChartId}`); + if (!chartResponse.ok) { + throw new Error('Failed to fetch chart info'); + } + const chartData = await chartResponse.json(); + + // Fetch candles for the chart + const candlesResponse = await fetch(`/api/candles?chartId=${activeChartId}`); + if (!candlesResponse.ok) { + throw new Error('Failed to fetch candles'); + } + const candlesData = await candlesResponse.json(); + + if (candlesData.length === 0) { + throw new Error('No candles found for this chart'); + } + + const startTime = candlesData[0].time; + const endTime = candlesData[candlesData.length - 1].time; + + // Make batch prediction request + const response = await fetch('/api/predict/batch', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + pair: chartData.name, + timeframe: '1h', // TODO: Get from chart metadata + start_time: startTime, + end_time: endTime, + }), + }); + + if (!response.ok) { + throw new Error(`Batch prediction failed: ${response.statusText}`); + } + + const data = await response.json(); + + const cacheKey = generateCacheKey(activeChartId, data.model_info.model_version); + if (cacheKey) { + predictionCacheRef.current.set(cacheKey, { + spans: data.spans, + predictions: data.predictions, + modelVersion: data.model_info.model_version, + }); + } + + setPredictionState((prev) => ({ + ...prev, + spans: data.spans, + perCandlePredictions: data.predictions, + isLoading: false, + cacheKey, + })); + } catch (error) { + console.error('Failed to fetch batch predictions:', error); + setPredictionState((prev) => ({ + ...prev, + isLoading: false, + error: error instanceof Error ? error.message : 'Failed to fetch batch predictions', + })); + } + }, [activeChartId, generateCacheKey]); + + // Clear prediction cache when model version changes + useEffect(() => { + if (predictionState.modelInfo) { + const currentVersion = predictionState.modelInfo.model_info.model_version; + // Clear cache entries with different model versions + const newCache = new Map(); + for (const [key, value] of predictionCacheRef.current.entries()) { + if (value.modelVersion === currentVersion) { + newCache.set(key, value); + } + } + predictionCacheRef.current = newCache; + } + }, [predictionState.modelInfo?.model_info.model_version]); + + // Health polling - check model status every 30 seconds when offline + useEffect(() => { + if (!isModelOnline) { + const interval = setInterval(() => { + fetchModelInfo(); + }, 30000); + return () => clearInterval(interval); + } + }, [isModelOnline, fetchModelInfo]); + + // Initialize model info on mount + useEffect(() => { + fetchModelInfo(); + }, [fetchModelInfo]); + // Keyboard handler for Delete/Backspace key useEffect(() => { const handleKeyDown = async (e: KeyboardEvent) => { @@ -289,6 +547,15 @@ export default function Home() { onDeleteSpan={handleDeleteSpan} /> + {/* Main chart area */} diff --git a/src/components/PredictionPanel.tsx b/src/components/PredictionPanel.tsx new file mode 100644 index 0000000..2a3c7e0 --- /dev/null +++ b/src/components/PredictionPanel.tsx @@ -0,0 +1,215 @@ +'use client'; + +import { useState } from 'react'; +import type { PredictionState, ModelInfoResponse, PredictionSummary } from '@/types/predictions'; + +interface PredictionPanelProps { + predictionState: PredictionState; + onToggleVisibility: () => void; + onFetchPredictions: () => void; + onFetchBatchPredictions: () => void; + onConfidenceChange: (threshold: number) => void; + onToggleLabelSelection: (label: string) => void; + predictionSummary?: PredictionSummary; + isModelOnline: boolean; +} + +export default function PredictionPanel({ + predictionState, + onToggleVisibility, + onFetchPredictions, + onFetchBatchPredictions, + onConfidenceChange, + onToggleLabelSelection, + predictionSummary, + isModelOnline, +}: PredictionPanelProps) { + const [showOnlyDisagreements, setShowOnlyDisagreements] = useState(false); + + const { + visible, + isLoading, + error, + modelInfo, + confidenceThreshold, + selectedLabels, + spans, + } = predictionState; + + if (!isModelOnline) { + return ( +
+
+
+

Model Server Offline

+
+

+ Prediction service is unavailable. Annotation tools continue to work normally. +

+
+ ); + } + + return ( +
+ {/* Header with master toggle */} +
+
+
+

Predictions

+
+ +
+ + {/* Model Info */} + {modelInfo && ( +
+
+ Model: + {modelInfo.model_info.model_name} +
+
+ Version: + {modelInfo.model_info.model_version} +
+
+ Type: + {modelInfo.model_info.model_type} +
+
+ Accuracy: + {(modelInfo.metrics.accuracy * 100).toFixed(1)}% +
+
+ F1 (macro): + {(modelInfo.metrics.f1_macro * 100).toFixed(1)}% +
+
+ )} + + {/* Action Buttons */} +
+ + +
+ + {/* Error Display */} + {error && ( +
+ {error} +
+ )} + + {/* Confidence Slider */} +
+
+ + {(confidenceThreshold * 100).toFixed(0)}% +
+ onConfidenceChange(Number(e.target.value) / 100)} + className="w-full h-1 bg-muted rounded-lg appearance-none cursor-pointer" + /> +
+ + {/* Label Filter Checkboxes */} + {modelInfo && ( +
+ +
+ {modelInfo.label_config.map((labelConfig) => { + const metrics = modelInfo.metrics.per_class[labelConfig.name]; + const isSelected = selectedLabels.has(labelConfig.name); + + return ( +
+ )} + + {/* Disagreement Filter */} + {predictionSummary && predictionSummary.disagreements.length > 0 && ( +
+ +
+ )} + + {/* Prediction Summary */} + {visible && spans.length > 0 && predictionSummary && ( +
+
+ Predictions: + {predictionSummary.total_predictions} +
+
+ Human annotations: + {predictionSummary.total_human_annotations} +
+
+ Agreements: + {predictionSummary.agreements} +
+
+ Disagreements: + {predictionSummary.disagreements.length} +
+
+ )} +
+ ); +} diff --git a/src/types/predictions.ts b/src/types/predictions.ts new file mode 100644 index 0000000..f8a79dd --- /dev/null +++ b/src/types/predictions.ts @@ -0,0 +1,118 @@ +/** + * Prediction types for ML model inference + */ + +export interface PredictionSpan { + label: string; + start_time: number; + end_time: number; + avg_confidence: number; + candle_count: number; +} + +export interface PerCandlePrediction { + time: number; + label: string; + confidence: number; +} + +export interface ModelInfo { + model_name: string; + model_version: string; + model_type: string; + experiment_name: string; + run_id: string; + trained_at: string; + feature_count: number; + label_names: string[]; +} + +export interface PerClassMetrics { + [label: string]: { + precision: number; + recall: number; + f1_score: number; + support: number; + }; +} + +export interface ModelMetrics { + accuracy: number; + f1_macro: number; + f1_weighted: number; + per_class: PerClassMetrics; +} + +export interface ModelInfoResponse { + model_info: ModelInfo; + metrics: ModelMetrics; + label_config: { + name: string; + color: string; + }[]; +} + +export interface PredictRequest { + candles: { + time: number; + open: number; + high: number; + low: number; + close: number; + volume?: number; + }[]; +} + +export interface PredictResponse { + predictions: PerCandlePrediction[]; + spans: PredictionSpan[]; + model_info: { + model_name: string; + model_version: string; + }; +} + +export interface BatchPredictRequest { + pair: string; + timeframe: string; + start_time: number; + end_time: number; + batch_size?: number; +} + +export interface PredictionState { + spans: PredictionSpan[]; + perCandlePredictions: PerCandlePrediction[]; + isLoading: boolean; + error: string | null; + modelInfo: ModelInfoResponse | null; + visible: boolean; + confidenceThreshold: number; + selectedLabels: Set; + autoPredict: boolean; + cacheKey: string | null; +} + +export type DisagreementType = + | 'missed_by_model' // Human annotation but no prediction + | 'missed_by_human' // Prediction but no human annotation + | 'label_mismatch'; // Both present but different labels + +export interface Disagreement { + type: DisagreementType; + humanSpan?: { + id: number; + label: string; + start_time: number; + end_time: number; + }; + predictionSpan?: PredictionSpan; + overlap_ratio?: number; +} + +export interface PredictionSummary { + total_predictions: number; + total_human_annotations: number; + agreements: number; + disagreements: Disagreement[]; +}