Scope training run queries in FastAPI to filter by user ID (Task 14.3)
- Add user_id column to TrainingRun model in db.py
- Store user_id on TrainingRun insert in /training/start
- Filter GET /training/runs by user_id (returns empty list if no user context)
- Enforce user ownership on GET /training/runs/{run_id} (404 on mismatch)
- Enforce user ownership on DELETE /training/runs/{run_id} (404 on mismatch)
- Add migration 002 to add user_id column and index to training_runs table
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
cbb921b4a7
commit
688e75e6be
3 changed files with 36 additions and 5 deletions
|
|
@ -46,6 +46,7 @@ class TrainingRun(Base):
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, index=True)
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
run_id = Column(String(255), unique=True, nullable=False, index=True)
|
run_id = Column(String(255), unique=True, nullable=False, index=True)
|
||||||
|
user_id = Column(String(255), nullable=True, index=True)
|
||||||
model_type = Column(String(100), nullable=False)
|
model_type = Column(String(100), nullable=False)
|
||||||
experiment_name = Column(String(255), nullable=False, index=True)
|
experiment_name = Column(String(255), nullable=False, index=True)
|
||||||
pipeline_config_hash = Column(String(64), nullable=False)
|
pipeline_config_hash = Column(String(64), nullable=False)
|
||||||
|
|
|
||||||
|
|
@ -1490,6 +1490,7 @@ async def training_start(request: TrainingStartRequest, user_id: Optional[str] =
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
training_run = TrainingRun(
|
training_run = TrainingRun(
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
|
user_id=user_id,
|
||||||
model_type=request.model_type,
|
model_type=request.model_type,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
pipeline_config_hash=config_hash,
|
pipeline_config_hash=config_hash,
|
||||||
|
|
@ -1522,14 +1523,23 @@ async def training_start(request: TrainingStartRequest, user_id: Optional[str] =
|
||||||
|
|
||||||
|
|
||||||
@app.get("/training/runs", response_model=TrainingRunsResponse, dependencies=[Depends(verify_api_key)])
|
@app.get("/training/runs", response_model=TrainingRunsResponse, dependencies=[Depends(verify_api_key)])
|
||||||
async def training_runs():
|
async def training_runs(user_id: Optional[str] = Depends(get_user_id)):
|
||||||
"""
|
"""
|
||||||
Return training run history from the database, sorted by date descending.
|
Return training run history from the database, sorted by date descending.
|
||||||
|
|
||||||
|
Filters results to only include runs belonging to the requesting user
|
||||||
|
(identified by the X-User-ID header). If no user ID is provided, returns
|
||||||
|
an empty list to prevent data leakage across users.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
stmt = select(TrainingRun).order_by(desc(TrainingRun.created_at))
|
stmt = select(TrainingRun).order_by(desc(TrainingRun.created_at))
|
||||||
|
if user_id:
|
||||||
|
stmt = stmt.where(TrainingRun.user_id == user_id)
|
||||||
|
else:
|
||||||
|
# No user context — return nothing to prevent data leakage
|
||||||
|
return TrainingRunsResponse(runs=[])
|
||||||
rows = db.execute(stmt).scalars().all()
|
rows = db.execute(stmt).scalars().all()
|
||||||
|
|
||||||
runs = [
|
runs = [
|
||||||
|
|
@ -1577,13 +1587,13 @@ class DeleteRunResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@app.delete("/training/runs/{run_id}", response_model=DeleteRunResponse, dependencies=[Depends(verify_api_key)])
|
@app.delete("/training/runs/{run_id}", response_model=DeleteRunResponse, dependencies=[Depends(verify_api_key)])
|
||||||
async def delete_training_run(run_id: str):
|
async def delete_training_run(run_id: str, user_id: Optional[str] = Depends(get_user_id)):
|
||||||
"""
|
"""
|
||||||
Delete a training run record and its model artifact.
|
Delete a training run record and its model artifact.
|
||||||
|
|
||||||
Returns HTTP 400 if the run_id format is invalid.
|
Returns HTTP 400 if the run_id format is invalid.
|
||||||
Returns HTTP 409 if the run is currently active.
|
Returns HTTP 409 if the run is currently active.
|
||||||
Returns HTTP 404 if the run_id doesn't exist.
|
Returns HTTP 404 if the run_id doesn't exist or belongs to a different user.
|
||||||
"""
|
"""
|
||||||
from sqlalchemy import select, delete as sa_delete
|
from sqlalchemy import select, delete as sa_delete
|
||||||
|
|
||||||
|
|
@ -1613,6 +1623,13 @@ async def delete_training_run(run_id: str):
|
||||||
detail=f"Training run not found: {run_id}",
|
detail=f"Training run not found: {run_id}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Enforce user ownership: return 404 (not 403) to avoid leaking existence
|
||||||
|
if user_id and row.user_id and row.user_id != user_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Training run not found: {run_id}",
|
||||||
|
)
|
||||||
|
|
||||||
db.execute(sa_delete(TrainingRun).where(TrainingRun.run_id == run_id))
|
db.execute(sa_delete(TrainingRun).where(TrainingRun.run_id == run_id))
|
||||||
db.commit()
|
db.commit()
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
|
@ -1644,12 +1661,12 @@ async def delete_training_run(run_id: str):
|
||||||
|
|
||||||
|
|
||||||
@app.get("/training/runs/{run_id}", response_model=TrainingRunInfo, dependencies=[Depends(verify_api_key)])
|
@app.get("/training/runs/{run_id}", response_model=TrainingRunInfo, dependencies=[Depends(verify_api_key)])
|
||||||
async def get_training_run(run_id: str):
|
async def get_training_run(run_id: str, user_id: Optional[str] = Depends(get_user_id)):
|
||||||
"""
|
"""
|
||||||
Get information about a specific training run by run_id.
|
Get information about a specific training run by run_id.
|
||||||
|
|
||||||
Returns HTTP 400 if the run_id format is invalid.
|
Returns HTTP 400 if the run_id format is invalid.
|
||||||
Returns HTTP 404 if the run_id doesn't exist.
|
Returns HTTP 404 if the run_id doesn't exist or belongs to a different user.
|
||||||
"""
|
"""
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
|
@ -1671,6 +1688,13 @@ async def get_training_run(run_id: str):
|
||||||
detail=f"Training run not found: {run_id}",
|
detail=f"Training run not found: {run_id}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Enforce user ownership: return 404 (not 403) to avoid leaking existence
|
||||||
|
if user_id and row.user_id and row.user_id != user_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Training run not found: {run_id}",
|
||||||
|
)
|
||||||
|
|
||||||
return TrainingRunInfo(
|
return TrainingRunInfo(
|
||||||
run_id=row.run_id,
|
run_id=row.run_id,
|
||||||
model_type=row.model_type,
|
model_type=row.model_type,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,6 @@
|
||||||
|
-- Add user_id column to training_runs for per-user scoping of training run queries
|
||||||
|
ALTER TABLE training_runs
|
||||||
|
ADD COLUMN IF NOT EXISTS user_id VARCHAR(255);
|
||||||
|
|
||||||
|
-- Create index on user_id for efficient per-user filtering
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_training_runs_user_id ON training_runs(user_id);
|
||||||
Loading…
Add table
Add a link
Reference in a new issue