Scope MLflow experiment names to include user ID (Task 14.2)

- Updated FastAPI /training/start endpoint to extract X-User-ID header via get_user_id() dependency
- Modified _run_training_background to accept and use user_id parameter
- Added MLflow experiment setup with user scoping: experiments are named user_{user_id}_training when user_id is provided, falling back to default experiment name otherwise
- Updated database record insertion to store scoped experiment name
- Updated training/train.py train() function to accept user_id parameter and use it for experiment naming
- Mark task 14.2 as complete in tasks.md

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marko Djordjevic 2026-02-20 13:46:00 +01:00
parent 9f76d7eb62
commit cbb921b4a7
3 changed files with 55 additions and 16 deletions

View file

@ -82,7 +82,7 @@
## 14. ML Service User Scoping ## 14. ML Service User Scoping
- [x] 14.1 `[haiku]` Update FastAPI service to read `X-User-ID` header from incoming requests - [x] 14.1 `[haiku]` Update FastAPI service to read `X-User-ID` header from incoming requests
- [ ] 14.2 `[haiku]` Scope MLflow experiment names to include user ID (e.g., `user_{uuid}_training`) - [x] 14.2 `[haiku]` Scope MLflow experiment names to include user ID (e.g., `user_{uuid}_training`)
- [ ] 14.3 `[sonnet]` Scope training run queries in FastAPI to filter by user ID - [ ] 14.3 `[sonnet]` Scope training run queries in FastAPI to filter by user ID
## 15. Documentation & Deployment ## 15. Documentation & Deployment

View file

@ -1230,19 +1230,40 @@ def _run_training_background(
model_type: str, model_type: str,
config: PipelineConfig, config: PipelineConfig,
chart_id: Optional[int] = None, chart_id: Optional[int] = None,
user_id: Optional[str] = None,
) -> None: ) -> None:
""" """
Background thread target: build dataset then train a model. Background thread target: build dataset then train a model.
Uses the pre-inserted TrainingRun record identified by ``run_id``. Uses the pre-inserted TrainingRun record identified by ``run_id``.
Args:
run_id: Training run ID
model_type: Type of model to train
config: Pipeline configuration
chart_id: Optional chart ID to train on
user_id: Optional user ID for scoped experiment naming
""" """
logger.info(f"Training thread started: run_id={run_id}, model_type={model_type}") logger.info(f"Training thread started: run_id={run_id}, model_type={model_type}, user_id={user_id}")
try: try:
# Import training utilities here to avoid circular import issues # Import training utilities here to avoid circular import issues
from training.train import create_model, temporal_split from training.train import create_model, temporal_split
from sklearn.metrics import accuracy_score, f1_score from sklearn.metrics import accuracy_score, f1_score
# Set up MLflow experiment with user scoping
mlflow_config = config.stages.training.mlflow
mlflow.set_tracking_uri(mlflow_config.tracking_uri)
# Use user-scoped experiment name if user_id provided, otherwise use default
if user_id:
experiment_name = f"user_{user_id}_training"
else:
experiment_name = mlflow_config.experiment_name
mlflow.set_experiment(experiment_name)
logger.info(f"MLflow experiment set to: {experiment_name}")
# Build dataset from database (feature engineering + annotation ingestion) # Build dataset from database (feature engineering + annotation ingestion)
logger.info("Building dataset from database...") logger.info("Building dataset from database...")
build_dataset_from_db(config, chart_id=chart_id) build_dataset_from_db(config, chart_id=chart_id)
@ -1407,12 +1428,16 @@ def _run_training_background(
@app.post("/training/start", response_model=TrainingStartResponse, dependencies=[Depends(verify_api_key)]) @app.post("/training/start", response_model=TrainingStartResponse, dependencies=[Depends(verify_api_key)])
async def training_start(request: TrainingStartRequest): async def training_start(request: TrainingStartRequest, user_id: Optional[str] = Depends(get_user_id)):
""" """
Start a training run in a background thread. Start a training run in a background thread.
Returns immediately with run_id and status "running". Returns immediately with run_id and status "running".
Rejects concurrent runs with HTTP 409. Rejects concurrent runs with HTTP 409.
Args:
request: Training request parameters
user_id: Optional user ID from X-User-ID header for scoped experiments
""" """
# Validate model type # Validate model type
if request.model_type not in SUPPORTED_MODEL_TYPES: if request.model_type not in SUPPORTED_MODEL_TYPES:
@ -1455,11 +1480,18 @@ async def training_start(request: TrainingStartRequest):
# Pre-insert the run record so callers can track it immediately # Pre-insert the run record so callers can track it immediately
try: try:
# Compute scoped experiment name
mlflow_config = config.stages.training.mlflow
if user_id:
experiment_name = f"user_{user_id}_training"
else:
experiment_name = mlflow_config.experiment_name
with get_db() as db: with get_db() as db:
training_run = TrainingRun( training_run = TrainingRun(
run_id=run_id, run_id=run_id,
model_type=request.model_type, model_type=request.model_type,
experiment_name=config.stages.training.mlflow.experiment_name, experiment_name=experiment_name,
pipeline_config_hash=config_hash, pipeline_config_hash=config_hash,
status="running", status="running",
created_at=datetime.now(timezone.utc), created_at=datetime.now(timezone.utc),
@ -1479,13 +1511,13 @@ async def training_start(request: TrainingStartRequest):
# Launch background thread (daemon so it doesn't block process exit) # Launch background thread (daemon so it doesn't block process exit)
thread = threading.Thread( thread = threading.Thread(
target=_run_training_background, target=_run_training_background,
args=(run_id, request.model_type, config, request.chart_id), args=(run_id, request.model_type, config, request.chart_id, user_id),
daemon=True, daemon=True,
name=f"training-{run_id[:8]}", name=f"training-{run_id[:8]}",
) )
thread.start() thread.start()
logger.info(f"Training started: run_id={run_id}, model_type={request.model_type}") logger.info(f"Training started: run_id={run_id}, model_type={request.model_type}, user_id={user_id or 'default'}")
return TrainingStartResponse(run_id=run_id, status="running") return TrainingStartResponse(run_id=run_id, status="running")

View file

@ -172,7 +172,8 @@ def compute_config_hash(config: PipelineConfig) -> str:
def train( def train(
config: PipelineConfig, config: PipelineConfig,
labeled_data_path: Path, labeled_data_path: Path,
output_model_path: Optional[Path] = None output_model_path: Optional[Path] = None,
user_id: Optional[str] = None
) -> str: ) -> str:
""" """
Main training function. Main training function.
@ -184,6 +185,7 @@ def train(
config: Pipeline configuration config: Pipeline configuration
labeled_data_path: Path to labeled CSV file labeled_data_path: Path to labeled CSV file
output_model_path: Optional path to save model locally (for inference) output_model_path: Optional path to save model locally (for inference)
user_id: Optional user ID for scoped experiment naming (e.g., user_{uuid}_training)
Returns: Returns:
MLflow run ID MLflow run ID
@ -197,8 +199,13 @@ def train(
# Set MLflow tracking URI # Set MLflow tracking URI
mlflow.set_tracking_uri(mlflow_config.tracking_uri) mlflow.set_tracking_uri(mlflow_config.tracking_uri)
# Set experiment # Set experiment with user scoping if user_id is provided
mlflow.set_experiment(mlflow_config.experiment_name) if user_id:
experiment_name = f"user_{user_id}_training"
else:
experiment_name = mlflow_config.experiment_name
mlflow.set_experiment(experiment_name)
logger.info(f"Loading labeled data from {labeled_data_path}") logger.info(f"Loading labeled data from {labeled_data_path}")
df = pd.read_csv(labeled_data_path) df = pd.read_csv(labeled_data_path)
@ -247,7 +254,7 @@ def train(
training_run = TrainingRun( training_run = TrainingRun(
run_id=run_id, run_id=run_id,
model_type=training_config.model_type, model_type=training_config.model_type,
experiment_name=mlflow_config.experiment_name, experiment_name=experiment_name,
pipeline_config_hash=compute_config_hash(config), pipeline_config_hash=compute_config_hash(config),
dataset_version=None, # TODO: Add DVC hash if available dataset_version=None, # TODO: Add DVC hash if available
metrics_summary={}, metrics_summary={},