security: validate run_id format and add path containment check in ML service
- Add `import re` to services/ml/app/main.py
- In POST /model/load: validate run_id matches ^[a-zA-Z0-9_-]+$ before DB lookup; use Path.resolve() + directory containment check before loading model artifact
- In DELETE /training/runs/{run_id}: validate run_id matches ^[a-zA-Z0-9_-]+$ before any processing; use Path.resolve() + directory containment check before deleting model artifact
- Both endpoints return HTTP 400 with {"detail": "Invalid run_id format"} on invalid input
- Mark task 2.2 as completed in openspec/changes/code-review-fix/tasks.md
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
870f92d208
commit
67dd7aa2f0
2 changed files with 31 additions and 3 deletions
|
|
@ -5,6 +5,7 @@ Provides REST API endpoints for model serving, health checks, and prediction.
|
|||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import threading
|
||||
import uuid as uuid_lib
|
||||
from pathlib import Path
|
||||
|
|
@ -1244,11 +1245,19 @@ async def delete_training_run(run_id: str):
|
|||
"""
|
||||
Delete a training run record and its model artifact.
|
||||
|
||||
Returns HTTP 400 if the run_id format is invalid.
|
||||
Returns HTTP 409 if the run is currently active.
|
||||
Returns HTTP 404 if the run_id doesn't exist.
|
||||
"""
|
||||
from sqlalchemy import select, delete as sa_delete
|
||||
|
||||
# Validate run_id format to prevent path traversal
|
||||
if not re.match(r'^[a-zA-Z0-9_-]+$', run_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid run_id format",
|
||||
)
|
||||
|
||||
# Reject deletion of the active run
|
||||
with state.training_lock:
|
||||
if state.active_training_run_id == run_id:
|
||||
|
|
@ -1280,7 +1289,13 @@ async def delete_training_run(run_id: str):
|
|||
)
|
||||
|
||||
# Remove model artifact if it exists
|
||||
model_path = Path("models") / f"{run_id}.pkl"
|
||||
models_base = Path("models").resolve()
|
||||
model_path = (Path("models") / f"{run_id}.pkl").resolve()
|
||||
if not str(model_path).startswith(str(models_base) + "/"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid run_id format",
|
||||
)
|
||||
if model_path.exists():
|
||||
try:
|
||||
model_path.unlink()
|
||||
|
|
@ -1396,6 +1411,13 @@ async def model_load(request: ModelLoadRequest):
|
|||
"""
|
||||
from sqlalchemy import select
|
||||
|
||||
# 0. Validate run_id format to prevent path traversal
|
||||
if not re.match(r'^[a-zA-Z0-9_-]+$', request.run_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid run_id format",
|
||||
)
|
||||
|
||||
# 1. Look up the training run
|
||||
try:
|
||||
with get_db() as db:
|
||||
|
|
@ -1421,7 +1443,13 @@ async def model_load(request: ModelLoadRequest):
|
|||
)
|
||||
|
||||
# 2. Resolve model artifact path
|
||||
model_path = Path("models") / f"{request.run_id}.pkl"
|
||||
models_base = Path("models").resolve()
|
||||
model_path = (Path("models") / f"{request.run_id}.pkl").resolve()
|
||||
if not str(model_path).startswith(str(models_base) + "/"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid run_id format",
|
||||
)
|
||||
if not model_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue