feat: complete prediction UI feedback tasks (11.2, 11.4, 11.5)
- Implement disagreement visual highlighting with distinct colors - Yellow highlight for 'missed_by_human' predictions - Orange for 'label_mismatch' disagreements - Warning icon on disagreement markers - Add click-to-convert prediction feedback - Click disagreement predictions to create span annotations - Auto-fill with predicted label and times - Set source as 'model_confirmed' or 'model_corrected' - Add dismiss action for false positive predictions - Alt+Click or Ctrl+Click to dismiss predictions - Saves negative annotation with label 'O' - Records original prediction in model_prediction field - Filter predictions when 'Show only disagreements' is enabled
This commit is contained in:
parent
a18c6d110a
commit
65f00e6ce7
13 changed files with 905 additions and 11 deletions
|
|
@ -306,6 +306,83 @@ export default function Home() {
|
|||
setSelectedSpanId(spanId);
|
||||
};
|
||||
|
||||
// Handle prediction click to convert to annotation
|
||||
const handlePredictionClick = useCallback(async (span: PredictionSpan, disagreementType: string | null) => {
|
||||
if (!activeChartId) return;
|
||||
|
||||
// Find the span label type that matches the prediction label
|
||||
const matchingLabelType = spanLabelTypes.find((lt) => lt.name === span.label);
|
||||
|
||||
if (!matchingLabelType) {
|
||||
console.warn(`No span label type found for prediction label: ${span.label}`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Create span annotation from prediction
|
||||
try {
|
||||
const response = await fetch('/api/span-annotations', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
chart_id: activeChartId,
|
||||
start_time: span.start_time,
|
||||
end_time: span.end_time,
|
||||
label: span.label,
|
||||
confidence: 3, // Default confidence for model-confirmed annotations
|
||||
source: disagreementType === 'label_mismatch' ? 'model_corrected' : 'model_confirmed',
|
||||
model_prediction: {
|
||||
label: span.label,
|
||||
confidence: span.avg_confidence,
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
await fetchSpanAnnotations(activeChartId);
|
||||
// Show a brief notification (you could add a toast notification here)
|
||||
console.log(`Created span annotation from prediction: ${span.label}`);
|
||||
} else {
|
||||
console.error('Failed to create span annotation from prediction');
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error creating span annotation from prediction:', error);
|
||||
}
|
||||
}, [activeChartId, spanLabelTypes, fetchSpanAnnotations]);
|
||||
|
||||
// Handle prediction dismiss (save as negative annotation with label "O")
|
||||
const handlePredictionDismiss = useCallback(async (span: PredictionSpan, disagreementType: string | null) => {
|
||||
if (!activeChartId) return;
|
||||
|
||||
// Create negative annotation with label "O"
|
||||
try {
|
||||
const response = await fetch('/api/span-annotations', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
chart_id: activeChartId,
|
||||
start_time: span.start_time,
|
||||
end_time: span.end_time,
|
||||
label: 'O', // "O" means "not a pattern"
|
||||
confidence: 5, // High confidence for explicit user correction
|
||||
source: 'human_correction',
|
||||
model_prediction: {
|
||||
label: span.label,
|
||||
confidence: span.avg_confidence,
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
await fetchSpanAnnotations(activeChartId);
|
||||
console.log(`Dismissed prediction as "not a pattern": ${span.label}`);
|
||||
} else {
|
||||
console.error('Failed to save negative annotation');
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error saving negative annotation:', error);
|
||||
}
|
||||
}, [activeChartId, fetchSpanAnnotations]);
|
||||
|
||||
const handleDeleteSpan = async (spanId: number) => {
|
||||
try {
|
||||
const response = await fetch(`/api/span-annotations/${spanId}`, {
|
||||
|
|
@ -715,6 +792,8 @@ export default function Home() {
|
|||
modelInfo={predictionState.modelInfo}
|
||||
predictionSummary={predictionSummary}
|
||||
showOnlyDisagreements={showOnlyDisagreements}
|
||||
onPredictionClick={handlePredictionClick}
|
||||
onPredictionDismiss={handlePredictionDismiss}
|
||||
/>
|
||||
</main>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -84,6 +84,8 @@ interface CandleChartProps {
|
|||
modelInfo?: ModelInfoResponse | null;
|
||||
predictionSummary?: PredictionSummary | null;
|
||||
showOnlyDisagreements?: boolean;
|
||||
onPredictionClick?: (span: PredictionSpan, disagreementType: string | null) => void;
|
||||
onPredictionDismiss?: (span: PredictionSpan, disagreementType: string | null) => void;
|
||||
}
|
||||
|
||||
export interface CandleChartHandle {
|
||||
|
|
@ -112,6 +114,8 @@ const CandleChart = forwardRef<CandleChartHandle, CandleChartProps>(
|
|||
modelInfo = null,
|
||||
predictionSummary = null,
|
||||
showOnlyDisagreements = false,
|
||||
onPredictionClick,
|
||||
onPredictionDismiss,
|
||||
}, ref) => {
|
||||
const chartContainerRef = useRef<HTMLDivElement>(null);
|
||||
const chartRef = useRef<IChartApi | null>(null);
|
||||
|
|
@ -392,6 +396,20 @@ const CandleChart = forwardRef<CandleChartHandle, CandleChartProps>(
|
|||
});
|
||||
}
|
||||
|
||||
// Build disagreement lookup map for highlighting
|
||||
// Map: predictionSpan start_time -> disagreement type
|
||||
const disagreementMap = new Map<number, string>();
|
||||
const disagreementSpanSet = new Set<string>(); // Set of "start-end" keys for filtering
|
||||
|
||||
if (predictionSummary?.disagreements) {
|
||||
predictionSummary.disagreements.forEach((d) => {
|
||||
if (d.predictionSpan) {
|
||||
disagreementMap.set(d.predictionSpan.start_time, d.type);
|
||||
disagreementSpanSet.add(`${d.predictionSpan.start_time}-${d.predictionSpan.end_time}`);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Helper to convert hex to rgba
|
||||
const hexToRgba = (hex: string, alpha: number): string => {
|
||||
hex = hex.replace('#', '');
|
||||
|
|
@ -404,7 +422,16 @@ const CandleChart = forwardRef<CandleChartHandle, CandleChartProps>(
|
|||
// Build candle price lookup for histogram values
|
||||
const candleMap = new Map(candles.map((c) => [c.time, c]));
|
||||
|
||||
// Filter and map predictions to histogram data
|
||||
// Build a map from prediction time to its span for disagreement lookup
|
||||
const predictionTimeToSpan = new Map<number, PredictionSpan>();
|
||||
predictionSpans.forEach((span) => {
|
||||
// Associate all times in the span with the span object
|
||||
for (let t = span.start_time; t <= span.end_time; t += 60) { // Assuming 1-minute candles, adjust if needed
|
||||
predictionTimeToSpan.set(t, span);
|
||||
}
|
||||
});
|
||||
|
||||
// Filter and map predictions to histogram data with disagreement highlighting
|
||||
const histogramData: (HistogramData & { color?: string })[] = perCandlePredictions
|
||||
.filter((p) => {
|
||||
// Filter by confidence threshold
|
||||
|
|
@ -413,17 +440,44 @@ const CandleChart = forwardRef<CandleChartHandle, CandleChartProps>(
|
|||
if (!selectedLabels.has(p.label)) return false;
|
||||
// Skip "O" (no-pattern) labels
|
||||
if (p.label === 'O') return false;
|
||||
|
||||
// If showOnlyDisagreements is enabled, only show predictions that are part of disagreements
|
||||
if (showOnlyDisagreements) {
|
||||
const span = predictionTimeToSpan.get(p.time);
|
||||
if (!span) return false;
|
||||
const spanKey = `${span.start_time}-${span.end_time}`;
|
||||
if (!disagreementSpanSet.has(spanKey)) return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
})
|
||||
.map((p) => {
|
||||
const candle = candleMap.get(p.time);
|
||||
// Use candle high as the histogram value so it overlays correctly
|
||||
const value = candle ? candle.high : 0;
|
||||
const baseColor = labelColorMap[p.label] || '#888888';
|
||||
|
||||
// Check if this prediction is part of a disagreement
|
||||
const span = predictionTimeToSpan.get(p.time);
|
||||
const disagreementType = span ? disagreementMap.get(span.start_time) : null;
|
||||
|
||||
let baseColor = labelColorMap[p.label] || '#888888';
|
||||
let alpha = 0.15;
|
||||
|
||||
// Apply disagreement-specific colors and styling
|
||||
if (disagreementType === 'missed_by_human') {
|
||||
// Yellow highlight for predictions missed by humans
|
||||
baseColor = '#eab308'; // yellow
|
||||
alpha = 0.25;
|
||||
} else if (disagreementType === 'label_mismatch') {
|
||||
// Orange for label mismatches
|
||||
baseColor = '#f97316'; // orange
|
||||
alpha = 0.25;
|
||||
}
|
||||
|
||||
return {
|
||||
time: p.time as Time,
|
||||
value,
|
||||
color: hexToRgba(baseColor, 0.15),
|
||||
color: hexToRgba(baseColor, alpha),
|
||||
};
|
||||
})
|
||||
.sort((a, b) => (a.time as number) - (b.time as number));
|
||||
|
|
@ -436,24 +490,43 @@ const CandleChart = forwardRef<CandleChartHandle, CandleChartProps>(
|
|||
if (span.avg_confidence < confidenceThreshold) return false;
|
||||
if (!selectedLabels.has(span.label)) return false;
|
||||
if (span.label === 'O') return false;
|
||||
|
||||
// If showOnlyDisagreements is enabled, only show spans that are disagreements
|
||||
if (showOnlyDisagreements) {
|
||||
const spanKey = `${span.start_time}-${span.end_time}`;
|
||||
if (!disagreementSpanSet.has(spanKey)) return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
})
|
||||
.map((span) => {
|
||||
const baseColor = labelColorMap[span.label] || '#888888';
|
||||
const disagreementType = disagreementMap.get(span.start_time);
|
||||
let baseColor = labelColorMap[span.label] || '#888888';
|
||||
let labelText = span.label;
|
||||
|
||||
// Apply disagreement-specific styling to markers
|
||||
if (disagreementType === 'missed_by_human') {
|
||||
baseColor = '#eab308'; // yellow
|
||||
labelText = `⚠ ${span.label}`;
|
||||
} else if (disagreementType === 'label_mismatch') {
|
||||
baseColor = '#f97316'; // orange
|
||||
labelText = `⚠ ${span.label}`;
|
||||
}
|
||||
|
||||
const confidencePct = Math.round(span.avg_confidence * 100);
|
||||
return {
|
||||
time: span.start_time as Time,
|
||||
position: 'belowBar' as const,
|
||||
color: baseColor,
|
||||
shape: 'square' as const,
|
||||
text: `${span.label} (${confidencePct}%)`,
|
||||
text: `${labelText} (${confidencePct}%)`,
|
||||
size: 1,
|
||||
};
|
||||
})
|
||||
.sort((a, b) => (a.time as number) - (b.time as number));
|
||||
|
||||
histogramSeriesRef.current.setMarkers(spanMarkers);
|
||||
}, [predictionVisible, perCandlePredictions, predictionSpans, confidenceThreshold, selectedLabels, modelInfo, candles]);
|
||||
}, [predictionVisible, perCandlePredictions, predictionSpans, confidenceThreshold, selectedLabels, modelInfo, candles, predictionSummary, showOnlyDisagreements]);
|
||||
|
||||
// Handle chart clicks for annotation
|
||||
useEffect(() => {
|
||||
|
|
@ -537,6 +610,41 @@ const CandleChart = forwardRef<CandleChartHandle, CandleChartProps>(
|
|||
}
|
||||
}
|
||||
|
||||
// Handle clicks on prediction spans (for converting to annotations or dismissing)
|
||||
if (!activeTool && predictionVisible && predictionSpans.length > 0) {
|
||||
const timestamp = typeof time === 'string' ? Date.parse(time) / 1000 : (time as number);
|
||||
|
||||
// Find if click is within any prediction span
|
||||
const clickedSpan = predictionSpans.find(
|
||||
(span) => timestamp >= span.start_time && timestamp <= span.end_time
|
||||
);
|
||||
|
||||
if (clickedSpan) {
|
||||
// Check if this span is a disagreement
|
||||
const disagreementMap = new Map<number, string>();
|
||||
if (predictionSummary?.disagreements) {
|
||||
predictionSummary.disagreements.forEach((d) => {
|
||||
if (d.predictionSpan) {
|
||||
disagreementMap.set(d.predictionSpan.start_time, d.type);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
const disagreementType = disagreementMap.get(clickedSpan.start_time) || null;
|
||||
|
||||
// Check if Alt or Ctrl key is pressed for dismiss action
|
||||
if ((param.sourceEvent?.altKey || param.sourceEvent?.ctrlKey) && onPredictionDismiss) {
|
||||
// Alt+Click or Ctrl+Click: Dismiss as "not a pattern"
|
||||
onPredictionDismiss(clickedSpan, disagreementType);
|
||||
} else if (onPredictionClick) {
|
||||
// Normal click: Convert to annotation (only for disagreements)
|
||||
if (disagreementType === 'missed_by_human' || disagreementType === 'label_mismatch') {
|
||||
onPredictionClick(clickedSpan, disagreementType);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Select/deselect label markers by clicking them
|
||||
const isMarkerTool = annotationTypes.find(
|
||||
(t) => t.category === 'marker' && t.name === activeTool
|
||||
|
|
@ -567,7 +675,7 @@ const CandleChart = forwardRef<CandleChartHandle, CandleChartProps>(
|
|||
return () => {
|
||||
chartRef.current?.unsubscribeClick(handleClick);
|
||||
};
|
||||
}, [activeTool, candles, annotations, annotationTypes, onAnnotationChange]);
|
||||
}, [activeTool, candles, annotations, annotationTypes, onAnnotationChange, predictionVisible, predictionSpans, predictionSummary, onPredictionClick, onPredictionDismiss]);
|
||||
|
||||
// Fetch data on mount
|
||||
useEffect(() => {
|
||||
|
|
|
|||
204
src/plugins/trend-line.ts
Normal file
204
src/plugins/trend-line.ts
Normal file
|
|
@ -0,0 +1,204 @@
|
|||
import { BitmapCoordinatesRenderingScope, CanvasRenderingTarget2D } from 'fancy-canvas';
|
||||
import {
|
||||
AutoscaleInfo,
|
||||
Coordinate,
|
||||
IChartApi,
|
||||
ISeriesApi,
|
||||
ISeriesPrimitive,
|
||||
IPrimitivePaneRenderer,
|
||||
IPrimitivePaneView,
|
||||
Logical,
|
||||
SeriesOptionsMap,
|
||||
SeriesType,
|
||||
Time,
|
||||
} from 'lightweight-charts';
|
||||
|
||||
class TrendLinePaneRenderer implements IPrimitivePaneRenderer {
|
||||
_p1: ViewPoint;
|
||||
_p2: ViewPoint;
|
||||
_text1: string;
|
||||
_text2: string;
|
||||
_options: TrendLineOptions;
|
||||
|
||||
constructor(p1: ViewPoint, p2: ViewPoint, text1: string, text2: string, options: TrendLineOptions) {
|
||||
this._p1 = p1;
|
||||
this._p2 = p2;
|
||||
this._text1 = text1;
|
||||
this._text2 = text2;
|
||||
this._options = options;
|
||||
}
|
||||
|
||||
draw(target: CanvasRenderingTarget2D) {
|
||||
target.useBitmapCoordinateSpace(scope => {
|
||||
if (
|
||||
this._p1.x === null ||
|
||||
this._p1.y === null ||
|
||||
this._p2.x === null ||
|
||||
this._p2.y === null
|
||||
)
|
||||
return;
|
||||
const ctx = scope.context;
|
||||
const x1Scaled = Math.round(this._p1.x * scope.horizontalPixelRatio);
|
||||
const y1Scaled = Math.round(this._p1.y * scope.verticalPixelRatio);
|
||||
const x2Scaled = Math.round(this._p2.x * scope.horizontalPixelRatio);
|
||||
const y2Scaled = Math.round(this._p2.y * scope.verticalPixelRatio);
|
||||
ctx.lineWidth = this._options.width;
|
||||
ctx.strokeStyle = this._options.lineColor;
|
||||
ctx.beginPath();
|
||||
ctx.moveTo(x1Scaled, y1Scaled);
|
||||
ctx.lineTo(x2Scaled, y2Scaled);
|
||||
ctx.stroke();
|
||||
if (this._options.showLabels) {
|
||||
this._drawTextLabel(scope, this._text1, x1Scaled, y1Scaled, true);
|
||||
this._drawTextLabel(scope, this._text2, x2Scaled, y2Scaled, false);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
_drawTextLabel(scope: BitmapCoordinatesRenderingScope, text: string, x: number, y: number, left: boolean) {
|
||||
scope.context.font = '24px Arial';
|
||||
scope.context.beginPath();
|
||||
const offset = 5 * scope.horizontalPixelRatio;
|
||||
const textWidth = scope.context.measureText(text);
|
||||
const leftAdjustment = left ? textWidth.width + offset * 4 : 0;
|
||||
scope.context.fillStyle = this._options.labelBackgroundColor;
|
||||
scope.context.roundRect(x + offset - leftAdjustment, y - 24, textWidth.width + offset * 2, 24 + offset, 5);
|
||||
scope.context.fill();
|
||||
scope.context.beginPath();
|
||||
scope.context.fillStyle = this._options.labelTextColor;
|
||||
scope.context.fillText(text, x + offset * 2 - leftAdjustment, y);
|
||||
}
|
||||
}
|
||||
|
||||
interface ViewPoint {
|
||||
x: Coordinate | null;
|
||||
y: Coordinate | null;
|
||||
}
|
||||
|
||||
class TrendLinePaneView implements IPrimitivePaneView {
|
||||
_source: TrendLine;
|
||||
_p1: ViewPoint = { x: null, y: null };
|
||||
_p2: ViewPoint = { x: null, y: null };
|
||||
|
||||
constructor(source: TrendLine) {
|
||||
this._source = source;
|
||||
}
|
||||
|
||||
update() {
|
||||
const series = this._source._series;
|
||||
const y1 = series.priceToCoordinate(this._source._p1.price);
|
||||
const y2 = series.priceToCoordinate(this._source._p2.price);
|
||||
const timeScale = this._source._chart.timeScale();
|
||||
const x1 = timeScale.timeToCoordinate(this._source._p1.time);
|
||||
const x2 = timeScale.timeToCoordinate(this._source._p2.time);
|
||||
this._p1 = { x: x1, y: y1 };
|
||||
this._p2 = { x: x2, y: y2 };
|
||||
}
|
||||
|
||||
renderer() {
|
||||
return new TrendLinePaneRenderer(
|
||||
this._p1,
|
||||
this._p2,
|
||||
'' + this._source._p1.price.toFixed(1),
|
||||
'' + this._source._p2.price.toFixed(1),
|
||||
this._source._options
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
interface Point {
|
||||
time: Time;
|
||||
price: number;
|
||||
}
|
||||
|
||||
export interface TrendLineOptions {
|
||||
lineColor: string;
|
||||
width: number;
|
||||
showLabels: boolean;
|
||||
labelBackgroundColor: string;
|
||||
labelTextColor: string;
|
||||
}
|
||||
|
||||
const defaultOptions: TrendLineOptions = {
|
||||
lineColor: 'rgb(0, 0, 0)',
|
||||
width: 2,
|
||||
showLabels: false,
|
||||
labelBackgroundColor: 'rgba(255, 255, 255, 0.85)',
|
||||
labelTextColor: 'rgb(0, 0, 0)',
|
||||
};
|
||||
|
||||
export class TrendLine implements ISeriesPrimitive<Time> {
|
||||
_chart: IChartApi;
|
||||
_series: ISeriesApi<keyof SeriesOptionsMap>;
|
||||
_p1: Point;
|
||||
_p2: Point;
|
||||
_paneViews: TrendLinePaneView[];
|
||||
_options: TrendLineOptions;
|
||||
_minPrice: number;
|
||||
_maxPrice: number;
|
||||
|
||||
constructor(
|
||||
chart: IChartApi,
|
||||
series: ISeriesApi<SeriesType>,
|
||||
p1: Point,
|
||||
p2: Point,
|
||||
options?: Partial<TrendLineOptions>
|
||||
) {
|
||||
this._chart = chart;
|
||||
this._series = series;
|
||||
this._p1 = p1;
|
||||
this._p2 = p2;
|
||||
this._minPrice = Math.min(this._p1.price, this._p2.price);
|
||||
this._maxPrice = Math.max(this._p1.price, this._p2.price);
|
||||
this._options = {
|
||||
...defaultOptions,
|
||||
...options,
|
||||
};
|
||||
this._paneViews = [new TrendLinePaneView(this)];
|
||||
}
|
||||
|
||||
updatePoints(p1: Point, p2: Point) {
|
||||
this._p1 = p1;
|
||||
this._p2 = p2;
|
||||
this._minPrice = Math.min(this._p1.price, this._p2.price);
|
||||
this._maxPrice = Math.max(this._p1.price, this._p2.price);
|
||||
}
|
||||
|
||||
getP1(): Point {
|
||||
return this._p1;
|
||||
}
|
||||
|
||||
getP2(): Point {
|
||||
return this._p2;
|
||||
}
|
||||
|
||||
autoscaleInfo(startTimePoint: Logical, endTimePoint: Logical): AutoscaleInfo | null {
|
||||
const p1Index = this._pointIndex(this._p1);
|
||||
const p2Index = this._pointIndex(this._p2);
|
||||
if (p1Index === null || p2Index === null) return null;
|
||||
if (endTimePoint < p1Index || startTimePoint > p2Index) return null;
|
||||
return {
|
||||
priceRange: {
|
||||
minValue: this._minPrice,
|
||||
maxValue: this._maxPrice,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
updateAllViews() {
|
||||
this._paneViews.forEach(pw => pw.update());
|
||||
}
|
||||
|
||||
paneViews() {
|
||||
return this._paneViews;
|
||||
}
|
||||
|
||||
_pointIndex(p: Point): number | null {
|
||||
const coordinate = this._chart
|
||||
.timeScale()
|
||||
.timeToCoordinate(p.time);
|
||||
if (coordinate === null) return null;
|
||||
const index = this._chart.timeScale().coordinateToLogical(coordinate);
|
||||
return index;
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue