feat: add run_id format validation to GET /training/runs/{run_id} endpoint

- Add new GET endpoint for retrieving a specific training run by run_id
- Validate run_id format with regex pattern ^[a-zA-Z0-9_-]+$ before DB access
- Return HTTP 400 for invalid run_id format, HTTP 404 for non-existent runs
- Ensure DELETE endpoint validation is correctly placed before any DB access
- Both endpoints now provide consistent security validation
- Mark task 5.8 as completed

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marko Djordjevic 2026-02-18 11:33:12 +01:00
parent 3dc0014328
commit 3d8672121e
2 changed files with 48 additions and 1 deletions

View file

@ -48,7 +48,7 @@
- [x] 5.5 `[sonnet]` Add candle time-sort validation/auto-sort to `POST /predict` in `services/ml/app/main.py`
- [x] 5.6 `[sonnet]` Implement real health checks: `SELECT 1` for PostgreSQL, MLflow API ping in `services/ml/app/main.py:396-409`
- [x] 5.7 `[sonnet]` Add training resource limits: 500MB dataset size check, 30-minute timeout with status update on expiry in `services/ml/app/main.py:907-1030`
- [ ] 5.8 `[haiku]` Add `run_id` format validation to `DELETE /training/runs/{run_id}` and `GET /training/runs/{run_id}` endpoints
- [x] 5.8 `[haiku]` Add `run_id` format validation to `DELETE /training/runs/{run_id}` and `GET /training/runs/{run_id}` endpoints
## 6. Infrastructure & Docker

View file

@ -1512,6 +1512,53 @@ async def delete_training_run(run_id: str):
return DeleteRunResponse(run_id=run_id, deleted=True)
@app.get("/training/runs/{run_id}", response_model=TrainingRunInfo, dependencies=[Depends(verify_api_key)])
async def get_training_run(run_id: str):
"""
Get information about a specific training run by run_id.
Returns HTTP 400 if the run_id format is invalid.
Returns HTTP 404 if the run_id doesn't exist.
"""
from sqlalchemy import select
# Validate run_id format to prevent path traversal
if not re.match(r'^[a-zA-Z0-9_-]+$', run_id):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid run_id format",
)
try:
with get_db() as db:
stmt = select(TrainingRun).where(TrainingRun.run_id == run_id)
row = db.execute(stmt).scalar_one_or_none()
if row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Training run not found: {run_id}",
)
return TrainingRunInfo(
run_id=row.run_id,
model_type=row.model_type,
status=row.status,
experiment_name=row.experiment_name,
created_at=row.created_at.isoformat() if row.created_at else None,
completed_at=row.completed_at.isoformat() if row.completed_at else None,
metrics_summary=row.metrics_summary,
)
except HTTPException:
raise
except Exception as exc:
logger.error(f"Failed to fetch training run {run_id}: {exc}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal server error",
)
@app.get("/training/dataset-info", response_model=DatasetInfoResponse, dependencies=[Depends(verify_api_key)])
async def training_dataset_info():
"""