feat: add run_id format validation to GET /training/runs/{run_id} endpoint
- Add new GET endpoint for retrieving a specific training run by run_id - Validate run_id format with regex pattern ^[a-zA-Z0-9_-]+$ before DB access - Return HTTP 400 for invalid run_id format, HTTP 404 for non-existent runs - Ensure DELETE endpoint validation is correctly placed before any DB access - Both endpoints now provide consistent security validation - Mark task 5.8 as completed Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
3dc0014328
commit
3d8672121e
2 changed files with 48 additions and 1 deletions
|
|
@ -48,7 +48,7 @@
|
|||
- [x] 5.5 `[sonnet]` Add candle time-sort validation/auto-sort to `POST /predict` in `services/ml/app/main.py`
|
||||
- [x] 5.6 `[sonnet]` Implement real health checks: `SELECT 1` for PostgreSQL, MLflow API ping in `services/ml/app/main.py:396-409`
|
||||
- [x] 5.7 `[sonnet]` Add training resource limits: 500MB dataset size check, 30-minute timeout with status update on expiry in `services/ml/app/main.py:907-1030`
|
||||
- [ ] 5.8 `[haiku]` Add `run_id` format validation to `DELETE /training/runs/{run_id}` and `GET /training/runs/{run_id}` endpoints
|
||||
- [x] 5.8 `[haiku]` Add `run_id` format validation to `DELETE /training/runs/{run_id}` and `GET /training/runs/{run_id}` endpoints
|
||||
|
||||
## 6. Infrastructure & Docker
|
||||
|
||||
|
|
|
|||
|
|
@ -1512,6 +1512,53 @@ async def delete_training_run(run_id: str):
|
|||
return DeleteRunResponse(run_id=run_id, deleted=True)
|
||||
|
||||
|
||||
@app.get("/training/runs/{run_id}", response_model=TrainingRunInfo, dependencies=[Depends(verify_api_key)])
|
||||
async def get_training_run(run_id: str):
|
||||
"""
|
||||
Get information about a specific training run by run_id.
|
||||
|
||||
Returns HTTP 400 if the run_id format is invalid.
|
||||
Returns HTTP 404 if the run_id doesn't exist.
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
|
||||
# Validate run_id format to prevent path traversal
|
||||
if not re.match(r'^[a-zA-Z0-9_-]+$', run_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid run_id format",
|
||||
)
|
||||
|
||||
try:
|
||||
with get_db() as db:
|
||||
stmt = select(TrainingRun).where(TrainingRun.run_id == run_id)
|
||||
row = db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if row is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Training run not found: {run_id}",
|
||||
)
|
||||
|
||||
return TrainingRunInfo(
|
||||
run_id=row.run_id,
|
||||
model_type=row.model_type,
|
||||
status=row.status,
|
||||
experiment_name=row.experiment_name,
|
||||
created_at=row.created_at.isoformat() if row.created_at else None,
|
||||
completed_at=row.completed_at.isoformat() if row.completed_at else None,
|
||||
metrics_summary=row.metrics_summary,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to fetch training run {run_id}: {exc}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error",
|
||||
)
|
||||
|
||||
|
||||
@app.get("/training/dataset-info", response_model=DatasetInfoResponse, dependencies=[Depends(verify_api_key)])
|
||||
async def training_dataset_info():
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue