candle-annotator/services/ml/training/evaluation.py
Marko Djordjevic f4c0f9a836 feat(ml): implement training stage with MLflow tracking and model wrappers
- Create RandomForestModel and XGBoostModel wrappers with class weight support
- Implement temporal and random train/val/test splitting
- Add MLflow experiment tracking with full parameter and metric logging
- Create evaluation module for confusion matrix, feature importance, and classification reports
- Implement model training with sklearn/xgboost flavor logging and optional registry registration
- Store training run metadata in PostgreSQL
- Wire training stage into pipeline.py orchestrator
- Support both RandomForest and XGBoost models with configurable hyperparameters
2026-02-15 14:22:19 +01:00

215 lines
5.3 KiB
Python

"""
Model evaluation utilities.
Generate confusion matrix plots, feature importance plots, and classification reports.
"""
from pathlib import Path
from typing import List, Optional
import io
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg') # Non-interactive backend for server
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
def generate_confusion_matrix_plot(
y_true: np.ndarray,
y_pred: np.ndarray,
labels: Optional[List[str]] = None,
normalize: bool = True
) -> bytes:
"""
Generate confusion matrix plot as PNG bytes.
Args:
y_true: True labels
y_pred: Predicted labels
labels: Class label names (optional, inferred if not provided)
normalize: Whether to normalize to percentages
Returns:
PNG image as bytes
"""
# Compute confusion matrix
cm = confusion_matrix(y_true, y_pred, labels=labels)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
fmt = '.2%'
else:
fmt = 'd'
# Create figure
fig, ax = plt.subplots(figsize=(10, 8))
# Plot heatmap
sns.heatmap(
cm,
annot=True,
fmt=fmt,
cmap='Blues',
xticklabels=labels if labels else 'auto',
yticklabels=labels if labels else 'auto',
ax=ax
)
ax.set_xlabel('Predicted Label')
ax.set_ylabel('True Label')
ax.set_title('Confusion Matrix')
# Convert to bytes
buf = io.BytesIO()
plt.tight_layout()
plt.savefig(buf, format='png', dpi=150)
plt.close(fig)
buf.seek(0)
return buf.read()
def generate_feature_importance_plot(
feature_names: List[str],
importances: np.ndarray,
top_n: int = 20
) -> bytes:
"""
Generate feature importance plot as PNG bytes.
Args:
feature_names: Names of features
importances: Feature importance scores
top_n: Number of top features to display
Returns:
PNG image as bytes
"""
# Create DataFrame and sort by importance
df = pd.DataFrame({
'feature': feature_names,
'importance': importances
})
df = df.sort_values('importance', ascending=False).head(top_n)
# Create figure
fig, ax = plt.subplots(figsize=(10, max(6, top_n * 0.3)))
# Horizontal bar plot
ax.barh(range(len(df)), df['importance'].values)
ax.set_yticks(range(len(df)))
ax.set_yticklabels(df['feature'].values)
ax.invert_yaxis()
ax.set_xlabel('Importance Score')
ax.set_title(f'Top {top_n} Feature Importances')
ax.grid(axis='x', alpha=0.3)
# Convert to bytes
buf = io.BytesIO()
plt.tight_layout()
plt.savefig(buf, format='png', dpi=150)
plt.close(fig)
buf.seek(0)
return buf.read()
def generate_classification_report_text(
y_true: np.ndarray,
y_pred: np.ndarray,
labels: Optional[List[str]] = None
) -> str:
"""
Generate classification report as text.
Args:
y_true: True labels
y_pred: Predicted labels
labels: Class label names (optional)
Returns:
Classification report as string
"""
return classification_report(
y_true,
y_pred,
target_names=labels,
digits=4
)
def save_confusion_matrix_plot(
y_true: np.ndarray,
y_pred: np.ndarray,
output_path: Path,
labels: Optional[List[str]] = None,
normalize: bool = True
):
"""
Generate and save confusion matrix plot to file.
Args:
y_true: True labels
y_pred: Predicted labels
output_path: Path to save PNG file
labels: Class label names (optional)
normalize: Whether to normalize to percentages
"""
png_bytes = generate_confusion_matrix_plot(y_true, y_pred, labels, normalize)
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'wb') as f:
f.write(png_bytes)
def save_feature_importance_plot(
feature_names: List[str],
importances: np.ndarray,
output_path: Path,
top_n: int = 20
):
"""
Generate and save feature importance plot to file.
Args:
feature_names: Names of features
importances: Feature importance scores
output_path: Path to save PNG file
top_n: Number of top features to display
"""
png_bytes = generate_feature_importance_plot(feature_names, importances, top_n)
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'wb') as f:
f.write(png_bytes)
def save_classification_report(
y_true: np.ndarray,
y_pred: np.ndarray,
output_path: Path,
labels: Optional[List[str]] = None
):
"""
Generate and save classification report to text file.
Args:
y_true: True labels
y_pred: Predicted labels
output_path: Path to save text file
labels: Class label names (optional)
"""
report = generate_classification_report_text(y_true, y_pred, labels)
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w') as f:
f.write(report)