diff --git a/openspec/changes/code-review-fix/tasks.md b/openspec/changes/code-review-fix/tasks.md index 6da5d84..ca90601 100644 --- a/openspec/changes/code-review-fix/tasks.md +++ b/openspec/changes/code-review-fix/tasks.md @@ -12,7 +12,7 @@ ## 2. Security Critical — Input Validation & CORS - [x] 2.1 `[haiku]` Validate `run_id` matches `/^[a-zA-Z0-9_-]+$/` in `src/app/api/training/runs/[run_id]/route.ts` before interpolation -- [ ] 2.2 `[sonnet]` Validate `run_id` format and use `Path.resolve()` + directory containment check in `services/ml/app/main.py` (model load at line 1203, delete at line 1312) +- [x] 2.2 `[sonnet]` Validate `run_id` format and use `Path.resolve()` + directory containment check in `services/ml/app/main.py` (model load at line 1203, delete at line 1312) - [ ] 2.3 `[sonnet]` Add file size check (reject >10MB) and row count limit (500,000) to `src/app/api/upload/route.ts` - [ ] 2.4 `[haiku]` Add file type validation (`.csv` extension, text MIME type) to `src/app/api/upload/route.ts` - [ ] 2.5 `[haiku]` Fix CORS in `services/ml/app/main.py`: replace `allow_origins=["*"]` with `["http://localhost:3000"]` and support `CORS_ORIGINS` env var diff --git a/services/ml/app/main.py b/services/ml/app/main.py index 0f43cd8..d93955f 100644 --- a/services/ml/app/main.py +++ b/services/ml/app/main.py @@ -5,6 +5,7 @@ Provides REST API endpoints for model serving, health checks, and prediction. """ import logging +import re import threading import uuid as uuid_lib from pathlib import Path @@ -1244,11 +1245,19 @@ async def delete_training_run(run_id: str): """ Delete a training run record and its model artifact. + Returns HTTP 400 if the run_id format is invalid. Returns HTTP 409 if the run is currently active. Returns HTTP 404 if the run_id doesn't exist. """ from sqlalchemy import select, delete as sa_delete + # Validate run_id format to prevent path traversal + if not re.match(r'^[a-zA-Z0-9_-]+$', run_id): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid run_id format", + ) + # Reject deletion of the active run with state.training_lock: if state.active_training_run_id == run_id: @@ -1280,7 +1289,13 @@ async def delete_training_run(run_id: str): ) # Remove model artifact if it exists - model_path = Path("models") / f"{run_id}.pkl" + models_base = Path("models").resolve() + model_path = (Path("models") / f"{run_id}.pkl").resolve() + if not str(model_path).startswith(str(models_base) + "/"): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid run_id format", + ) if model_path.exists(): try: model_path.unlink() @@ -1396,6 +1411,13 @@ async def model_load(request: ModelLoadRequest): """ from sqlalchemy import select + # 0. Validate run_id format to prevent path traversal + if not re.match(r'^[a-zA-Z0-9_-]+$', request.run_id): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid run_id format", + ) + # 1. Look up the training run try: with get_db() as db: @@ -1421,7 +1443,13 @@ async def model_load(request: ModelLoadRequest): ) # 2. Resolve model artifact path - model_path = Path("models") / f"{request.run_id}.pkl" + models_base = Path("models").resolve() + model_path = (Path("models") / f"{request.run_id}.pkl").resolve() + if not str(model_path).startswith(str(models_base) + "/"): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid run_id format", + ) if not model_path.exists(): raise HTTPException( status_code=status.HTTP_404_NOT_FOUND,