diff --git a/services/ml/app/main.py b/services/ml/app/main.py index 74aa572..993361b 100644 --- a/services/ml/app/main.py +++ b/services/ml/app/main.py @@ -310,6 +310,28 @@ async def startup_event(): Load model and pipeline config on startup. """ logger.info("Starting inference service...") + + # Mark any stale "running" records as failed — they belong to a previous + # process and will never complete. + try: + with get_db() as db: + stmt = ( + sa_update(TrainingRun) + .where(TrainingRun.status == "running") + .values( + status="failed", + completed_at=datetime.utcnow(), + metrics_summary={"error": "Service restarted while training was in progress"}, + ) + ) + result = db.execute(stmt) + db.commit() + if result.rowcount: + logger.warning( + f"Marked {result.rowcount} stale training run(s) as failed on startup" + ) + except Exception as exc: + logger.error(f"Failed to clean up stale training runs: {exc}") # Load pipeline config config_path = Path("config/pipeline.yaml") @@ -1115,6 +1137,81 @@ async def training_runs(): return TrainingRunsResponse(runs=runs) +class ActiveTrainingResponse(BaseModel): + """Response model for GET /training/active.""" + active: bool + run_id: Optional[str] = None + + +@app.get("/training/active", response_model=ActiveTrainingResponse) +async def training_active(): + """ + Return whether a training run is currently active and its run_id. + """ + with state.training_lock: + run_id = state.active_training_run_id + return ActiveTrainingResponse(active=run_id is not None, run_id=run_id) + + +class DeleteRunResponse(BaseModel): + """Response model for DELETE /training/runs/{run_id}.""" + run_id: str + deleted: bool + + +@app.delete("/training/runs/{run_id}", response_model=DeleteRunResponse) +async def delete_training_run(run_id: str): + """ + Delete a training run record and its model artifact. + + Returns HTTP 409 if the run is currently active. + Returns HTTP 404 if the run_id doesn't exist. + """ + from sqlalchemy import select, delete as sa_delete + + # Reject deletion of the active run + with state.training_lock: + if state.active_training_run_id == run_id: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Cannot delete an active training run", + ) + + try: + with get_db() as db: + stmt = select(TrainingRun).where(TrainingRun.run_id == run_id) + row = db.execute(stmt).scalar_one_or_none() + + if row is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Training run not found: {run_id}", + ) + + db.execute(sa_delete(TrainingRun).where(TrainingRun.run_id == run_id)) + db.commit() + except HTTPException: + raise + except Exception as exc: + logger.error(f"Failed to delete training run {run_id}: {exc}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to delete training run: {exc}", + ) + + # Remove model artifact if it exists + model_path = Path("models") / f"{run_id}.pkl" + if model_path.exists(): + try: + model_path.unlink() + logger.info(f"Deleted model artifact: {model_path}") + except Exception as exc: + logger.warning(f"Could not delete model artifact {model_path}: {exc}") + + logger.info(f"Deleted training run: {run_id}") + return DeleteRunResponse(run_id=run_id, deleted=True) + + @app.get("/training/dataset-info", response_model=DatasetInfoResponse) async def training_dataset_info(): """ diff --git a/src/app/api/training/active/route.ts b/src/app/api/training/active/route.ts new file mode 100644 index 0000000..b62aa36 --- /dev/null +++ b/src/app/api/training/active/route.ts @@ -0,0 +1,34 @@ +import { NextRequest, NextResponse } from 'next/server'; + +const INFERENCE_API_URL = process.env.INFERENCE_API_URL || 'http://localhost:8001'; +const INFERENCE_API_TIMEOUT = parseInt(process.env.INFERENCE_API_TIMEOUT || '10000', 10); + +export async function GET(_request: NextRequest) { + const controller = new AbortController(); + const timeoutId = setTimeout(() => controller.abort(), INFERENCE_API_TIMEOUT); + + try { + const response = await fetch(`${INFERENCE_API_URL}/training/active`, { + method: 'GET', + headers: { 'Content-Type': 'application/json' }, + signal: controller.signal, + }); + clearTimeout(timeoutId); + + const data = await response.json(); + if (!response.ok) { + return NextResponse.json({ error: data.detail || 'Failed to fetch active run' }, { status: response.status }); + } + return NextResponse.json(data); + } catch (error: any) { + clearTimeout(timeoutId); + if (error.name === 'AbortError') { + return NextResponse.json({ error: 'Request timed out' }, { status: 504 }); + } + if (error.cause?.code === 'ECONNREFUSED' || error.message?.includes('fetch failed')) { + return NextResponse.json({ active: false, run_id: null }); + } + console.error('training/active proxy error:', error); + return NextResponse.json({ error: 'Internal server error' }, { status: 500 }); + } +} diff --git a/src/app/api/training/runs/[run_id]/route.ts b/src/app/api/training/runs/[run_id]/route.ts new file mode 100644 index 0000000..bce03d8 --- /dev/null +++ b/src/app/api/training/runs/[run_id]/route.ts @@ -0,0 +1,37 @@ +import { NextRequest, NextResponse } from 'next/server'; + +const INFERENCE_API_URL = process.env.INFERENCE_API_URL || 'http://localhost:8001'; +const INFERENCE_API_TIMEOUT = parseInt(process.env.INFERENCE_API_TIMEOUT || '10000', 10); + +export async function DELETE( + _request: NextRequest, + { params }: { params: { run_id: string } } +) { + const { run_id } = params; + const controller = new AbortController(); + const timeoutId = setTimeout(() => controller.abort(), INFERENCE_API_TIMEOUT); + + try { + const response = await fetch(`${INFERENCE_API_URL}/training/runs/${run_id}`, { + method: 'DELETE', + signal: controller.signal, + }); + clearTimeout(timeoutId); + + const data = await response.json(); + if (!response.ok) { + return NextResponse.json({ error: data.detail || 'Failed to delete run' }, { status: response.status }); + } + return NextResponse.json(data); + } catch (error: any) { + clearTimeout(timeoutId); + if (error.name === 'AbortError') { + return NextResponse.json({ error: 'Request timed out' }, { status: 504 }); + } + if (error.cause?.code === 'ECONNREFUSED' || error.message?.includes('fetch failed')) { + return NextResponse.json({ error: 'Inference service unavailable' }, { status: 503 }); + } + console.error('training/runs DELETE proxy error:', error); + return NextResponse.json({ error: 'Internal server error' }, { status: 500 }); + } +} diff --git a/src/components/TrainingPanel.tsx b/src/components/TrainingPanel.tsx index c1efd31..0afeac8 100644 --- a/src/components/TrainingPanel.tsx +++ b/src/components/TrainingPanel.tsx @@ -1,7 +1,7 @@ 'use client'; import { useState, useEffect, useRef, useCallback } from 'react'; -import { ChevronDown } from 'lucide-react'; +import { ChevronDown, Trash2 } from 'lucide-react'; interface DatasetInfo { path: string; @@ -68,6 +68,7 @@ export default function TrainingPanel() { const [statusMessage, setStatusMessage] = useState<{ type: 'success' | 'error'; text: string } | null>(null); const [isLoadingDataset, setIsLoadingDataset] = useState(false); const [isLoadingRuns, setIsLoadingRuns] = useState(false); + const [deletingRunIds, setDeletingRunIds] = useState>(new Set()); const pollIntervalRef = useRef(null); const fetchDatasetInfo = useCallback(async () => { @@ -95,39 +96,53 @@ export default function TrainingPanel() { } }, []); - // Load data when panel expands - useEffect(() => { - if (expanded) { - fetchDatasetInfo(); - setIsLoadingRuns(true); - fetchRuns().finally(() => setIsLoadingRuns(false)); + // Fetch the authoritative active run from the backend + const fetchActiveRun = useCallback(async (): Promise => { + try { + const res = await fetch('/api/training/active'); + if (!res.ok) return null; + const data = await res.json(); + return data.active ? (data.run_id ?? null) : null; + } catch { + return null; } - }, [expanded, fetchDatasetInfo, fetchRuns]); + }, []); - // Check if there's an active training run on expand + // Load data when panel expands; check backend for truly active run useEffect(() => { - if (expanded && runs.length > 0) { - const running = runs.find((r) => r.status === 'running'); - if (running && !activeRunId) { - setActiveRunId(running.run_id); + if (!expanded) return; + + fetchDatasetInfo(); + setIsLoadingRuns(true); + + Promise.all([fetchRuns(), fetchActiveRun()]).then(([_runs, serverActiveRunId]) => { + if (serverActiveRunId) { + setActiveRunId(serverActiveRunId); setIsTraining(true); + } else { + // Ensure we are not stuck in training state + setIsTraining(false); + setActiveRunId(null); } - } - }, [expanded, runs, activeRunId]); + }).finally(() => setIsLoadingRuns(false)); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [expanded]); // Poll while training is active useEffect(() => { if (isTraining && activeRunId) { pollIntervalRef.current = setInterval(async () => { - const updatedRuns = await fetchRuns(); - const activeRun = updatedRuns.find((r) => r.run_id === activeRunId); - if (activeRun && activeRun.status !== 'running') { + const [updatedRuns, serverActiveRunId] = await Promise.all([fetchRuns(), fetchActiveRun()]); + + // If the server says nothing is active, training is done + if (!serverActiveRunId) { setIsTraining(false); + const finishedRun = updatedRuns.find((r) => r.run_id === activeRunId); setActiveRunId(null); if (pollIntervalRef.current) clearInterval(pollIntervalRef.current); - if (activeRun.status === 'completed') { - const metrics = activeRun.metrics_summary; + if (finishedRun?.status === 'completed') { + const metrics = finishedRun.metrics_summary; const metricStr = metrics ? Object.entries(metrics) .map(([k, v]) => `${k}: ${typeof v === 'number' ? (v * 100).toFixed(1) + '%' : v}`) @@ -137,10 +152,10 @@ export default function TrainingPanel() { type: 'success', text: `Training complete!${metricStr ? ' ' + metricStr : ''}`, }); - } else { + } else if (finishedRun?.status === 'failed') { setStatusMessage({ type: 'error', - text: activeRun.error || 'Training failed', + text: finishedRun.error || 'Training failed', }); } } @@ -155,7 +170,7 @@ export default function TrainingPanel() { return () => { if (pollIntervalRef.current) clearInterval(pollIntervalRef.current); }; - }, [isTraining, activeRunId, fetchRuns]); + }, [isTraining, activeRunId, fetchRuns, fetchActiveRun]); const handleStartTraining = async () => { setStatusMessage(null); @@ -192,6 +207,27 @@ export default function TrainingPanel() { } }; + const handleDeleteRun = async (runId: string) => { + setDeletingRunIds((prev) => new Set(prev).add(runId)); + try { + const res = await fetch(`/api/training/runs/${runId}`, { method: 'DELETE' }); + if (!res.ok) { + const data = await res.json().catch(() => ({})); + setStatusMessage({ type: 'error', text: data.error || data.detail || 'Failed to delete run' }); + return; + } + await fetchRuns(); + } catch { + setStatusMessage({ type: 'error', text: 'Failed to delete run' }); + } finally { + setDeletingRunIds((prev) => { + const next = new Set(prev); + next.delete(runId); + return next; + }); + } + }; + const datasetMissing = datasetInfo !== null && !datasetInfo.exists; const canTrain = !isTraining && !datasetMissing && datasetInfo !== null; @@ -301,7 +337,19 @@ export default function TrainingPanel() { {run.model_type.replace('_', ' ')} - +
+ + {run.status !== 'running' && ( + + )} +
{formatDate(run.created_at)}
{run.status === 'completed' && run.metrics_summary && (