Scope MLflow experiment names to include user ID (Task 14.2)
- Updated FastAPI /training/start endpoint to extract X-User-ID header via get_user_id() dependency
- Modified _run_training_background to accept and use user_id parameter
- Added MLflow experiment setup with user scoping: experiments are named user_{user_id}_training when user_id is provided, falling back to default experiment name otherwise
- Updated database record insertion to store scoped experiment name
- Updated training/train.py train() function to accept user_id parameter and use it for experiment naming
- Mark task 14.2 as complete in tasks.md
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
9f76d7eb62
commit
cbb921b4a7
3 changed files with 55 additions and 16 deletions
|
|
@ -82,7 +82,7 @@
|
||||||
## 14. ML Service User Scoping
|
## 14. ML Service User Scoping
|
||||||
|
|
||||||
- [x] 14.1 `[haiku]` Update FastAPI service to read `X-User-ID` header from incoming requests
|
- [x] 14.1 `[haiku]` Update FastAPI service to read `X-User-ID` header from incoming requests
|
||||||
- [ ] 14.2 `[haiku]` Scope MLflow experiment names to include user ID (e.g., `user_{uuid}_training`)
|
- [x] 14.2 `[haiku]` Scope MLflow experiment names to include user ID (e.g., `user_{uuid}_training`)
|
||||||
- [ ] 14.3 `[sonnet]` Scope training run queries in FastAPI to filter by user ID
|
- [ ] 14.3 `[sonnet]` Scope training run queries in FastAPI to filter by user ID
|
||||||
|
|
||||||
## 15. Documentation & Deployment
|
## 15. Documentation & Deployment
|
||||||
|
|
|
||||||
|
|
@ -1230,19 +1230,40 @@ def _run_training_background(
|
||||||
model_type: str,
|
model_type: str,
|
||||||
config: PipelineConfig,
|
config: PipelineConfig,
|
||||||
chart_id: Optional[int] = None,
|
chart_id: Optional[int] = None,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Background thread target: build dataset then train a model.
|
Background thread target: build dataset then train a model.
|
||||||
|
|
||||||
Uses the pre-inserted TrainingRun record identified by ``run_id``.
|
Uses the pre-inserted TrainingRun record identified by ``run_id``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: Training run ID
|
||||||
|
model_type: Type of model to train
|
||||||
|
config: Pipeline configuration
|
||||||
|
chart_id: Optional chart ID to train on
|
||||||
|
user_id: Optional user ID for scoped experiment naming
|
||||||
"""
|
"""
|
||||||
logger.info(f"Training thread started: run_id={run_id}, model_type={model_type}")
|
logger.info(f"Training thread started: run_id={run_id}, model_type={model_type}, user_id={user_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Import training utilities here to avoid circular import issues
|
# Import training utilities here to avoid circular import issues
|
||||||
from training.train import create_model, temporal_split
|
from training.train import create_model, temporal_split
|
||||||
from sklearn.metrics import accuracy_score, f1_score
|
from sklearn.metrics import accuracy_score, f1_score
|
||||||
|
|
||||||
|
# Set up MLflow experiment with user scoping
|
||||||
|
mlflow_config = config.stages.training.mlflow
|
||||||
|
mlflow.set_tracking_uri(mlflow_config.tracking_uri)
|
||||||
|
|
||||||
|
# Use user-scoped experiment name if user_id provided, otherwise use default
|
||||||
|
if user_id:
|
||||||
|
experiment_name = f"user_{user_id}_training"
|
||||||
|
else:
|
||||||
|
experiment_name = mlflow_config.experiment_name
|
||||||
|
|
||||||
|
mlflow.set_experiment(experiment_name)
|
||||||
|
logger.info(f"MLflow experiment set to: {experiment_name}")
|
||||||
|
|
||||||
# Build dataset from database (feature engineering + annotation ingestion)
|
# Build dataset from database (feature engineering + annotation ingestion)
|
||||||
logger.info("Building dataset from database...")
|
logger.info("Building dataset from database...")
|
||||||
build_dataset_from_db(config, chart_id=chart_id)
|
build_dataset_from_db(config, chart_id=chart_id)
|
||||||
|
|
@ -1407,12 +1428,16 @@ def _run_training_background(
|
||||||
|
|
||||||
|
|
||||||
@app.post("/training/start", response_model=TrainingStartResponse, dependencies=[Depends(verify_api_key)])
|
@app.post("/training/start", response_model=TrainingStartResponse, dependencies=[Depends(verify_api_key)])
|
||||||
async def training_start(request: TrainingStartRequest):
|
async def training_start(request: TrainingStartRequest, user_id: Optional[str] = Depends(get_user_id)):
|
||||||
"""
|
"""
|
||||||
Start a training run in a background thread.
|
Start a training run in a background thread.
|
||||||
|
|
||||||
Returns immediately with run_id and status "running".
|
Returns immediately with run_id and status "running".
|
||||||
Rejects concurrent runs with HTTP 409.
|
Rejects concurrent runs with HTTP 409.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Training request parameters
|
||||||
|
user_id: Optional user ID from X-User-ID header for scoped experiments
|
||||||
"""
|
"""
|
||||||
# Validate model type
|
# Validate model type
|
||||||
if request.model_type not in SUPPORTED_MODEL_TYPES:
|
if request.model_type not in SUPPORTED_MODEL_TYPES:
|
||||||
|
|
@ -1455,11 +1480,18 @@ async def training_start(request: TrainingStartRequest):
|
||||||
|
|
||||||
# Pre-insert the run record so callers can track it immediately
|
# Pre-insert the run record so callers can track it immediately
|
||||||
try:
|
try:
|
||||||
|
# Compute scoped experiment name
|
||||||
|
mlflow_config = config.stages.training.mlflow
|
||||||
|
if user_id:
|
||||||
|
experiment_name = f"user_{user_id}_training"
|
||||||
|
else:
|
||||||
|
experiment_name = mlflow_config.experiment_name
|
||||||
|
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
training_run = TrainingRun(
|
training_run = TrainingRun(
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
model_type=request.model_type,
|
model_type=request.model_type,
|
||||||
experiment_name=config.stages.training.mlflow.experiment_name,
|
experiment_name=experiment_name,
|
||||||
pipeline_config_hash=config_hash,
|
pipeline_config_hash=config_hash,
|
||||||
status="running",
|
status="running",
|
||||||
created_at=datetime.now(timezone.utc),
|
created_at=datetime.now(timezone.utc),
|
||||||
|
|
@ -1479,13 +1511,13 @@ async def training_start(request: TrainingStartRequest):
|
||||||
# Launch background thread (daemon so it doesn't block process exit)
|
# Launch background thread (daemon so it doesn't block process exit)
|
||||||
thread = threading.Thread(
|
thread = threading.Thread(
|
||||||
target=_run_training_background,
|
target=_run_training_background,
|
||||||
args=(run_id, request.model_type, config, request.chart_id),
|
args=(run_id, request.model_type, config, request.chart_id, user_id),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
name=f"training-{run_id[:8]}",
|
name=f"training-{run_id[:8]}",
|
||||||
)
|
)
|
||||||
thread.start()
|
thread.start()
|
||||||
|
|
||||||
logger.info(f"Training started: run_id={run_id}, model_type={request.model_type}")
|
logger.info(f"Training started: run_id={run_id}, model_type={request.model_type}, user_id={user_id or 'default'}")
|
||||||
return TrainingStartResponse(run_id=run_id, status="running")
|
return TrainingStartResponse(run_id=run_id, status="running")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -172,33 +172,40 @@ def compute_config_hash(config: PipelineConfig) -> str:
|
||||||
def train(
|
def train(
|
||||||
config: PipelineConfig,
|
config: PipelineConfig,
|
||||||
labeled_data_path: Path,
|
labeled_data_path: Path,
|
||||||
output_model_path: Optional[Path] = None
|
output_model_path: Optional[Path] = None,
|
||||||
|
user_id: Optional[str] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Main training function.
|
Main training function.
|
||||||
|
|
||||||
Loads labeled data, splits, trains model, evaluates, logs to MLflow,
|
Loads labeled data, splits, trains model, evaluates, logs to MLflow,
|
||||||
and stores metadata in PostgreSQL.
|
and stores metadata in PostgreSQL.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: Pipeline configuration
|
config: Pipeline configuration
|
||||||
labeled_data_path: Path to labeled CSV file
|
labeled_data_path: Path to labeled CSV file
|
||||||
output_model_path: Optional path to save model locally (for inference)
|
output_model_path: Optional path to save model locally (for inference)
|
||||||
|
user_id: Optional user ID for scoped experiment naming (e.g., user_{uuid}_training)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
MLflow run ID
|
MLflow run ID
|
||||||
"""
|
"""
|
||||||
training_config = config.stages.training
|
training_config = config.stages.training
|
||||||
mlflow_config = training_config.mlflow
|
mlflow_config = training_config.mlflow
|
||||||
|
|
||||||
# Initialize database
|
# Initialize database
|
||||||
init_db()
|
init_db()
|
||||||
|
|
||||||
# Set MLflow tracking URI
|
# Set MLflow tracking URI
|
||||||
mlflow.set_tracking_uri(mlflow_config.tracking_uri)
|
mlflow.set_tracking_uri(mlflow_config.tracking_uri)
|
||||||
|
|
||||||
# Set experiment
|
# Set experiment with user scoping if user_id is provided
|
||||||
mlflow.set_experiment(mlflow_config.experiment_name)
|
if user_id:
|
||||||
|
experiment_name = f"user_{user_id}_training"
|
||||||
|
else:
|
||||||
|
experiment_name = mlflow_config.experiment_name
|
||||||
|
|
||||||
|
mlflow.set_experiment(experiment_name)
|
||||||
|
|
||||||
logger.info(f"Loading labeled data from {labeled_data_path}")
|
logger.info(f"Loading labeled data from {labeled_data_path}")
|
||||||
df = pd.read_csv(labeled_data_path)
|
df = pd.read_csv(labeled_data_path)
|
||||||
|
|
@ -247,7 +254,7 @@ def train(
|
||||||
training_run = TrainingRun(
|
training_run = TrainingRun(
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
model_type=training_config.model_type,
|
model_type=training_config.model_type,
|
||||||
experiment_name=mlflow_config.experiment_name,
|
experiment_name=experiment_name,
|
||||||
pipeline_config_hash=compute_config_hash(config),
|
pipeline_config_hash=compute_config_hash(config),
|
||||||
dataset_version=None, # TODO: Add DVC hash if available
|
dataset_version=None, # TODO: Add DVC hash if available
|
||||||
metrics_summary={},
|
metrics_summary={},
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue