""" RandomForest model wrapper for candlestick pattern classification. Provides a wrapper around scikit-learn's RandomForestClassifier with support for class weight balancing. """ from typing import Any, Dict, Optional import numpy as np from sklearn.ensemble import RandomForestClassifier class RandomForestModel: """ RandomForest classifier wrapper for candlestick patterns. Attributes: model: The underlying RandomForestClassifier instance classes_: Fitted class labels feature_importances_: Feature importance scores (after fitting) """ def __init__(self, hyperparameters: Dict[str, Any], class_weights: Optional[str] = None): """ Initialize RandomForest model. Args: hyperparameters: Model hyperparameters from config class_weights: "balanced" for inverse-frequency weighting, None for no weighting """ self.hyperparameters = hyperparameters.copy() self.class_weights = class_weights # Set class_weight parameter if class_weights == "balanced": self.hyperparameters["class_weight"] = "balanced" # Initialize scikit-learn model self.model = RandomForestClassifier(**self.hyperparameters) def fit(self, X: np.ndarray, y: np.ndarray): """ Train the RandomForest model. Args: X: Training features (n_samples, n_features) y: Training labels (n_samples,) Returns: self """ self.model.fit(X, y) return self def predict(self, X: np.ndarray) -> np.ndarray: """ Predict class labels. Args: X: Features (n_samples, n_features) Returns: Predicted labels (n_samples,) """ return self.model.predict(X) def predict_proba(self, X: np.ndarray) -> np.ndarray: """ Predict class probabilities. Args: X: Features (n_samples, n_features) Returns: Class probabilities (n_samples, n_classes) """ return self.model.predict_proba(X) @property def classes_(self): """Get fitted class labels.""" return self.model.classes_ @property def feature_importances_(self): """Get feature importance scores.""" return self.model.feature_importances_ def get_params(self) -> Dict[str, Any]: """ Get model parameters. Returns: Dictionary of model hyperparameters """ return self.model.get_params() def __repr__(self): return f"RandomForestModel(n_estimators={self.hyperparameters.get('n_estimators', 100)})"