Scope MLflow experiment names to include user ID (Task 14.2)

- Updated FastAPI /training/start endpoint to extract X-User-ID header via get_user_id() dependency
- Modified _run_training_background to accept and use user_id parameter
- Added MLflow experiment setup with user scoping: experiments are named user_{user_id}_training when user_id is provided, falling back to default experiment name otherwise
- Updated database record insertion to store scoped experiment name
- Updated training/train.py train() function to accept user_id parameter and use it for experiment naming
- Mark task 14.2 as complete in tasks.md

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marko Djordjevic 2026-02-20 13:46:00 +01:00
parent 9f76d7eb62
commit cbb921b4a7
3 changed files with 55 additions and 16 deletions

View file

@ -172,33 +172,40 @@ def compute_config_hash(config: PipelineConfig) -> str:
def train(
config: PipelineConfig,
labeled_data_path: Path,
output_model_path: Optional[Path] = None
output_model_path: Optional[Path] = None,
user_id: Optional[str] = None
) -> str:
"""
Main training function.
Loads labeled data, splits, trains model, evaluates, logs to MLflow,
and stores metadata in PostgreSQL.
Args:
config: Pipeline configuration
labeled_data_path: Path to labeled CSV file
output_model_path: Optional path to save model locally (for inference)
user_id: Optional user ID for scoped experiment naming (e.g., user_{uuid}_training)
Returns:
MLflow run ID
"""
training_config = config.stages.training
mlflow_config = training_config.mlflow
# Initialize database
init_db()
# Set MLflow tracking URI
mlflow.set_tracking_uri(mlflow_config.tracking_uri)
# Set experiment
mlflow.set_experiment(mlflow_config.experiment_name)
# Set experiment with user scoping if user_id is provided
if user_id:
experiment_name = f"user_{user_id}_training"
else:
experiment_name = mlflow_config.experiment_name
mlflow.set_experiment(experiment_name)
logger.info(f"Loading labeled data from {labeled_data_path}")
df = pd.read_csv(labeled_data_path)
@ -247,7 +254,7 @@ def train(
training_run = TrainingRun(
run_id=run_id,
model_type=training_config.model_type,
experiment_name=mlflow_config.experiment_name,
experiment_name=experiment_name,
pipeline_config_hash=compute_config_hash(config),
dataset_version=None, # TODO: Add DVC hash if available
metrics_summary={},