fix: training panel stuck button, stale runs on startup, add delete model
This commit is contained in:
parent
d34dc9d729
commit
6ef102cf21
4 changed files with 240 additions and 24 deletions
|
|
@ -310,6 +310,28 @@ async def startup_event():
|
||||||
Load model and pipeline config on startup.
|
Load model and pipeline config on startup.
|
||||||
"""
|
"""
|
||||||
logger.info("Starting inference service...")
|
logger.info("Starting inference service...")
|
||||||
|
|
||||||
|
# Mark any stale "running" records as failed — they belong to a previous
|
||||||
|
# process and will never complete.
|
||||||
|
try:
|
||||||
|
with get_db() as db:
|
||||||
|
stmt = (
|
||||||
|
sa_update(TrainingRun)
|
||||||
|
.where(TrainingRun.status == "running")
|
||||||
|
.values(
|
||||||
|
status="failed",
|
||||||
|
completed_at=datetime.utcnow(),
|
||||||
|
metrics_summary={"error": "Service restarted while training was in progress"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = db.execute(stmt)
|
||||||
|
db.commit()
|
||||||
|
if result.rowcount:
|
||||||
|
logger.warning(
|
||||||
|
f"Marked {result.rowcount} stale training run(s) as failed on startup"
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Failed to clean up stale training runs: {exc}")
|
||||||
|
|
||||||
# Load pipeline config
|
# Load pipeline config
|
||||||
config_path = Path("config/pipeline.yaml")
|
config_path = Path("config/pipeline.yaml")
|
||||||
|
|
@ -1115,6 +1137,81 @@ async def training_runs():
|
||||||
return TrainingRunsResponse(runs=runs)
|
return TrainingRunsResponse(runs=runs)
|
||||||
|
|
||||||
|
|
||||||
|
class ActiveTrainingResponse(BaseModel):
|
||||||
|
"""Response model for GET /training/active."""
|
||||||
|
active: bool
|
||||||
|
run_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/training/active", response_model=ActiveTrainingResponse)
|
||||||
|
async def training_active():
|
||||||
|
"""
|
||||||
|
Return whether a training run is currently active and its run_id.
|
||||||
|
"""
|
||||||
|
with state.training_lock:
|
||||||
|
run_id = state.active_training_run_id
|
||||||
|
return ActiveTrainingResponse(active=run_id is not None, run_id=run_id)
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteRunResponse(BaseModel):
|
||||||
|
"""Response model for DELETE /training/runs/{run_id}."""
|
||||||
|
run_id: str
|
||||||
|
deleted: bool
|
||||||
|
|
||||||
|
|
||||||
|
@app.delete("/training/runs/{run_id}", response_model=DeleteRunResponse)
|
||||||
|
async def delete_training_run(run_id: str):
|
||||||
|
"""
|
||||||
|
Delete a training run record and its model artifact.
|
||||||
|
|
||||||
|
Returns HTTP 409 if the run is currently active.
|
||||||
|
Returns HTTP 404 if the run_id doesn't exist.
|
||||||
|
"""
|
||||||
|
from sqlalchemy import select, delete as sa_delete
|
||||||
|
|
||||||
|
# Reject deletion of the active run
|
||||||
|
with state.training_lock:
|
||||||
|
if state.active_training_run_id == run_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail="Cannot delete an active training run",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with get_db() as db:
|
||||||
|
stmt = select(TrainingRun).where(TrainingRun.run_id == run_id)
|
||||||
|
row = db.execute(stmt).scalar_one_or_none()
|
||||||
|
|
||||||
|
if row is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Training run not found: {run_id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute(sa_delete(TrainingRun).where(TrainingRun.run_id == run_id))
|
||||||
|
db.commit()
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Failed to delete training run {run_id}: {exc}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to delete training run: {exc}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove model artifact if it exists
|
||||||
|
model_path = Path("models") / f"{run_id}.pkl"
|
||||||
|
if model_path.exists():
|
||||||
|
try:
|
||||||
|
model_path.unlink()
|
||||||
|
logger.info(f"Deleted model artifact: {model_path}")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"Could not delete model artifact {model_path}: {exc}")
|
||||||
|
|
||||||
|
logger.info(f"Deleted training run: {run_id}")
|
||||||
|
return DeleteRunResponse(run_id=run_id, deleted=True)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/training/dataset-info", response_model=DatasetInfoResponse)
|
@app.get("/training/dataset-info", response_model=DatasetInfoResponse)
|
||||||
async def training_dataset_info():
|
async def training_dataset_info():
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
34
src/app/api/training/active/route.ts
Normal file
34
src/app/api/training/active/route.ts
Normal file
|
|
@ -0,0 +1,34 @@
|
||||||
|
import { NextRequest, NextResponse } from 'next/server';
|
||||||
|
|
||||||
|
const INFERENCE_API_URL = process.env.INFERENCE_API_URL || 'http://localhost:8001';
|
||||||
|
const INFERENCE_API_TIMEOUT = parseInt(process.env.INFERENCE_API_TIMEOUT || '10000', 10);
|
||||||
|
|
||||||
|
export async function GET(_request: NextRequest) {
|
||||||
|
const controller = new AbortController();
|
||||||
|
const timeoutId = setTimeout(() => controller.abort(), INFERENCE_API_TIMEOUT);
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await fetch(`${INFERENCE_API_URL}/training/active`, {
|
||||||
|
method: 'GET',
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
signal: controller.signal,
|
||||||
|
});
|
||||||
|
clearTimeout(timeoutId);
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
if (!response.ok) {
|
||||||
|
return NextResponse.json({ error: data.detail || 'Failed to fetch active run' }, { status: response.status });
|
||||||
|
}
|
||||||
|
return NextResponse.json(data);
|
||||||
|
} catch (error: any) {
|
||||||
|
clearTimeout(timeoutId);
|
||||||
|
if (error.name === 'AbortError') {
|
||||||
|
return NextResponse.json({ error: 'Request timed out' }, { status: 504 });
|
||||||
|
}
|
||||||
|
if (error.cause?.code === 'ECONNREFUSED' || error.message?.includes('fetch failed')) {
|
||||||
|
return NextResponse.json({ active: false, run_id: null });
|
||||||
|
}
|
||||||
|
console.error('training/active proxy error:', error);
|
||||||
|
return NextResponse.json({ error: 'Internal server error' }, { status: 500 });
|
||||||
|
}
|
||||||
|
}
|
||||||
37
src/app/api/training/runs/[run_id]/route.ts
Normal file
37
src/app/api/training/runs/[run_id]/route.ts
Normal file
|
|
@ -0,0 +1,37 @@
|
||||||
|
import { NextRequest, NextResponse } from 'next/server';
|
||||||
|
|
||||||
|
const INFERENCE_API_URL = process.env.INFERENCE_API_URL || 'http://localhost:8001';
|
||||||
|
const INFERENCE_API_TIMEOUT = parseInt(process.env.INFERENCE_API_TIMEOUT || '10000', 10);
|
||||||
|
|
||||||
|
export async function DELETE(
|
||||||
|
_request: NextRequest,
|
||||||
|
{ params }: { params: { run_id: string } }
|
||||||
|
) {
|
||||||
|
const { run_id } = params;
|
||||||
|
const controller = new AbortController();
|
||||||
|
const timeoutId = setTimeout(() => controller.abort(), INFERENCE_API_TIMEOUT);
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await fetch(`${INFERENCE_API_URL}/training/runs/${run_id}`, {
|
||||||
|
method: 'DELETE',
|
||||||
|
signal: controller.signal,
|
||||||
|
});
|
||||||
|
clearTimeout(timeoutId);
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
if (!response.ok) {
|
||||||
|
return NextResponse.json({ error: data.detail || 'Failed to delete run' }, { status: response.status });
|
||||||
|
}
|
||||||
|
return NextResponse.json(data);
|
||||||
|
} catch (error: any) {
|
||||||
|
clearTimeout(timeoutId);
|
||||||
|
if (error.name === 'AbortError') {
|
||||||
|
return NextResponse.json({ error: 'Request timed out' }, { status: 504 });
|
||||||
|
}
|
||||||
|
if (error.cause?.code === 'ECONNREFUSED' || error.message?.includes('fetch failed')) {
|
||||||
|
return NextResponse.json({ error: 'Inference service unavailable' }, { status: 503 });
|
||||||
|
}
|
||||||
|
console.error('training/runs DELETE proxy error:', error);
|
||||||
|
return NextResponse.json({ error: 'Internal server error' }, { status: 500 });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
'use client';
|
'use client';
|
||||||
|
|
||||||
import { useState, useEffect, useRef, useCallback } from 'react';
|
import { useState, useEffect, useRef, useCallback } from 'react';
|
||||||
import { ChevronDown } from 'lucide-react';
|
import { ChevronDown, Trash2 } from 'lucide-react';
|
||||||
|
|
||||||
interface DatasetInfo {
|
interface DatasetInfo {
|
||||||
path: string;
|
path: string;
|
||||||
|
|
@ -68,6 +68,7 @@ export default function TrainingPanel() {
|
||||||
const [statusMessage, setStatusMessage] = useState<{ type: 'success' | 'error'; text: string } | null>(null);
|
const [statusMessage, setStatusMessage] = useState<{ type: 'success' | 'error'; text: string } | null>(null);
|
||||||
const [isLoadingDataset, setIsLoadingDataset] = useState(false);
|
const [isLoadingDataset, setIsLoadingDataset] = useState(false);
|
||||||
const [isLoadingRuns, setIsLoadingRuns] = useState(false);
|
const [isLoadingRuns, setIsLoadingRuns] = useState(false);
|
||||||
|
const [deletingRunIds, setDeletingRunIds] = useState<Set<string>>(new Set());
|
||||||
const pollIntervalRef = useRef<NodeJS.Timeout | null>(null);
|
const pollIntervalRef = useRef<NodeJS.Timeout | null>(null);
|
||||||
|
|
||||||
const fetchDatasetInfo = useCallback(async () => {
|
const fetchDatasetInfo = useCallback(async () => {
|
||||||
|
|
@ -95,39 +96,53 @@ export default function TrainingPanel() {
|
||||||
}
|
}
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
// Load data when panel expands
|
// Fetch the authoritative active run from the backend
|
||||||
useEffect(() => {
|
const fetchActiveRun = useCallback(async (): Promise<string | null> => {
|
||||||
if (expanded) {
|
try {
|
||||||
fetchDatasetInfo();
|
const res = await fetch('/api/training/active');
|
||||||
setIsLoadingRuns(true);
|
if (!res.ok) return null;
|
||||||
fetchRuns().finally(() => setIsLoadingRuns(false));
|
const data = await res.json();
|
||||||
|
return data.active ? (data.run_id ?? null) : null;
|
||||||
|
} catch {
|
||||||
|
return null;
|
||||||
}
|
}
|
||||||
}, [expanded, fetchDatasetInfo, fetchRuns]);
|
}, []);
|
||||||
|
|
||||||
// Check if there's an active training run on expand
|
// Load data when panel expands; check backend for truly active run
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (expanded && runs.length > 0) {
|
if (!expanded) return;
|
||||||
const running = runs.find((r) => r.status === 'running');
|
|
||||||
if (running && !activeRunId) {
|
fetchDatasetInfo();
|
||||||
setActiveRunId(running.run_id);
|
setIsLoadingRuns(true);
|
||||||
|
|
||||||
|
Promise.all([fetchRuns(), fetchActiveRun()]).then(([_runs, serverActiveRunId]) => {
|
||||||
|
if (serverActiveRunId) {
|
||||||
|
setActiveRunId(serverActiveRunId);
|
||||||
setIsTraining(true);
|
setIsTraining(true);
|
||||||
|
} else {
|
||||||
|
// Ensure we are not stuck in training state
|
||||||
|
setIsTraining(false);
|
||||||
|
setActiveRunId(null);
|
||||||
}
|
}
|
||||||
}
|
}).finally(() => setIsLoadingRuns(false));
|
||||||
}, [expanded, runs, activeRunId]);
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
|
}, [expanded]);
|
||||||
|
|
||||||
// Poll while training is active
|
// Poll while training is active
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (isTraining && activeRunId) {
|
if (isTraining && activeRunId) {
|
||||||
pollIntervalRef.current = setInterval(async () => {
|
pollIntervalRef.current = setInterval(async () => {
|
||||||
const updatedRuns = await fetchRuns();
|
const [updatedRuns, serverActiveRunId] = await Promise.all([fetchRuns(), fetchActiveRun()]);
|
||||||
const activeRun = updatedRuns.find((r) => r.run_id === activeRunId);
|
|
||||||
if (activeRun && activeRun.status !== 'running') {
|
// If the server says nothing is active, training is done
|
||||||
|
if (!serverActiveRunId) {
|
||||||
setIsTraining(false);
|
setIsTraining(false);
|
||||||
|
const finishedRun = updatedRuns.find((r) => r.run_id === activeRunId);
|
||||||
setActiveRunId(null);
|
setActiveRunId(null);
|
||||||
if (pollIntervalRef.current) clearInterval(pollIntervalRef.current);
|
if (pollIntervalRef.current) clearInterval(pollIntervalRef.current);
|
||||||
|
|
||||||
if (activeRun.status === 'completed') {
|
if (finishedRun?.status === 'completed') {
|
||||||
const metrics = activeRun.metrics_summary;
|
const metrics = finishedRun.metrics_summary;
|
||||||
const metricStr = metrics
|
const metricStr = metrics
|
||||||
? Object.entries(metrics)
|
? Object.entries(metrics)
|
||||||
.map(([k, v]) => `${k}: ${typeof v === 'number' ? (v * 100).toFixed(1) + '%' : v}`)
|
.map(([k, v]) => `${k}: ${typeof v === 'number' ? (v * 100).toFixed(1) + '%' : v}`)
|
||||||
|
|
@ -137,10 +152,10 @@ export default function TrainingPanel() {
|
||||||
type: 'success',
|
type: 'success',
|
||||||
text: `Training complete!${metricStr ? ' ' + metricStr : ''}`,
|
text: `Training complete!${metricStr ? ' ' + metricStr : ''}`,
|
||||||
});
|
});
|
||||||
} else {
|
} else if (finishedRun?.status === 'failed') {
|
||||||
setStatusMessage({
|
setStatusMessage({
|
||||||
type: 'error',
|
type: 'error',
|
||||||
text: activeRun.error || 'Training failed',
|
text: finishedRun.error || 'Training failed',
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -155,7 +170,7 @@ export default function TrainingPanel() {
|
||||||
return () => {
|
return () => {
|
||||||
if (pollIntervalRef.current) clearInterval(pollIntervalRef.current);
|
if (pollIntervalRef.current) clearInterval(pollIntervalRef.current);
|
||||||
};
|
};
|
||||||
}, [isTraining, activeRunId, fetchRuns]);
|
}, [isTraining, activeRunId, fetchRuns, fetchActiveRun]);
|
||||||
|
|
||||||
const handleStartTraining = async () => {
|
const handleStartTraining = async () => {
|
||||||
setStatusMessage(null);
|
setStatusMessage(null);
|
||||||
|
|
@ -192,6 +207,27 @@ export default function TrainingPanel() {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const handleDeleteRun = async (runId: string) => {
|
||||||
|
setDeletingRunIds((prev) => new Set(prev).add(runId));
|
||||||
|
try {
|
||||||
|
const res = await fetch(`/api/training/runs/${runId}`, { method: 'DELETE' });
|
||||||
|
if (!res.ok) {
|
||||||
|
const data = await res.json().catch(() => ({}));
|
||||||
|
setStatusMessage({ type: 'error', text: data.error || data.detail || 'Failed to delete run' });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
await fetchRuns();
|
||||||
|
} catch {
|
||||||
|
setStatusMessage({ type: 'error', text: 'Failed to delete run' });
|
||||||
|
} finally {
|
||||||
|
setDeletingRunIds((prev) => {
|
||||||
|
const next = new Set(prev);
|
||||||
|
next.delete(runId);
|
||||||
|
return next;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
const datasetMissing = datasetInfo !== null && !datasetInfo.exists;
|
const datasetMissing = datasetInfo !== null && !datasetInfo.exists;
|
||||||
const canTrain = !isTraining && !datasetMissing && datasetInfo !== null;
|
const canTrain = !isTraining && !datasetMissing && datasetInfo !== null;
|
||||||
|
|
||||||
|
|
@ -301,7 +337,19 @@ export default function TrainingPanel() {
|
||||||
<span className="text-foreground font-medium capitalize">
|
<span className="text-foreground font-medium capitalize">
|
||||||
{run.model_type.replace('_', ' ')}
|
{run.model_type.replace('_', ' ')}
|
||||||
</span>
|
</span>
|
||||||
<StatusBadge status={run.status} />
|
<div className="flex items-center gap-1">
|
||||||
|
<StatusBadge status={run.status} />
|
||||||
|
{run.status !== 'running' && (
|
||||||
|
<button
|
||||||
|
onClick={() => handleDeleteRun(run.run_id)}
|
||||||
|
disabled={deletingRunIds.has(run.run_id)}
|
||||||
|
title="Delete run"
|
||||||
|
className="text-muted-foreground hover:text-destructive transition-colors disabled:opacity-40"
|
||||||
|
>
|
||||||
|
<Trash2 className="w-2.5 h-2.5" />
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div className="text-muted-foreground">{formatDate(run.created_at)}</div>
|
<div className="text-muted-foreground">{formatDate(run.created_at)}</div>
|
||||||
{run.status === 'completed' && run.metrics_summary && (
|
{run.status === 'completed' && run.metrics_summary && (
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue