fix: training panel stuck button, stale runs on startup, add delete model
This commit is contained in:
parent
d34dc9d729
commit
6ef102cf21
4 changed files with 240 additions and 24 deletions
|
|
@ -310,6 +310,28 @@ async def startup_event():
|
|||
Load model and pipeline config on startup.
|
||||
"""
|
||||
logger.info("Starting inference service...")
|
||||
|
||||
# Mark any stale "running" records as failed — they belong to a previous
|
||||
# process and will never complete.
|
||||
try:
|
||||
with get_db() as db:
|
||||
stmt = (
|
||||
sa_update(TrainingRun)
|
||||
.where(TrainingRun.status == "running")
|
||||
.values(
|
||||
status="failed",
|
||||
completed_at=datetime.utcnow(),
|
||||
metrics_summary={"error": "Service restarted while training was in progress"},
|
||||
)
|
||||
)
|
||||
result = db.execute(stmt)
|
||||
db.commit()
|
||||
if result.rowcount:
|
||||
logger.warning(
|
||||
f"Marked {result.rowcount} stale training run(s) as failed on startup"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to clean up stale training runs: {exc}")
|
||||
|
||||
# Load pipeline config
|
||||
config_path = Path("config/pipeline.yaml")
|
||||
|
|
@ -1115,6 +1137,81 @@ async def training_runs():
|
|||
return TrainingRunsResponse(runs=runs)
|
||||
|
||||
|
||||
class ActiveTrainingResponse(BaseModel):
|
||||
"""Response model for GET /training/active."""
|
||||
active: bool
|
||||
run_id: Optional[str] = None
|
||||
|
||||
|
||||
@app.get("/training/active", response_model=ActiveTrainingResponse)
|
||||
async def training_active():
|
||||
"""
|
||||
Return whether a training run is currently active and its run_id.
|
||||
"""
|
||||
with state.training_lock:
|
||||
run_id = state.active_training_run_id
|
||||
return ActiveTrainingResponse(active=run_id is not None, run_id=run_id)
|
||||
|
||||
|
||||
class DeleteRunResponse(BaseModel):
|
||||
"""Response model for DELETE /training/runs/{run_id}."""
|
||||
run_id: str
|
||||
deleted: bool
|
||||
|
||||
|
||||
@app.delete("/training/runs/{run_id}", response_model=DeleteRunResponse)
|
||||
async def delete_training_run(run_id: str):
|
||||
"""
|
||||
Delete a training run record and its model artifact.
|
||||
|
||||
Returns HTTP 409 if the run is currently active.
|
||||
Returns HTTP 404 if the run_id doesn't exist.
|
||||
"""
|
||||
from sqlalchemy import select, delete as sa_delete
|
||||
|
||||
# Reject deletion of the active run
|
||||
with state.training_lock:
|
||||
if state.active_training_run_id == run_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Cannot delete an active training run",
|
||||
)
|
||||
|
||||
try:
|
||||
with get_db() as db:
|
||||
stmt = select(TrainingRun).where(TrainingRun.run_id == run_id)
|
||||
row = db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if row is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Training run not found: {run_id}",
|
||||
)
|
||||
|
||||
db.execute(sa_delete(TrainingRun).where(TrainingRun.run_id == run_id))
|
||||
db.commit()
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to delete training run {run_id}: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to delete training run: {exc}",
|
||||
)
|
||||
|
||||
# Remove model artifact if it exists
|
||||
model_path = Path("models") / f"{run_id}.pkl"
|
||||
if model_path.exists():
|
||||
try:
|
||||
model_path.unlink()
|
||||
logger.info(f"Deleted model artifact: {model_path}")
|
||||
except Exception as exc:
|
||||
logger.warning(f"Could not delete model artifact {model_path}: {exc}")
|
||||
|
||||
logger.info(f"Deleted training run: {run_id}")
|
||||
return DeleteRunResponse(run_id=run_id, deleted=True)
|
||||
|
||||
|
||||
@app.get("/training/dataset-info", response_model=DatasetInfoResponse)
|
||||
async def training_dataset_info():
|
||||
"""
|
||||
|
|
|
|||
34
src/app/api/training/active/route.ts
Normal file
34
src/app/api/training/active/route.ts
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
import { NextRequest, NextResponse } from 'next/server';
|
||||
|
||||
const INFERENCE_API_URL = process.env.INFERENCE_API_URL || 'http://localhost:8001';
|
||||
const INFERENCE_API_TIMEOUT = parseInt(process.env.INFERENCE_API_TIMEOUT || '10000', 10);
|
||||
|
||||
export async function GET(_request: NextRequest) {
|
||||
const controller = new AbortController();
|
||||
const timeoutId = setTimeout(() => controller.abort(), INFERENCE_API_TIMEOUT);
|
||||
|
||||
try {
|
||||
const response = await fetch(`${INFERENCE_API_URL}/training/active`, {
|
||||
method: 'GET',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
signal: controller.signal,
|
||||
});
|
||||
clearTimeout(timeoutId);
|
||||
|
||||
const data = await response.json();
|
||||
if (!response.ok) {
|
||||
return NextResponse.json({ error: data.detail || 'Failed to fetch active run' }, { status: response.status });
|
||||
}
|
||||
return NextResponse.json(data);
|
||||
} catch (error: any) {
|
||||
clearTimeout(timeoutId);
|
||||
if (error.name === 'AbortError') {
|
||||
return NextResponse.json({ error: 'Request timed out' }, { status: 504 });
|
||||
}
|
||||
if (error.cause?.code === 'ECONNREFUSED' || error.message?.includes('fetch failed')) {
|
||||
return NextResponse.json({ active: false, run_id: null });
|
||||
}
|
||||
console.error('training/active proxy error:', error);
|
||||
return NextResponse.json({ error: 'Internal server error' }, { status: 500 });
|
||||
}
|
||||
}
|
||||
37
src/app/api/training/runs/[run_id]/route.ts
Normal file
37
src/app/api/training/runs/[run_id]/route.ts
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
import { NextRequest, NextResponse } from 'next/server';
|
||||
|
||||
const INFERENCE_API_URL = process.env.INFERENCE_API_URL || 'http://localhost:8001';
|
||||
const INFERENCE_API_TIMEOUT = parseInt(process.env.INFERENCE_API_TIMEOUT || '10000', 10);
|
||||
|
||||
export async function DELETE(
|
||||
_request: NextRequest,
|
||||
{ params }: { params: { run_id: string } }
|
||||
) {
|
||||
const { run_id } = params;
|
||||
const controller = new AbortController();
|
||||
const timeoutId = setTimeout(() => controller.abort(), INFERENCE_API_TIMEOUT);
|
||||
|
||||
try {
|
||||
const response = await fetch(`${INFERENCE_API_URL}/training/runs/${run_id}`, {
|
||||
method: 'DELETE',
|
||||
signal: controller.signal,
|
||||
});
|
||||
clearTimeout(timeoutId);
|
||||
|
||||
const data = await response.json();
|
||||
if (!response.ok) {
|
||||
return NextResponse.json({ error: data.detail || 'Failed to delete run' }, { status: response.status });
|
||||
}
|
||||
return NextResponse.json(data);
|
||||
} catch (error: any) {
|
||||
clearTimeout(timeoutId);
|
||||
if (error.name === 'AbortError') {
|
||||
return NextResponse.json({ error: 'Request timed out' }, { status: 504 });
|
||||
}
|
||||
if (error.cause?.code === 'ECONNREFUSED' || error.message?.includes('fetch failed')) {
|
||||
return NextResponse.json({ error: 'Inference service unavailable' }, { status: 503 });
|
||||
}
|
||||
console.error('training/runs DELETE proxy error:', error);
|
||||
return NextResponse.json({ error: 'Internal server error' }, { status: 500 });
|
||||
}
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
'use client';
|
||||
|
||||
import { useState, useEffect, useRef, useCallback } from 'react';
|
||||
import { ChevronDown } from 'lucide-react';
|
||||
import { ChevronDown, Trash2 } from 'lucide-react';
|
||||
|
||||
interface DatasetInfo {
|
||||
path: string;
|
||||
|
|
@ -68,6 +68,7 @@ export default function TrainingPanel() {
|
|||
const [statusMessage, setStatusMessage] = useState<{ type: 'success' | 'error'; text: string } | null>(null);
|
||||
const [isLoadingDataset, setIsLoadingDataset] = useState(false);
|
||||
const [isLoadingRuns, setIsLoadingRuns] = useState(false);
|
||||
const [deletingRunIds, setDeletingRunIds] = useState<Set<string>>(new Set());
|
||||
const pollIntervalRef = useRef<NodeJS.Timeout | null>(null);
|
||||
|
||||
const fetchDatasetInfo = useCallback(async () => {
|
||||
|
|
@ -95,39 +96,53 @@ export default function TrainingPanel() {
|
|||
}
|
||||
}, []);
|
||||
|
||||
// Load data when panel expands
|
||||
useEffect(() => {
|
||||
if (expanded) {
|
||||
fetchDatasetInfo();
|
||||
setIsLoadingRuns(true);
|
||||
fetchRuns().finally(() => setIsLoadingRuns(false));
|
||||
// Fetch the authoritative active run from the backend
|
||||
const fetchActiveRun = useCallback(async (): Promise<string | null> => {
|
||||
try {
|
||||
const res = await fetch('/api/training/active');
|
||||
if (!res.ok) return null;
|
||||
const data = await res.json();
|
||||
return data.active ? (data.run_id ?? null) : null;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}, [expanded, fetchDatasetInfo, fetchRuns]);
|
||||
}, []);
|
||||
|
||||
// Check if there's an active training run on expand
|
||||
// Load data when panel expands; check backend for truly active run
|
||||
useEffect(() => {
|
||||
if (expanded && runs.length > 0) {
|
||||
const running = runs.find((r) => r.status === 'running');
|
||||
if (running && !activeRunId) {
|
||||
setActiveRunId(running.run_id);
|
||||
if (!expanded) return;
|
||||
|
||||
fetchDatasetInfo();
|
||||
setIsLoadingRuns(true);
|
||||
|
||||
Promise.all([fetchRuns(), fetchActiveRun()]).then(([_runs, serverActiveRunId]) => {
|
||||
if (serverActiveRunId) {
|
||||
setActiveRunId(serverActiveRunId);
|
||||
setIsTraining(true);
|
||||
} else {
|
||||
// Ensure we are not stuck in training state
|
||||
setIsTraining(false);
|
||||
setActiveRunId(null);
|
||||
}
|
||||
}
|
||||
}, [expanded, runs, activeRunId]);
|
||||
}).finally(() => setIsLoadingRuns(false));
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [expanded]);
|
||||
|
||||
// Poll while training is active
|
||||
useEffect(() => {
|
||||
if (isTraining && activeRunId) {
|
||||
pollIntervalRef.current = setInterval(async () => {
|
||||
const updatedRuns = await fetchRuns();
|
||||
const activeRun = updatedRuns.find((r) => r.run_id === activeRunId);
|
||||
if (activeRun && activeRun.status !== 'running') {
|
||||
const [updatedRuns, serverActiveRunId] = await Promise.all([fetchRuns(), fetchActiveRun()]);
|
||||
|
||||
// If the server says nothing is active, training is done
|
||||
if (!serverActiveRunId) {
|
||||
setIsTraining(false);
|
||||
const finishedRun = updatedRuns.find((r) => r.run_id === activeRunId);
|
||||
setActiveRunId(null);
|
||||
if (pollIntervalRef.current) clearInterval(pollIntervalRef.current);
|
||||
|
||||
if (activeRun.status === 'completed') {
|
||||
const metrics = activeRun.metrics_summary;
|
||||
if (finishedRun?.status === 'completed') {
|
||||
const metrics = finishedRun.metrics_summary;
|
||||
const metricStr = metrics
|
||||
? Object.entries(metrics)
|
||||
.map(([k, v]) => `${k}: ${typeof v === 'number' ? (v * 100).toFixed(1) + '%' : v}`)
|
||||
|
|
@ -137,10 +152,10 @@ export default function TrainingPanel() {
|
|||
type: 'success',
|
||||
text: `Training complete!${metricStr ? ' ' + metricStr : ''}`,
|
||||
});
|
||||
} else {
|
||||
} else if (finishedRun?.status === 'failed') {
|
||||
setStatusMessage({
|
||||
type: 'error',
|
||||
text: activeRun.error || 'Training failed',
|
||||
text: finishedRun.error || 'Training failed',
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
@ -155,7 +170,7 @@ export default function TrainingPanel() {
|
|||
return () => {
|
||||
if (pollIntervalRef.current) clearInterval(pollIntervalRef.current);
|
||||
};
|
||||
}, [isTraining, activeRunId, fetchRuns]);
|
||||
}, [isTraining, activeRunId, fetchRuns, fetchActiveRun]);
|
||||
|
||||
const handleStartTraining = async () => {
|
||||
setStatusMessage(null);
|
||||
|
|
@ -192,6 +207,27 @@ export default function TrainingPanel() {
|
|||
}
|
||||
};
|
||||
|
||||
const handleDeleteRun = async (runId: string) => {
|
||||
setDeletingRunIds((prev) => new Set(prev).add(runId));
|
||||
try {
|
||||
const res = await fetch(`/api/training/runs/${runId}`, { method: 'DELETE' });
|
||||
if (!res.ok) {
|
||||
const data = await res.json().catch(() => ({}));
|
||||
setStatusMessage({ type: 'error', text: data.error || data.detail || 'Failed to delete run' });
|
||||
return;
|
||||
}
|
||||
await fetchRuns();
|
||||
} catch {
|
||||
setStatusMessage({ type: 'error', text: 'Failed to delete run' });
|
||||
} finally {
|
||||
setDeletingRunIds((prev) => {
|
||||
const next = new Set(prev);
|
||||
next.delete(runId);
|
||||
return next;
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const datasetMissing = datasetInfo !== null && !datasetInfo.exists;
|
||||
const canTrain = !isTraining && !datasetMissing && datasetInfo !== null;
|
||||
|
||||
|
|
@ -301,7 +337,19 @@ export default function TrainingPanel() {
|
|||
<span className="text-foreground font-medium capitalize">
|
||||
{run.model_type.replace('_', ' ')}
|
||||
</span>
|
||||
<StatusBadge status={run.status} />
|
||||
<div className="flex items-center gap-1">
|
||||
<StatusBadge status={run.status} />
|
||||
{run.status !== 'running' && (
|
||||
<button
|
||||
onClick={() => handleDeleteRun(run.run_id)}
|
||||
disabled={deletingRunIds.has(run.run_id)}
|
||||
title="Delete run"
|
||||
className="text-muted-foreground hover:text-destructive transition-colors disabled:opacity-40"
|
||||
>
|
||||
<Trash2 className="w-2.5 h-2.5" />
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<div className="text-muted-foreground">{formatDate(run.created_at)}</div>
|
||||
{run.status === 'completed' && run.metrics_summary && (
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue