From 688e75e6be0fb3794d0a15dd9b7e82d34e4def26 Mon Sep 17 00:00:00 2001 From: Marko Djordjevic Date: Fri, 20 Feb 2026 18:38:18 +0100 Subject: [PATCH] Scope training run queries in FastAPI to filter by user ID (Task 14.3) - Add user_id column to TrainingRun model in db.py - Store user_id on TrainingRun insert in /training/start - Filter GET /training/runs by user_id (returns empty list if no user context) - Enforce user ownership on GET /training/runs/{run_id} (404 on mismatch) - Enforce user ownership on DELETE /training/runs/{run_id} (404 on mismatch) - Add migration 002 to add user_id column and index to training_runs table Co-Authored-By: Claude Sonnet 4.6 --- services/ml/app/db.py | 1 + services/ml/app/main.py | 34 ++++++++++++++++--- .../002_add_user_id_to_training_runs.sql | 6 ++++ 3 files changed, 36 insertions(+), 5 deletions(-) create mode 100644 services/ml/migrations/002_add_user_id_to_training_runs.sql diff --git a/services/ml/app/db.py b/services/ml/app/db.py index 8fc8475..7f84233 100644 --- a/services/ml/app/db.py +++ b/services/ml/app/db.py @@ -46,6 +46,7 @@ class TrainingRun(Base): id = Column(Integer, primary_key=True, index=True) run_id = Column(String(255), unique=True, nullable=False, index=True) + user_id = Column(String(255), nullable=True, index=True) model_type = Column(String(100), nullable=False) experiment_name = Column(String(255), nullable=False, index=True) pipeline_config_hash = Column(String(64), nullable=False) diff --git a/services/ml/app/main.py b/services/ml/app/main.py index 5f6b40f..730c70c 100644 --- a/services/ml/app/main.py +++ b/services/ml/app/main.py @@ -1490,6 +1490,7 @@ async def training_start(request: TrainingStartRequest, user_id: Optional[str] = with get_db() as db: training_run = TrainingRun( run_id=run_id, + user_id=user_id, model_type=request.model_type, experiment_name=experiment_name, pipeline_config_hash=config_hash, @@ -1522,14 +1523,23 @@ async def training_start(request: TrainingStartRequest, user_id: Optional[str] = @app.get("/training/runs", response_model=TrainingRunsResponse, dependencies=[Depends(verify_api_key)]) -async def training_runs(): +async def training_runs(user_id: Optional[str] = Depends(get_user_id)): """ Return training run history from the database, sorted by date descending. + + Filters results to only include runs belonging to the requesting user + (identified by the X-User-ID header). If no user ID is provided, returns + an empty list to prevent data leakage across users. """ try: from sqlalchemy import select with get_db() as db: stmt = select(TrainingRun).order_by(desc(TrainingRun.created_at)) + if user_id: + stmt = stmt.where(TrainingRun.user_id == user_id) + else: + # No user context — return nothing to prevent data leakage + return TrainingRunsResponse(runs=[]) rows = db.execute(stmt).scalars().all() runs = [ @@ -1577,13 +1587,13 @@ class DeleteRunResponse(BaseModel): @app.delete("/training/runs/{run_id}", response_model=DeleteRunResponse, dependencies=[Depends(verify_api_key)]) -async def delete_training_run(run_id: str): +async def delete_training_run(run_id: str, user_id: Optional[str] = Depends(get_user_id)): """ Delete a training run record and its model artifact. Returns HTTP 400 if the run_id format is invalid. Returns HTTP 409 if the run is currently active. - Returns HTTP 404 if the run_id doesn't exist. + Returns HTTP 404 if the run_id doesn't exist or belongs to a different user. """ from sqlalchemy import select, delete as sa_delete @@ -1613,6 +1623,13 @@ async def delete_training_run(run_id: str): detail=f"Training run not found: {run_id}", ) + # Enforce user ownership: return 404 (not 403) to avoid leaking existence + if user_id and row.user_id and row.user_id != user_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Training run not found: {run_id}", + ) + db.execute(sa_delete(TrainingRun).where(TrainingRun.run_id == run_id)) db.commit() except HTTPException: @@ -1644,12 +1661,12 @@ async def delete_training_run(run_id: str): @app.get("/training/runs/{run_id}", response_model=TrainingRunInfo, dependencies=[Depends(verify_api_key)]) -async def get_training_run(run_id: str): +async def get_training_run(run_id: str, user_id: Optional[str] = Depends(get_user_id)): """ Get information about a specific training run by run_id. Returns HTTP 400 if the run_id format is invalid. - Returns HTTP 404 if the run_id doesn't exist. + Returns HTTP 404 if the run_id doesn't exist or belongs to a different user. """ from sqlalchemy import select @@ -1671,6 +1688,13 @@ async def get_training_run(run_id: str): detail=f"Training run not found: {run_id}", ) + # Enforce user ownership: return 404 (not 403) to avoid leaking existence + if user_id and row.user_id and row.user_id != user_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Training run not found: {run_id}", + ) + return TrainingRunInfo( run_id=row.run_id, model_type=row.model_type, diff --git a/services/ml/migrations/002_add_user_id_to_training_runs.sql b/services/ml/migrations/002_add_user_id_to_training_runs.sql new file mode 100644 index 0000000..22c2c64 --- /dev/null +++ b/services/ml/migrations/002_add_user_id_to_training_runs.sql @@ -0,0 +1,6 @@ +-- Add user_id column to training_runs for per-user scoping of training run queries +ALTER TABLE training_runs + ADD COLUMN IF NOT EXISTS user_id VARCHAR(255); + +-- Create index on user_id for efficient per-user filtering +CREATE INDEX IF NOT EXISTS idx_training_runs_user_id ON training_runs(user_id);