feat(ml): add database schema, config parser, and DVC setup
- Initialize DVC with local storage backend (task 1.6) - Create PostgreSQL schema for training_runs table (task 1.7) - Add SQLAlchemy database connection setup (task 1.8) - Create Pydantic config models for pipeline.yaml (task 2.1) - Add migration runner for database setup - Fix pyproject.toml package discovery config
This commit is contained in:
parent
1a653c5866
commit
ea339a54a7
15 changed files with 412 additions and 4 deletions
|
|
@ -5,13 +5,13 @@
|
|||
- [x] 1.3 Create `services/ml/Dockerfile` with Python 3.11, TA-Lib C library installation (`libta-lib-dev`), and pip install of dependencies
|
||||
- [x] 1.4 Create `config/pipeline.yaml` with the full pipeline configuration (all stages, default hyperparameters, MLflow/DVC settings)
|
||||
- [x] 1.5 Add PostgreSQL, ml-service, and mlflow containers to `docker-compose.yml` with shared data volume
|
||||
- [ ] 1.6 Initialize DVC in `services/ml/` with local remote storage backend
|
||||
- [ ] 1.7 Create PostgreSQL database schema: `training_runs` table (run_id, model_type, experiment_name, pipeline_config_hash, dataset_version, metrics_summary JSON, status, created_at, completed_at)
|
||||
- [ ] 1.8 Create `services/ml/app/db.py` — SQLAlchemy engine and session setup for PostgreSQL connection
|
||||
- [x] 1.6 Initialize DVC in `services/ml/` with local remote storage backend
|
||||
- [x] 1.7 Create PostgreSQL database schema: `training_runs` table (run_id, model_type, experiment_name, pipeline_config_hash, dataset_version, metrics_summary JSON, status, created_at, completed_at)
|
||||
- [x] 1.8 Create `services/ml/app/db.py` — SQLAlchemy engine and session setup for PostgreSQL connection
|
||||
|
||||
## 2. Pipeline Config & Entry Point
|
||||
|
||||
- [ ] 2.1 Create `services/ml/app/config.py` — Pydantic model for pipeline YAML config with validation (stages, data paths, hyperparameters)
|
||||
- [x] 2.1 Create `services/ml/app/config.py` — Pydantic model for pipeline YAML config with validation (stages, data paths, hyperparameters)
|
||||
- [ ] 2.2 Create `services/ml/pipeline.py` — main orchestrator that reads config and runs enabled stages in sequence
|
||||
- [ ] 2.3 Add CLI argument parsing: `--config`, `--stage` (run individual stage), support for `python pipeline.py --config config/pipeline.yaml`
|
||||
|
||||
|
|
|
|||
5
services/ml/.dvc/config
Normal file
5
services/ml/.dvc/config
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
[core]
|
||||
no_scm = True
|
||||
remote = local
|
||||
['remote "local"']
|
||||
url = /home/homoludens/projekti/bitcon/candle_annotator/services/ml/dvc_storage
|
||||
0
services/ml/.dvc/tmp/btime
Normal file
0
services/ml/.dvc/tmp/btime
Normal file
3
services/ml/.dvcignore
Normal file
3
services/ml/.dvcignore
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
# Add patterns of files dvc should ignore, which could improve
|
||||
# the performance. Learn more at
|
||||
# https://dvc.org/doc/user-guide/dvcignore
|
||||
1
services/ml/app/__init__.py
Normal file
1
services/ml/app/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""ML service application package."""
|
||||
147
services/ml/app/config.py
Normal file
147
services/ml/app/config.py
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
"""
|
||||
Pipeline configuration module.
|
||||
|
||||
Pydantic models for validating and loading the pipeline.yaml configuration.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional, Literal
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class TALibIndicator(BaseModel):
|
||||
"""Configuration for a single TA-Lib indicator."""
|
||||
name: str
|
||||
params: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class FeatureEngineeringConfig(BaseModel):
|
||||
"""Feature engineering stage configuration."""
|
||||
enabled: bool = True
|
||||
talib_indicators: List[TALibIndicator] = Field(default_factory=list)
|
||||
candle_features: bool = True
|
||||
custom_features: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ProgrammaticLabelsConfig(BaseModel):
|
||||
"""Configuration for programmatic TA-Lib pattern labels."""
|
||||
enabled: bool = True
|
||||
talib_patterns: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AnnotationIngestionConfig(BaseModel):
|
||||
"""Annotation ingestion stage configuration."""
|
||||
enabled: bool = True
|
||||
label_encoding: Literal["window", "bio"] = "window"
|
||||
window_size: int = 30
|
||||
context_padding: int = 20
|
||||
min_confidence: int = 1
|
||||
programmatic_labels: ProgrammaticLabelsConfig = Field(
|
||||
default_factory=ProgrammaticLabelsConfig
|
||||
)
|
||||
merge_strategy: Literal["human_priority", "programmatic_priority", "both"] = "human_priority"
|
||||
|
||||
|
||||
class MLflowConfig(BaseModel):
|
||||
"""MLflow experiment tracking configuration."""
|
||||
tracking_uri: str = "http://mlflow:5000"
|
||||
experiment_name: str = "candlestick_patterns"
|
||||
log_artifacts: bool = True
|
||||
register_model: bool = False
|
||||
|
||||
|
||||
class TrainingConfig(BaseModel):
|
||||
"""Training stage configuration."""
|
||||
enabled: bool = True
|
||||
model_type: Literal["random_forest", "xgboost"] = "random_forest"
|
||||
split_method: Literal["temporal", "random"] = "temporal"
|
||||
test_split: float = Field(0.2, ge=0.0, le=1.0)
|
||||
validation_split: float = Field(0.1, ge=0.0, le=1.0)
|
||||
class_weights: Optional[Literal["balanced"]] = "balanced"
|
||||
hyperparameters: Dict[str, Any] = Field(default_factory=dict)
|
||||
mlflow: MLflowConfig = Field(default_factory=MLflowConfig)
|
||||
|
||||
@field_validator("test_split", "validation_split")
|
||||
@classmethod
|
||||
def validate_split(cls, v):
|
||||
if not 0.0 <= v <= 1.0:
|
||||
raise ValueError("Split must be between 0.0 and 1.0")
|
||||
return v
|
||||
|
||||
|
||||
class InferenceConfig(BaseModel):
|
||||
"""Inference stage configuration."""
|
||||
enabled: bool = True
|
||||
model_source: Literal["mlflow", "local"] = "local"
|
||||
mlflow_model_name: Optional[str] = "candlestick_pattern_v1"
|
||||
mlflow_model_stage: Literal["Production", "Staging", "None"] = "Production"
|
||||
local_model_path: str = "models/best_model.pkl"
|
||||
batch_size: int = 1000
|
||||
use_training_config: bool = True
|
||||
|
||||
|
||||
class DataConfig(BaseModel):
|
||||
"""Data paths configuration."""
|
||||
raw_path: str = "data/raw/OHLCV.csv"
|
||||
enriched_path: str = "data/enriched/features.csv"
|
||||
labeled_path: str = "data/labeled/dataset.csv"
|
||||
annotations_path: str = "data/annotations/export.json"
|
||||
|
||||
|
||||
class StagesConfig(BaseModel):
|
||||
"""All pipeline stages configuration."""
|
||||
feature_engineering: FeatureEngineeringConfig = Field(
|
||||
default_factory=FeatureEngineeringConfig
|
||||
)
|
||||
annotation_ingestion: AnnotationIngestionConfig = Field(
|
||||
default_factory=AnnotationIngestionConfig
|
||||
)
|
||||
training: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||
inference: InferenceConfig = Field(default_factory=InferenceConfig)
|
||||
|
||||
|
||||
class PipelineConfig(BaseModel):
|
||||
"""Root pipeline configuration."""
|
||||
data: DataConfig = Field(default_factory=DataConfig)
|
||||
stages: StagesConfig = Field(default_factory=StagesConfig)
|
||||
|
||||
|
||||
def load_config(config_path: str | Path) -> PipelineConfig:
|
||||
"""
|
||||
Load and validate pipeline configuration from YAML file.
|
||||
|
||||
Args:
|
||||
config_path: Path to pipeline.yaml file
|
||||
|
||||
Returns:
|
||||
Validated PipelineConfig object
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If config file doesn't exist
|
||||
ValueError: If config validation fails
|
||||
yaml.YAMLError: If YAML parsing fails
|
||||
"""
|
||||
config_path = Path(config_path)
|
||||
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||
|
||||
with open(config_path, 'r') as f:
|
||||
config_dict = yaml.safe_load(f)
|
||||
|
||||
try:
|
||||
return PipelineConfig(**config_dict)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Config validation failed: {e}")
|
||||
|
||||
|
||||
def get_default_config() -> PipelineConfig:
|
||||
"""
|
||||
Get default pipeline configuration.
|
||||
|
||||
Returns:
|
||||
PipelineConfig with default values
|
||||
"""
|
||||
return PipelineConfig()
|
||||
106
services/ml/app/db.py
Normal file
106
services/ml/app/db.py
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
"""
|
||||
Database connection and session management for the ML service.
|
||||
|
||||
This module provides SQLAlchemy engine and session setup for PostgreSQL.
|
||||
Environment variables control the connection parameters.
|
||||
"""
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
|
||||
from sqlalchemy import create_engine, Column, Integer, String, DateTime, JSON
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
|
||||
# Database connection configuration from environment
|
||||
DATABASE_URL = os.getenv(
|
||||
"DATABASE_URL",
|
||||
f"postgresql://{os.getenv('POSTGRES_USER', 'ml_user')}:"
|
||||
f"{os.getenv('POSTGRES_PASSWORD', 'ml_password')}@"
|
||||
f"{os.getenv('POSTGRES_HOST', 'localhost')}:"
|
||||
f"{os.getenv('POSTGRES_PORT', '5432')}/"
|
||||
f"{os.getenv('POSTGRES_DB', 'ml_service')}"
|
||||
)
|
||||
|
||||
# Create SQLAlchemy engine
|
||||
engine = create_engine(
|
||||
DATABASE_URL,
|
||||
pool_pre_ping=True, # Verify connections before using them
|
||||
pool_size=5, # Number of connections to maintain
|
||||
max_overflow=10, # Max connections beyond pool_size
|
||||
)
|
||||
|
||||
# Create session factory
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
# Base class for declarative models
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
# Training runs model
|
||||
class TrainingRun(Base):
|
||||
"""Model for tracking ML training runs."""
|
||||
|
||||
__tablename__ = "training_runs"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
run_id = Column(String(255), unique=True, nullable=False, index=True)
|
||||
model_type = Column(String(100), nullable=False)
|
||||
experiment_name = Column(String(255), nullable=False, index=True)
|
||||
pipeline_config_hash = Column(String(64), nullable=False)
|
||||
dataset_version = Column(String(100))
|
||||
metrics_summary = Column(JSON)
|
||||
status = Column(String(50), nullable=False, default="running", index=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), index=True)
|
||||
completed_at = Column(DateTime(timezone=True))
|
||||
|
||||
def __repr__(self):
|
||||
return f"<TrainingRun(run_id='{self.run_id}', status='{self.status}')>"
|
||||
|
||||
|
||||
def init_db():
|
||||
"""
|
||||
Initialize the database schema.
|
||||
Creates all tables defined by Base.metadata.
|
||||
"""
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
"""
|
||||
Context manager for database sessions.
|
||||
|
||||
Usage:
|
||||
with get_db() as db:
|
||||
# Use db session here
|
||||
training_run = TrainingRun(run_id="123", ...)
|
||||
db.add(training_run)
|
||||
db.commit()
|
||||
|
||||
Yields:
|
||||
Database session
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def get_db_session() -> Session:
|
||||
"""
|
||||
Get a database session (for dependency injection).
|
||||
|
||||
Usage with FastAPI:
|
||||
@app.get("/")
|
||||
def endpoint(db: Session = Depends(get_db_session)):
|
||||
# Use db here
|
||||
|
||||
Returns:
|
||||
Database session (caller must close it)
|
||||
"""
|
||||
return SessionLocal()
|
||||
22
services/ml/candle_ml.egg-info/PKG-INFO
Normal file
22
services/ml/candle_ml.egg-info/PKG-INFO
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
Metadata-Version: 2.4
|
||||
Name: candle-ml
|
||||
Version: 0.1.0
|
||||
Summary: ML service for candlestick pattern recognition
|
||||
Requires-Python: >=3.11
|
||||
Requires-Dist: fastapi>=0.109.0
|
||||
Requires-Dist: uvicorn[standard]>=0.27.0
|
||||
Requires-Dist: scikit-learn>=1.4.0
|
||||
Requires-Dist: xgboost>=2.0.3
|
||||
Requires-Dist: pandas>=2.2.0
|
||||
Requires-Dist: numpy>=1.26.0
|
||||
Requires-Dist: joblib>=1.3.2
|
||||
Requires-Dist: mlflow>=2.10.0
|
||||
Requires-Dist: pyyaml>=6.0.1
|
||||
Requires-Dist: TA-Lib>=0.4.28
|
||||
Requires-Dist: dvc>=3.40.0
|
||||
Requires-Dist: sqlalchemy>=2.0.25
|
||||
Requires-Dist: psycopg2-binary>=2.9.9
|
||||
Requires-Dist: pydantic>=2.5.0
|
||||
Requires-Dist: pydantic-settings>=2.1.0
|
||||
Requires-Dist: matplotlib>=3.8.2
|
||||
Requires-Dist: seaborn>=0.13.1
|
||||
6
services/ml/candle_ml.egg-info/SOURCES.txt
Normal file
6
services/ml/candle_ml.egg-info/SOURCES.txt
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
pyproject.toml
|
||||
candle_ml.egg-info/PKG-INFO
|
||||
candle_ml.egg-info/SOURCES.txt
|
||||
candle_ml.egg-info/dependency_links.txt
|
||||
candle_ml.egg-info/requires.txt
|
||||
candle_ml.egg-info/top_level.txt
|
||||
1
services/ml/candle_ml.egg-info/dependency_links.txt
Normal file
1
services/ml/candle_ml.egg-info/dependency_links.txt
Normal file
|
|
@ -0,0 +1 @@
|
|||
|
||||
17
services/ml/candle_ml.egg-info/requires.txt
Normal file
17
services/ml/candle_ml.egg-info/requires.txt
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
fastapi>=0.109.0
|
||||
uvicorn[standard]>=0.27.0
|
||||
scikit-learn>=1.4.0
|
||||
xgboost>=2.0.3
|
||||
pandas>=2.2.0
|
||||
numpy>=1.26.0
|
||||
joblib>=1.3.2
|
||||
mlflow>=2.10.0
|
||||
pyyaml>=6.0.1
|
||||
TA-Lib>=0.4.28
|
||||
dvc>=3.40.0
|
||||
sqlalchemy>=2.0.25
|
||||
psycopg2-binary>=2.9.9
|
||||
pydantic>=2.5.0
|
||||
pydantic-settings>=2.1.0
|
||||
matplotlib>=3.8.2
|
||||
seaborn>=0.13.1
|
||||
4
services/ml/candle_ml.egg-info/top_level.txt
Normal file
4
services/ml/candle_ml.egg-info/top_level.txt
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
app
|
||||
features
|
||||
inference
|
||||
training
|
||||
27
services/ml/migrations/001_create_training_runs.sql
Normal file
27
services/ml/migrations/001_create_training_runs.sql
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
-- Create training_runs table for tracking ML training runs
|
||||
CREATE TABLE IF NOT EXISTS training_runs (
|
||||
id SERIAL PRIMARY KEY,
|
||||
run_id VARCHAR(255) NOT NULL UNIQUE,
|
||||
model_type VARCHAR(100) NOT NULL,
|
||||
experiment_name VARCHAR(255) NOT NULL,
|
||||
pipeline_config_hash VARCHAR(64) NOT NULL,
|
||||
dataset_version VARCHAR(100),
|
||||
metrics_summary JSONB,
|
||||
status VARCHAR(50) NOT NULL DEFAULT 'running',
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
||||
completed_at TIMESTAMP WITH TIME ZONE,
|
||||
|
||||
CONSTRAINT valid_status CHECK (status IN ('running', 'completed', 'failed', 'cancelled'))
|
||||
);
|
||||
|
||||
-- Create index on run_id for faster lookups
|
||||
CREATE INDEX idx_training_runs_run_id ON training_runs(run_id);
|
||||
|
||||
-- Create index on experiment_name for filtering by experiment
|
||||
CREATE INDEX idx_training_runs_experiment ON training_runs(experiment_name);
|
||||
|
||||
-- Create index on status for filtering active runs
|
||||
CREATE INDEX idx_training_runs_status ON training_runs(status);
|
||||
|
||||
-- Create index on created_at for chronological queries
|
||||
CREATE INDEX idx_training_runs_created_at ON training_runs(created_at DESC);
|
||||
65
services/ml/migrations/run_migrations.py
Executable file
65
services/ml/migrations/run_migrations.py
Executable file
|
|
@ -0,0 +1,65 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple database migration runner for the ML service.
|
||||
Runs all SQL files in the migrations directory in order.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import psycopg2
|
||||
from psycopg2 import sql
|
||||
|
||||
|
||||
def get_db_connection():
|
||||
"""Get database connection from environment variables."""
|
||||
return psycopg2.connect(
|
||||
host=os.getenv("POSTGRES_HOST", "localhost"),
|
||||
port=os.getenv("POSTGRES_PORT", "5432"),
|
||||
database=os.getenv("POSTGRES_DB", "ml_service"),
|
||||
user=os.getenv("POSTGRES_USER", "ml_user"),
|
||||
password=os.getenv("POSTGRES_PASSWORD", "ml_password")
|
||||
)
|
||||
|
||||
|
||||
def run_migrations():
|
||||
"""Run all migration files in order."""
|
||||
migrations_dir = Path(__file__).parent
|
||||
migration_files = sorted(migrations_dir.glob("*.sql"))
|
||||
|
||||
if not migration_files:
|
||||
print("No migration files found")
|
||||
return
|
||||
|
||||
print(f"Found {len(migration_files)} migration file(s)")
|
||||
|
||||
conn = get_db_connection()
|
||||
cur = conn.cursor()
|
||||
|
||||
try:
|
||||
for migration_file in migration_files:
|
||||
print(f"Running migration: {migration_file.name}")
|
||||
|
||||
with open(migration_file, 'r') as f:
|
||||
migration_sql = f.read()
|
||||
|
||||
cur.execute(migration_sql)
|
||||
conn.commit()
|
||||
|
||||
print(f" ✓ {migration_file.name} completed")
|
||||
|
||||
print("\nAll migrations completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
print(f"\n✗ Migration failed: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
finally:
|
||||
cur.close()
|
||||
conn.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_migrations()
|
||||
|
|
@ -26,3 +26,7 @@ dependencies = [
|
|||
[build-system]
|
||||
requires = ["setuptools>=61.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["."]
|
||||
include = ["app*", "features*", "training*", "inference*"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue