feat(ml): implement training stage with MLflow tracking and model wrappers

- Create RandomForestModel and XGBoostModel wrappers with class weight support
- Implement temporal and random train/val/test splitting
- Add MLflow experiment tracking with full parameter and metric logging
- Create evaluation module for confusion matrix, feature importance, and classification reports
- Implement model training with sklearn/xgboost flavor logging and optional registry registration
- Store training run metadata in PostgreSQL
- Wire training stage into pipeline.py orchestrator
- Support both RandomForest and XGBoost models with configurable hyperparameters
This commit is contained in:
Marko Djordjevic 2026-02-15 14:22:19 +01:00
parent 16763b967e
commit f4c0f9a836
8 changed files with 900 additions and 14 deletions

View file

@ -36,18 +36,18 @@
## 5. Training Stage
- [ ] 5.1 Create `services/ml/training/train.py` — main training entry point: load labeled CSV, split, train, evaluate, log to MLflow
- [ ] 5.2 Implement temporal train/validation/test splitting with configurable ratios, warn on random split
- [ ] 5.3 Create `services/ml/training/models/random_forest.py` — RandomForestClassifier wrapper with class_weights support
- [ ] 5.4 Create `services/ml/training/models/xgboost_model.py` — XGBClassifier wrapper with class_weights support
- [ ] 5.5 Implement model dispatch — select model class based on `model_type` config, fail with supported types list for unknown types
- [ ] 5.6 Implement MLflow experiment tracking — create run, log config artifact, dataset params, per-class sample counts, all hyperparameters
- [ ] 5.7 Implement metrics logging — accuracy, f1_macro, f1_weighted, per-class precision/recall/F1
- [ ] 5.8 Create `services/ml/training/evaluation.py` — generate confusion matrix plot, feature importance plot, classification report text
- [ ] 5.9 Implement MLflow artifact logging — log confusion_matrix.png, feature_importance.png, classification_report.txt, pipeline_config.yaml
- [ ] 5.10 Implement MLflow model registration — log model with sklearn/xgboost flavor, register in registry if configured
- [ ] 5.11 Store training run metadata in PostgreSQL `training_runs` table
- [ ] 5.12 Wire training into `pipeline.py`
- [x] 5.1 Create `services/ml/training/train.py` — main training entry point: load labeled CSV, split, train, evaluate, log to MLflow
- [x] 5.2 Implement temporal train/validation/test splitting with configurable ratios, warn on random split
- [x] 5.3 Create `services/ml/training/models/random_forest.py` — RandomForestClassifier wrapper with class_weights support
- [x] 5.4 Create `services/ml/training/models/xgboost_model.py` — XGBClassifier wrapper with class_weights support
- [x] 5.5 Implement model dispatch — select model class based on `model_type` config, fail with supported types list for unknown types
- [x] 5.6 Implement MLflow experiment tracking — create run, log config artifact, dataset params, per-class sample counts, all hyperparameters
- [x] 5.7 Implement metrics logging — accuracy, f1_macro, f1_weighted, per-class precision/recall/F1
- [x] 5.8 Create `services/ml/training/evaluation.py` — generate confusion matrix plot, feature importance plot, classification report text
- [x] 5.9 Implement MLflow artifact logging — log confusion_matrix.png, feature_importance.png, classification_report.txt, pipeline_config.yaml
- [x] 5.10 Implement MLflow model registration — log model with sklearn/xgboost flavor, register in registry if configured
- [x] 5.11 Store training run metadata in PostgreSQL `training_runs` table
- [x] 5.12 Wire training into `pipeline.py`
## 6. Inference Service (FastAPI)

View file

@ -93,10 +93,16 @@ def run_training(config: PipelineConfig) -> None:
return
# Import here to avoid circular dependencies
from training.train import run_training_stage
from training.train import train
logger.info(f"Reading labeled data from: {config.data.labeled_path}")
run_training_stage(config)
# Set output model path from config
output_model_path = Path(config.stages.inference.local_model_path)
# Run training
run_id = train(config, Path(config.data.labeled_path), output_model_path)
logger.info(f"Training completed. MLflow run ID: {run_id}")
logger.info("Training stage completed successfully")

View file

View file

@ -0,0 +1,215 @@
"""
Model evaluation utilities.
Generate confusion matrix plots, feature importance plots, and classification reports.
"""
from pathlib import Path
from typing import List, Optional
import io
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg') # Non-interactive backend for server
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
def generate_confusion_matrix_plot(
y_true: np.ndarray,
y_pred: np.ndarray,
labels: Optional[List[str]] = None,
normalize: bool = True
) -> bytes:
"""
Generate confusion matrix plot as PNG bytes.
Args:
y_true: True labels
y_pred: Predicted labels
labels: Class label names (optional, inferred if not provided)
normalize: Whether to normalize to percentages
Returns:
PNG image as bytes
"""
# Compute confusion matrix
cm = confusion_matrix(y_true, y_pred, labels=labels)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
fmt = '.2%'
else:
fmt = 'd'
# Create figure
fig, ax = plt.subplots(figsize=(10, 8))
# Plot heatmap
sns.heatmap(
cm,
annot=True,
fmt=fmt,
cmap='Blues',
xticklabels=labels if labels else 'auto',
yticklabels=labels if labels else 'auto',
ax=ax
)
ax.set_xlabel('Predicted Label')
ax.set_ylabel('True Label')
ax.set_title('Confusion Matrix')
# Convert to bytes
buf = io.BytesIO()
plt.tight_layout()
plt.savefig(buf, format='png', dpi=150)
plt.close(fig)
buf.seek(0)
return buf.read()
def generate_feature_importance_plot(
feature_names: List[str],
importances: np.ndarray,
top_n: int = 20
) -> bytes:
"""
Generate feature importance plot as PNG bytes.
Args:
feature_names: Names of features
importances: Feature importance scores
top_n: Number of top features to display
Returns:
PNG image as bytes
"""
# Create DataFrame and sort by importance
df = pd.DataFrame({
'feature': feature_names,
'importance': importances
})
df = df.sort_values('importance', ascending=False).head(top_n)
# Create figure
fig, ax = plt.subplots(figsize=(10, max(6, top_n * 0.3)))
# Horizontal bar plot
ax.barh(range(len(df)), df['importance'].values)
ax.set_yticks(range(len(df)))
ax.set_yticklabels(df['feature'].values)
ax.invert_yaxis()
ax.set_xlabel('Importance Score')
ax.set_title(f'Top {top_n} Feature Importances')
ax.grid(axis='x', alpha=0.3)
# Convert to bytes
buf = io.BytesIO()
plt.tight_layout()
plt.savefig(buf, format='png', dpi=150)
plt.close(fig)
buf.seek(0)
return buf.read()
def generate_classification_report_text(
y_true: np.ndarray,
y_pred: np.ndarray,
labels: Optional[List[str]] = None
) -> str:
"""
Generate classification report as text.
Args:
y_true: True labels
y_pred: Predicted labels
labels: Class label names (optional)
Returns:
Classification report as string
"""
return classification_report(
y_true,
y_pred,
target_names=labels,
digits=4
)
def save_confusion_matrix_plot(
y_true: np.ndarray,
y_pred: np.ndarray,
output_path: Path,
labels: Optional[List[str]] = None,
normalize: bool = True
):
"""
Generate and save confusion matrix plot to file.
Args:
y_true: True labels
y_pred: Predicted labels
output_path: Path to save PNG file
labels: Class label names (optional)
normalize: Whether to normalize to percentages
"""
png_bytes = generate_confusion_matrix_plot(y_true, y_pred, labels, normalize)
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'wb') as f:
f.write(png_bytes)
def save_feature_importance_plot(
feature_names: List[str],
importances: np.ndarray,
output_path: Path,
top_n: int = 20
):
"""
Generate and save feature importance plot to file.
Args:
feature_names: Names of features
importances: Feature importance scores
output_path: Path to save PNG file
top_n: Number of top features to display
"""
png_bytes = generate_feature_importance_plot(feature_names, importances, top_n)
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'wb') as f:
f.write(png_bytes)
def save_classification_report(
y_true: np.ndarray,
y_pred: np.ndarray,
output_path: Path,
labels: Optional[List[str]] = None
):
"""
Generate and save classification report to text file.
Args:
y_true: True labels
y_pred: Predicted labels
output_path: Path to save text file
labels: Class label names (optional)
"""
report = generate_classification_report_text(y_true, y_pred, labels)
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w') as f:
f.write(report)

View file

View file

@ -0,0 +1,100 @@
"""
RandomForest model wrapper for candlestick pattern classification.
Provides a wrapper around scikit-learn's RandomForestClassifier with
support for class weight balancing.
"""
from typing import Any, Dict, Optional
import numpy as np
from sklearn.ensemble import RandomForestClassifier
class RandomForestModel:
"""
RandomForest classifier wrapper for candlestick patterns.
Attributes:
model: The underlying RandomForestClassifier instance
classes_: Fitted class labels
feature_importances_: Feature importance scores (after fitting)
"""
def __init__(self, hyperparameters: Dict[str, Any], class_weights: Optional[str] = None):
"""
Initialize RandomForest model.
Args:
hyperparameters: Model hyperparameters from config
class_weights: "balanced" for inverse-frequency weighting, None for no weighting
"""
self.hyperparameters = hyperparameters.copy()
self.class_weights = class_weights
# Set class_weight parameter
if class_weights == "balanced":
self.hyperparameters["class_weight"] = "balanced"
# Initialize scikit-learn model
self.model = RandomForestClassifier(**self.hyperparameters)
def fit(self, X: np.ndarray, y: np.ndarray):
"""
Train the RandomForest model.
Args:
X: Training features (n_samples, n_features)
y: Training labels (n_samples,)
Returns:
self
"""
self.model.fit(X, y)
return self
def predict(self, X: np.ndarray) -> np.ndarray:
"""
Predict class labels.
Args:
X: Features (n_samples, n_features)
Returns:
Predicted labels (n_samples,)
"""
return self.model.predict(X)
def predict_proba(self, X: np.ndarray) -> np.ndarray:
"""
Predict class probabilities.
Args:
X: Features (n_samples, n_features)
Returns:
Class probabilities (n_samples, n_classes)
"""
return self.model.predict_proba(X)
@property
def classes_(self):
"""Get fitted class labels."""
return self.model.classes_
@property
def feature_importances_(self):
"""Get feature importance scores."""
return self.model.feature_importances_
def get_params(self) -> Dict[str, Any]:
"""
Get model parameters.
Returns:
Dictionary of model hyperparameters
"""
return self.model.get_params()
def __repr__(self):
return f"RandomForestModel(n_estimators={self.hyperparameters.get('n_estimators', 100)})"

View file

@ -0,0 +1,120 @@
"""
XGBoost model wrapper for candlestick pattern classification.
Provides a wrapper around XGBoost's XGBClassifier with support for
class weight balancing.
"""
from typing import Any, Dict, Optional
import numpy as np
from xgboost import XGBClassifier
from sklearn.utils.class_weight import compute_class_weight
class XGBoostModel:
"""
XGBoost classifier wrapper for candlestick patterns.
Attributes:
model: The underlying XGBClassifier instance
classes_: Fitted class labels
feature_importances_: Feature importance scores (after fitting)
"""
def __init__(self, hyperparameters: Dict[str, Any], class_weights: Optional[str] = None):
"""
Initialize XGBoost model.
Args:
hyperparameters: Model hyperparameters from config
class_weights: "balanced" for inverse-frequency weighting, None for no weighting
"""
self.hyperparameters = hyperparameters.copy()
self.class_weights = class_weights
self._sample_weights = None
# XGBoost doesn't have built-in class_weight parameter like sklearn
# We'll compute sample weights manually when class_weights is "balanced"
# Initialize XGBoost model
self.model = XGBClassifier(**self.hyperparameters)
def fit(self, X: np.ndarray, y: np.ndarray):
"""
Train the XGBoost model.
Args:
X: Training features (n_samples, n_features)
y: Training labels (n_samples,)
Returns:
self
"""
# Compute sample weights if class weighting is enabled
if self.class_weights == "balanced":
# Compute class weights
classes = np.unique(y)
class_weights = compute_class_weight(
class_weight="balanced",
classes=classes,
y=y
)
# Map class weights to sample weights
class_weight_dict = dict(zip(classes, class_weights))
sample_weights = np.array([class_weight_dict[label] for label in y])
# Fit with sample weights
self.model.fit(X, y, sample_weight=sample_weights)
else:
# Fit without sample weights
self.model.fit(X, y)
return self
def predict(self, X: np.ndarray) -> np.ndarray:
"""
Predict class labels.
Args:
X: Features (n_samples, n_features)
Returns:
Predicted labels (n_samples,)
"""
return self.model.predict(X)
def predict_proba(self, X: np.ndarray) -> np.ndarray:
"""
Predict class probabilities.
Args:
X: Features (n_samples, n_features)
Returns:
Class probabilities (n_samples, n_classes)
"""
return self.model.predict_proba(X)
@property
def classes_(self):
"""Get fitted class labels."""
return self.model.classes_
@property
def feature_importances_(self):
"""Get feature importance scores."""
return self.model.feature_importances_
def get_params(self) -> Dict[str, Any]:
"""
Get model parameters.
Returns:
Dictionary of model hyperparameters
"""
return self.model.get_params()
def __repr__(self):
return f"XGBoostModel(n_estimators={self.hyperparameters.get('n_estimators', 100)})"

View file

@ -0,0 +1,445 @@
"""
Model training module.
Main training entry point: load labeled CSV, split, train, evaluate, log to MLflow.
"""
import hashlib
import logging
from datetime import datetime
from pathlib import Path
from typing import Tuple, Optional
import warnings
import pandas as pd
import numpy as np
import mlflow
import mlflow.sklearn
import mlflow.xgboost
from sklearn.metrics import accuracy_score, f1_score, precision_recall_fscore_support
from app.config import PipelineConfig, TrainingConfig
from app.db import get_db, TrainingRun, init_db
from training.models.random_forest import RandomForestModel
from training.models.xgboost_model import XGBoostModel
from training import evaluation
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def temporal_split(
X: np.ndarray,
y: np.ndarray,
test_split: float,
validation_split: float
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Split data temporally into train/validation/test sets.
Data is assumed to already be sorted by time.
Split ratios:
- test_split: fraction for test set (from the end)
- validation_split: fraction for validation set (from remaining data after test)
- remainder: training set
Args:
X: Feature matrix (n_samples, n_features)
y: Labels (n_samples,)
test_split: Test set fraction (0.0-1.0)
validation_split: Validation set fraction (0.0-1.0)
Returns:
X_train, X_val, X_test, y_train, y_val, y_test
"""
n_samples = len(X)
# Calculate split indices
n_test = int(n_samples * test_split)
n_val = int((n_samples - n_test) * validation_split)
n_train = n_samples - n_test - n_val
# Split
X_train = X[:n_train]
y_train = y[:n_train]
X_val = X[n_train:n_train + n_val]
y_val = y[n_train:n_train + n_val]
X_test = X[n_train + n_val:]
y_test = y[n_train + n_val:]
logger.info(f"Temporal split: train={n_train}, val={n_val}, test={n_test}")
return X_train, X_val, X_test, y_train, y_val, y_test
def random_split(
X: np.ndarray,
y: np.ndarray,
test_split: float,
validation_split: float,
random_state: int = 42
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Split data randomly into train/validation/test sets.
WARNING: Not recommended for financial time series data.
Args:
X: Feature matrix (n_samples, n_features)
y: Labels (n_samples,)
test_split: Test set fraction (0.0-1.0)
validation_split: Validation set fraction (0.0-1.0)
random_state: Random seed
Returns:
X_train, X_val, X_test, y_train, y_val, y_test
"""
from sklearn.model_selection import train_test_split
warnings.warn(
"Random splitting is not recommended for financial time series data. "
"Use temporal splitting instead to avoid data leakage.",
UserWarning
)
logger.warning("Using random split (not recommended for time series)")
# First split: train+val vs test
X_temp, X_test, y_temp, y_test = train_test_split(
X, y,
test_size=test_split,
random_state=random_state,
stratify=y
)
# Second split: train vs val
val_size = validation_split / (1 - test_split)
X_train, X_val, y_train, y_val = train_test_split(
X_temp, y_temp,
test_size=val_size,
random_state=random_state,
stratify=y_temp
)
logger.info(f"Random split: train={len(X_train)}, val={len(X_val)}, test={len(X_test)}")
return X_train, X_val, X_test, y_train, y_val, y_test
def create_model(model_type: str, hyperparameters: dict, class_weights: Optional[str]):
"""
Create model instance based on model_type.
Args:
model_type: "random_forest" or "xgboost"
hyperparameters: Model hyperparameters
class_weights: "balanced" or None
Returns:
Model instance
Raises:
ValueError: If model_type is not supported
"""
if model_type == "random_forest":
return RandomForestModel(hyperparameters, class_weights)
elif model_type == "xgboost":
return XGBoostModel(hyperparameters, class_weights)
else:
supported_types = ["random_forest", "xgboost"]
raise ValueError(
f"Unsupported model type: {model_type}. "
f"Supported types: {supported_types}"
)
def compute_config_hash(config: PipelineConfig) -> str:
"""
Compute hash of pipeline configuration.
Args:
config: Pipeline configuration
Returns:
SHA256 hash (first 16 chars)
"""
import json
config_str = json.dumps(config.model_dump(), sort_keys=True)
return hashlib.sha256(config_str.encode()).hexdigest()[:16]
def train(
config: PipelineConfig,
labeled_data_path: Path,
output_model_path: Optional[Path] = None
) -> str:
"""
Main training function.
Loads labeled data, splits, trains model, evaluates, logs to MLflow,
and stores metadata in PostgreSQL.
Args:
config: Pipeline configuration
labeled_data_path: Path to labeled CSV file
output_model_path: Optional path to save model locally (for inference)
Returns:
MLflow run ID
"""
training_config = config.stages.training
mlflow_config = training_config.mlflow
# Initialize database
init_db()
# Set MLflow tracking URI
mlflow.set_tracking_uri(mlflow_config.tracking_uri)
# Set experiment
mlflow.set_experiment(mlflow_config.experiment_name)
logger.info(f"Loading labeled data from {labeled_data_path}")
df = pd.read_csv(labeled_data_path)
# Separate features and labels
if 'label' not in df.columns:
raise ValueError("Labeled dataset must have 'label' column")
label_col = 'label'
feature_cols = [col for col in df.columns if col not in ['label', 'time', 'timestamp']]
X = df[feature_cols].values
y = df[label_col].values
feature_names = feature_cols
logger.info(f"Loaded {len(X)} samples with {len(feature_cols)} features")
logger.info(f"Class distribution: {dict(zip(*np.unique(y, return_counts=True)))}")
# Split data
if training_config.split_method == "temporal":
X_train, X_val, X_test, y_train, y_val, y_test = temporal_split(
X, y,
training_config.test_split,
training_config.validation_split
)
elif training_config.split_method == "random":
X_train, X_val, X_test, y_train, y_val, y_test = random_split(
X, y,
training_config.test_split,
training_config.validation_split,
random_state=training_config.hyperparameters.get('random_state', 42)
)
else:
raise ValueError(f"Unknown split_method: {training_config.split_method}")
# Start MLflow run
with mlflow.start_run() as run:
run_id = run.info.run_id
logger.info(f"Started MLflow run: {run_id}")
# Create training run record in PostgreSQL
with get_db() as db:
training_run = TrainingRun(
run_id=run_id,
model_type=training_config.model_type,
experiment_name=mlflow_config.experiment_name,
pipeline_config_hash=compute_config_hash(config),
dataset_version=None, # TODO: Add DVC hash if available
metrics_summary={},
status="running",
created_at=datetime.utcnow()
)
db.add(training_run)
db.commit()
# Log parameters
mlflow.log_param("model_type", training_config.model_type)
mlflow.log_param("split_method", training_config.split_method)
mlflow.log_param("test_split", training_config.test_split)
mlflow.log_param("validation_split", training_config.validation_split)
mlflow.log_param("class_weights", training_config.class_weights)
mlflow.log_param("n_train_samples", len(X_train))
mlflow.log_param("n_val_samples", len(X_val))
mlflow.log_param("n_test_samples", len(X_test))
mlflow.log_param("n_features", X.shape[1])
mlflow.log_param("n_classes", len(np.unique(y)))
# Log per-class sample counts
for label, count in zip(*np.unique(y_train, return_counts=True)):
mlflow.log_param(f"train_samples_{label}", int(count))
# Log all hyperparameters
for param, value in training_config.hyperparameters.items():
mlflow.log_param(param, value)
# Log pipeline config as artifact
import yaml
config_dict = config.model_dump()
config_yaml = yaml.dump(config_dict, default_flow_style=False)
mlflow.log_text(config_yaml, "pipeline_config.yaml")
# Create and train model
logger.info(f"Training {training_config.model_type} model")
model = create_model(
training_config.model_type,
training_config.hyperparameters,
training_config.class_weights
)
model.fit(X_train, y_train)
logger.info("Training complete")
# Evaluate on validation set
y_val_pred = model.predict(X_val)
val_accuracy = accuracy_score(y_val, y_val_pred)
val_f1_macro = f1_score(y_val, y_val_pred, average='macro')
val_f1_weighted = f1_score(y_val, y_val_pred, average='weighted')
mlflow.log_metric("val_accuracy", val_accuracy)
mlflow.log_metric("val_f1_macro", val_f1_macro)
mlflow.log_metric("val_f1_weighted", val_f1_weighted)
# Evaluate on test set
y_test_pred = model.predict(X_test)
test_accuracy = accuracy_score(y_test, y_test_pred)
test_f1_macro = f1_score(y_test, y_test_pred, average='macro')
test_f1_weighted = f1_score(y_test, y_test_pred, average='weighted')
mlflow.log_metric("test_accuracy", test_accuracy)
mlflow.log_metric("test_f1_macro", test_f1_macro)
mlflow.log_metric("test_f1_weighted", test_f1_weighted)
logger.info(f"Test accuracy: {test_accuracy:.4f}")
logger.info(f"Test F1 (macro): {test_f1_macro:.4f}")
logger.info(f"Test F1 (weighted): {test_f1_weighted:.4f}")
# Compute per-class metrics
classes = model.classes_
precision, recall, f1, support = precision_recall_fscore_support(
y_test, y_test_pred, labels=classes, average=None
)
for i, label in enumerate(classes):
mlflow.log_metric(f"test_precision_{label}", precision[i])
mlflow.log_metric(f"test_recall_{label}", recall[i])
mlflow.log_metric(f"test_f1_{label}", f1[i])
logger.info(f"Class {label}: P={precision[i]:.4f}, R={recall[i]:.4f}, F1={f1[i]:.4f}")
# Generate and log artifacts if enabled
if mlflow_config.log_artifacts:
logger.info("Generating evaluation artifacts")
# Confusion matrix
cm_bytes = evaluation.generate_confusion_matrix_plot(
y_test, y_test_pred, labels=classes.tolist()
)
import tempfile
with tempfile.NamedTemporaryFile(mode='wb', suffix='.png', delete=False) as f:
f.write(cm_bytes)
cm_path = f.name
mlflow.log_artifact(cm_path, "confusion_matrix.png")
Path(cm_path).unlink()
# Feature importance (if available)
if hasattr(model, 'feature_importances_'):
fi_bytes = evaluation.generate_feature_importance_plot(
feature_names, model.feature_importances_
)
with tempfile.NamedTemporaryFile(mode='wb', suffix='.png', delete=False) as f:
f.write(fi_bytes)
fi_path = f.name
mlflow.log_artifact(fi_path, "feature_importance.png")
Path(fi_path).unlink()
# Classification report
report_text = evaluation.generate_classification_report_text(
y_test, y_test_pred, labels=classes.tolist()
)
mlflow.log_text(report_text, "classification_report.txt")
# Log model to MLflow
logger.info("Logging model to MLflow")
if training_config.model_type == "random_forest":
mlflow.sklearn.log_model(
model.model,
"model",
registered_model_name=(
config.stages.inference.mlflow_model_name
if mlflow_config.register_model else None
)
)
elif training_config.model_type == "xgboost":
mlflow.xgboost.log_model(
model.model,
"model",
registered_model_name=(
config.stages.inference.mlflow_model_name
if mlflow_config.register_model else None
)
)
# Save model locally if path provided
if output_model_path:
import joblib
output_model_path = Path(output_model_path)
output_model_path.parent.mkdir(parents=True, exist_ok=True)
joblib.dump(model, output_model_path)
logger.info(f"Saved model to {output_model_path}")
# Update training run record in PostgreSQL
metrics_summary = {
"test_accuracy": float(test_accuracy),
"test_f1_macro": float(test_f1_macro),
"test_f1_weighted": float(test_f1_weighted),
"val_accuracy": float(val_accuracy),
"val_f1_macro": float(val_f1_macro),
"val_f1_weighted": float(val_f1_weighted)
}
with get_db() as db:
training_run = db.query(TrainingRun).filter(
TrainingRun.run_id == run_id
).first()
if training_run:
training_run.metrics_summary = metrics_summary
training_run.status = "completed"
training_run.completed_at = datetime.utcnow()
db.commit()
logger.info(f"Training run {run_id} completed successfully")
return run_id
if __name__ == "__main__":
import argparse
from app.config import load_config
parser = argparse.ArgumentParser(description="Train candlestick pattern model")
parser.add_argument(
"--config",
type=str,
default="config/pipeline.yaml",
help="Path to pipeline config file"
)
parser.add_argument(
"--output-model",
type=str,
default="models/best_model.pkl",
help="Path to save trained model"
)
args = parser.parse_args()
config = load_config(args.config)
labeled_data_path = Path(config.data.labeled_path)
if not labeled_data_path.exists():
logger.error(f"Labeled data not found: {labeled_data_path}")
logger.error("Run annotation ingestion stage first")
exit(1)
run_id = train(config, labeled_data_path, Path(args.output_model))
print(f"Training complete. Run ID: {run_id}")