feat(ml): implement training stage with MLflow tracking and model wrappers
- Create RandomForestModel and XGBoostModel wrappers with class weight support - Implement temporal and random train/val/test splitting - Add MLflow experiment tracking with full parameter and metric logging - Create evaluation module for confusion matrix, feature importance, and classification reports - Implement model training with sklearn/xgboost flavor logging and optional registry registration - Store training run metadata in PostgreSQL - Wire training stage into pipeline.py orchestrator - Support both RandomForest and XGBoost models with configurable hyperparameters
This commit is contained in:
parent
16763b967e
commit
f4c0f9a836
8 changed files with 900 additions and 14 deletions
445
services/ml/training/train.py
Normal file
445
services/ml/training/train.py
Normal file
|
|
@ -0,0 +1,445 @@
|
|||
"""
|
||||
Model training module.
|
||||
|
||||
Main training entry point: load labeled CSV, split, train, evaluate, log to MLflow.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Optional
|
||||
import warnings
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import mlflow
|
||||
import mlflow.sklearn
|
||||
import mlflow.xgboost
|
||||
from sklearn.metrics import accuracy_score, f1_score, precision_recall_fscore_support
|
||||
|
||||
from app.config import PipelineConfig, TrainingConfig
|
||||
from app.db import get_db, TrainingRun, init_db
|
||||
from training.models.random_forest import RandomForestModel
|
||||
from training.models.xgboost_model import XGBoostModel
|
||||
from training import evaluation
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def temporal_split(
|
||||
X: np.ndarray,
|
||||
y: np.ndarray,
|
||||
test_split: float,
|
||||
validation_split: float
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Split data temporally into train/validation/test sets.
|
||||
|
||||
Data is assumed to already be sorted by time.
|
||||
Split ratios:
|
||||
- test_split: fraction for test set (from the end)
|
||||
- validation_split: fraction for validation set (from remaining data after test)
|
||||
- remainder: training set
|
||||
|
||||
Args:
|
||||
X: Feature matrix (n_samples, n_features)
|
||||
y: Labels (n_samples,)
|
||||
test_split: Test set fraction (0.0-1.0)
|
||||
validation_split: Validation set fraction (0.0-1.0)
|
||||
|
||||
Returns:
|
||||
X_train, X_val, X_test, y_train, y_val, y_test
|
||||
"""
|
||||
n_samples = len(X)
|
||||
|
||||
# Calculate split indices
|
||||
n_test = int(n_samples * test_split)
|
||||
n_val = int((n_samples - n_test) * validation_split)
|
||||
n_train = n_samples - n_test - n_val
|
||||
|
||||
# Split
|
||||
X_train = X[:n_train]
|
||||
y_train = y[:n_train]
|
||||
|
||||
X_val = X[n_train:n_train + n_val]
|
||||
y_val = y[n_train:n_train + n_val]
|
||||
|
||||
X_test = X[n_train + n_val:]
|
||||
y_test = y[n_train + n_val:]
|
||||
|
||||
logger.info(f"Temporal split: train={n_train}, val={n_val}, test={n_test}")
|
||||
|
||||
return X_train, X_val, X_test, y_train, y_val, y_test
|
||||
|
||||
|
||||
def random_split(
|
||||
X: np.ndarray,
|
||||
y: np.ndarray,
|
||||
test_split: float,
|
||||
validation_split: float,
|
||||
random_state: int = 42
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Split data randomly into train/validation/test sets.
|
||||
|
||||
WARNING: Not recommended for financial time series data.
|
||||
|
||||
Args:
|
||||
X: Feature matrix (n_samples, n_features)
|
||||
y: Labels (n_samples,)
|
||||
test_split: Test set fraction (0.0-1.0)
|
||||
validation_split: Validation set fraction (0.0-1.0)
|
||||
random_state: Random seed
|
||||
|
||||
Returns:
|
||||
X_train, X_val, X_test, y_train, y_val, y_test
|
||||
"""
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
warnings.warn(
|
||||
"Random splitting is not recommended for financial time series data. "
|
||||
"Use temporal splitting instead to avoid data leakage.",
|
||||
UserWarning
|
||||
)
|
||||
logger.warning("Using random split (not recommended for time series)")
|
||||
|
||||
# First split: train+val vs test
|
||||
X_temp, X_test, y_temp, y_test = train_test_split(
|
||||
X, y,
|
||||
test_size=test_split,
|
||||
random_state=random_state,
|
||||
stratify=y
|
||||
)
|
||||
|
||||
# Second split: train vs val
|
||||
val_size = validation_split / (1 - test_split)
|
||||
X_train, X_val, y_train, y_val = train_test_split(
|
||||
X_temp, y_temp,
|
||||
test_size=val_size,
|
||||
random_state=random_state,
|
||||
stratify=y_temp
|
||||
)
|
||||
|
||||
logger.info(f"Random split: train={len(X_train)}, val={len(X_val)}, test={len(X_test)}")
|
||||
|
||||
return X_train, X_val, X_test, y_train, y_val, y_test
|
||||
|
||||
|
||||
def create_model(model_type: str, hyperparameters: dict, class_weights: Optional[str]):
|
||||
"""
|
||||
Create model instance based on model_type.
|
||||
|
||||
Args:
|
||||
model_type: "random_forest" or "xgboost"
|
||||
hyperparameters: Model hyperparameters
|
||||
class_weights: "balanced" or None
|
||||
|
||||
Returns:
|
||||
Model instance
|
||||
|
||||
Raises:
|
||||
ValueError: If model_type is not supported
|
||||
"""
|
||||
if model_type == "random_forest":
|
||||
return RandomForestModel(hyperparameters, class_weights)
|
||||
elif model_type == "xgboost":
|
||||
return XGBoostModel(hyperparameters, class_weights)
|
||||
else:
|
||||
supported_types = ["random_forest", "xgboost"]
|
||||
raise ValueError(
|
||||
f"Unsupported model type: {model_type}. "
|
||||
f"Supported types: {supported_types}"
|
||||
)
|
||||
|
||||
|
||||
def compute_config_hash(config: PipelineConfig) -> str:
|
||||
"""
|
||||
Compute hash of pipeline configuration.
|
||||
|
||||
Args:
|
||||
config: Pipeline configuration
|
||||
|
||||
Returns:
|
||||
SHA256 hash (first 16 chars)
|
||||
"""
|
||||
import json
|
||||
config_str = json.dumps(config.model_dump(), sort_keys=True)
|
||||
return hashlib.sha256(config_str.encode()).hexdigest()[:16]
|
||||
|
||||
|
||||
def train(
|
||||
config: PipelineConfig,
|
||||
labeled_data_path: Path,
|
||||
output_model_path: Optional[Path] = None
|
||||
) -> str:
|
||||
"""
|
||||
Main training function.
|
||||
|
||||
Loads labeled data, splits, trains model, evaluates, logs to MLflow,
|
||||
and stores metadata in PostgreSQL.
|
||||
|
||||
Args:
|
||||
config: Pipeline configuration
|
||||
labeled_data_path: Path to labeled CSV file
|
||||
output_model_path: Optional path to save model locally (for inference)
|
||||
|
||||
Returns:
|
||||
MLflow run ID
|
||||
"""
|
||||
training_config = config.stages.training
|
||||
mlflow_config = training_config.mlflow
|
||||
|
||||
# Initialize database
|
||||
init_db()
|
||||
|
||||
# Set MLflow tracking URI
|
||||
mlflow.set_tracking_uri(mlflow_config.tracking_uri)
|
||||
|
||||
# Set experiment
|
||||
mlflow.set_experiment(mlflow_config.experiment_name)
|
||||
|
||||
logger.info(f"Loading labeled data from {labeled_data_path}")
|
||||
df = pd.read_csv(labeled_data_path)
|
||||
|
||||
# Separate features and labels
|
||||
if 'label' not in df.columns:
|
||||
raise ValueError("Labeled dataset must have 'label' column")
|
||||
|
||||
label_col = 'label'
|
||||
feature_cols = [col for col in df.columns if col not in ['label', 'time', 'timestamp']]
|
||||
|
||||
X = df[feature_cols].values
|
||||
y = df[label_col].values
|
||||
feature_names = feature_cols
|
||||
|
||||
logger.info(f"Loaded {len(X)} samples with {len(feature_cols)} features")
|
||||
logger.info(f"Class distribution: {dict(zip(*np.unique(y, return_counts=True)))}")
|
||||
|
||||
# Split data
|
||||
if training_config.split_method == "temporal":
|
||||
X_train, X_val, X_test, y_train, y_val, y_test = temporal_split(
|
||||
X, y,
|
||||
training_config.test_split,
|
||||
training_config.validation_split
|
||||
)
|
||||
elif training_config.split_method == "random":
|
||||
X_train, X_val, X_test, y_train, y_val, y_test = random_split(
|
||||
X, y,
|
||||
training_config.test_split,
|
||||
training_config.validation_split,
|
||||
random_state=training_config.hyperparameters.get('random_state', 42)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown split_method: {training_config.split_method}")
|
||||
|
||||
# Start MLflow run
|
||||
with mlflow.start_run() as run:
|
||||
run_id = run.info.run_id
|
||||
logger.info(f"Started MLflow run: {run_id}")
|
||||
|
||||
# Create training run record in PostgreSQL
|
||||
with get_db() as db:
|
||||
training_run = TrainingRun(
|
||||
run_id=run_id,
|
||||
model_type=training_config.model_type,
|
||||
experiment_name=mlflow_config.experiment_name,
|
||||
pipeline_config_hash=compute_config_hash(config),
|
||||
dataset_version=None, # TODO: Add DVC hash if available
|
||||
metrics_summary={},
|
||||
status="running",
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
db.add(training_run)
|
||||
db.commit()
|
||||
|
||||
# Log parameters
|
||||
mlflow.log_param("model_type", training_config.model_type)
|
||||
mlflow.log_param("split_method", training_config.split_method)
|
||||
mlflow.log_param("test_split", training_config.test_split)
|
||||
mlflow.log_param("validation_split", training_config.validation_split)
|
||||
mlflow.log_param("class_weights", training_config.class_weights)
|
||||
mlflow.log_param("n_train_samples", len(X_train))
|
||||
mlflow.log_param("n_val_samples", len(X_val))
|
||||
mlflow.log_param("n_test_samples", len(X_test))
|
||||
mlflow.log_param("n_features", X.shape[1])
|
||||
mlflow.log_param("n_classes", len(np.unique(y)))
|
||||
|
||||
# Log per-class sample counts
|
||||
for label, count in zip(*np.unique(y_train, return_counts=True)):
|
||||
mlflow.log_param(f"train_samples_{label}", int(count))
|
||||
|
||||
# Log all hyperparameters
|
||||
for param, value in training_config.hyperparameters.items():
|
||||
mlflow.log_param(param, value)
|
||||
|
||||
# Log pipeline config as artifact
|
||||
import yaml
|
||||
config_dict = config.model_dump()
|
||||
config_yaml = yaml.dump(config_dict, default_flow_style=False)
|
||||
mlflow.log_text(config_yaml, "pipeline_config.yaml")
|
||||
|
||||
# Create and train model
|
||||
logger.info(f"Training {training_config.model_type} model")
|
||||
model = create_model(
|
||||
training_config.model_type,
|
||||
training_config.hyperparameters,
|
||||
training_config.class_weights
|
||||
)
|
||||
|
||||
model.fit(X_train, y_train)
|
||||
logger.info("Training complete")
|
||||
|
||||
# Evaluate on validation set
|
||||
y_val_pred = model.predict(X_val)
|
||||
val_accuracy = accuracy_score(y_val, y_val_pred)
|
||||
val_f1_macro = f1_score(y_val, y_val_pred, average='macro')
|
||||
val_f1_weighted = f1_score(y_val, y_val_pred, average='weighted')
|
||||
|
||||
mlflow.log_metric("val_accuracy", val_accuracy)
|
||||
mlflow.log_metric("val_f1_macro", val_f1_macro)
|
||||
mlflow.log_metric("val_f1_weighted", val_f1_weighted)
|
||||
|
||||
# Evaluate on test set
|
||||
y_test_pred = model.predict(X_test)
|
||||
test_accuracy = accuracy_score(y_test, y_test_pred)
|
||||
test_f1_macro = f1_score(y_test, y_test_pred, average='macro')
|
||||
test_f1_weighted = f1_score(y_test, y_test_pred, average='weighted')
|
||||
|
||||
mlflow.log_metric("test_accuracy", test_accuracy)
|
||||
mlflow.log_metric("test_f1_macro", test_f1_macro)
|
||||
mlflow.log_metric("test_f1_weighted", test_f1_weighted)
|
||||
|
||||
logger.info(f"Test accuracy: {test_accuracy:.4f}")
|
||||
logger.info(f"Test F1 (macro): {test_f1_macro:.4f}")
|
||||
logger.info(f"Test F1 (weighted): {test_f1_weighted:.4f}")
|
||||
|
||||
# Compute per-class metrics
|
||||
classes = model.classes_
|
||||
precision, recall, f1, support = precision_recall_fscore_support(
|
||||
y_test, y_test_pred, labels=classes, average=None
|
||||
)
|
||||
|
||||
for i, label in enumerate(classes):
|
||||
mlflow.log_metric(f"test_precision_{label}", precision[i])
|
||||
mlflow.log_metric(f"test_recall_{label}", recall[i])
|
||||
mlflow.log_metric(f"test_f1_{label}", f1[i])
|
||||
logger.info(f"Class {label}: P={precision[i]:.4f}, R={recall[i]:.4f}, F1={f1[i]:.4f}")
|
||||
|
||||
# Generate and log artifacts if enabled
|
||||
if mlflow_config.log_artifacts:
|
||||
logger.info("Generating evaluation artifacts")
|
||||
|
||||
# Confusion matrix
|
||||
cm_bytes = evaluation.generate_confusion_matrix_plot(
|
||||
y_test, y_test_pred, labels=classes.tolist()
|
||||
)
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(mode='wb', suffix='.png', delete=False) as f:
|
||||
f.write(cm_bytes)
|
||||
cm_path = f.name
|
||||
mlflow.log_artifact(cm_path, "confusion_matrix.png")
|
||||
Path(cm_path).unlink()
|
||||
|
||||
# Feature importance (if available)
|
||||
if hasattr(model, 'feature_importances_'):
|
||||
fi_bytes = evaluation.generate_feature_importance_plot(
|
||||
feature_names, model.feature_importances_
|
||||
)
|
||||
with tempfile.NamedTemporaryFile(mode='wb', suffix='.png', delete=False) as f:
|
||||
f.write(fi_bytes)
|
||||
fi_path = f.name
|
||||
mlflow.log_artifact(fi_path, "feature_importance.png")
|
||||
Path(fi_path).unlink()
|
||||
|
||||
# Classification report
|
||||
report_text = evaluation.generate_classification_report_text(
|
||||
y_test, y_test_pred, labels=classes.tolist()
|
||||
)
|
||||
mlflow.log_text(report_text, "classification_report.txt")
|
||||
|
||||
# Log model to MLflow
|
||||
logger.info("Logging model to MLflow")
|
||||
if training_config.model_type == "random_forest":
|
||||
mlflow.sklearn.log_model(
|
||||
model.model,
|
||||
"model",
|
||||
registered_model_name=(
|
||||
config.stages.inference.mlflow_model_name
|
||||
if mlflow_config.register_model else None
|
||||
)
|
||||
)
|
||||
elif training_config.model_type == "xgboost":
|
||||
mlflow.xgboost.log_model(
|
||||
model.model,
|
||||
"model",
|
||||
registered_model_name=(
|
||||
config.stages.inference.mlflow_model_name
|
||||
if mlflow_config.register_model else None
|
||||
)
|
||||
)
|
||||
|
||||
# Save model locally if path provided
|
||||
if output_model_path:
|
||||
import joblib
|
||||
output_model_path = Path(output_model_path)
|
||||
output_model_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
joblib.dump(model, output_model_path)
|
||||
logger.info(f"Saved model to {output_model_path}")
|
||||
|
||||
# Update training run record in PostgreSQL
|
||||
metrics_summary = {
|
||||
"test_accuracy": float(test_accuracy),
|
||||
"test_f1_macro": float(test_f1_macro),
|
||||
"test_f1_weighted": float(test_f1_weighted),
|
||||
"val_accuracy": float(val_accuracy),
|
||||
"val_f1_macro": float(val_f1_macro),
|
||||
"val_f1_weighted": float(val_f1_weighted)
|
||||
}
|
||||
|
||||
with get_db() as db:
|
||||
training_run = db.query(TrainingRun).filter(
|
||||
TrainingRun.run_id == run_id
|
||||
).first()
|
||||
if training_run:
|
||||
training_run.metrics_summary = metrics_summary
|
||||
training_run.status = "completed"
|
||||
training_run.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Training run {run_id} completed successfully")
|
||||
|
||||
return run_id
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
from app.config import load_config
|
||||
|
||||
parser = argparse.ArgumentParser(description="Train candlestick pattern model")
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
default="config/pipeline.yaml",
|
||||
help="Path to pipeline config file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-model",
|
||||
type=str,
|
||||
default="models/best_model.pkl",
|
||||
help="Path to save trained model"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config = load_config(args.config)
|
||||
labeled_data_path = Path(config.data.labeled_path)
|
||||
|
||||
if not labeled_data_path.exists():
|
||||
logger.error(f"Labeled data not found: {labeled_data_path}")
|
||||
logger.error("Run annotation ingestion stage first")
|
||||
exit(1)
|
||||
|
||||
run_id = train(config, labeled_data_path, Path(args.output_model))
|
||||
print(f"Training complete. Run ID: {run_id}")
|
||||
Loading…
Add table
Add a link
Reference in a new issue