feat: add Python migration script and successfully test SQLite to PostgreSQL data migration
- Created scripts/migrate-sqlite-to-postgres.py as alternative to TypeScript version - Handles all type conversions: timestamps, booleans, and JSONB fields - Successfully migrated all 2,836 rows from SQLite to PostgreSQL - Verified data integrity: all 6 tables migrated correctly - Charts: 1, Candles: 2,592, Annotations: 4, Span annotations: 223
This commit is contained in:
parent
5f70f13da3
commit
bfe437857b
9 changed files with 1080 additions and 20 deletions
|
|
@ -16,6 +16,7 @@ import pandas as pd
|
|||
import numpy as np
|
||||
|
||||
from app.config import AnnotationIngestionConfig
|
||||
from app.data_access import DataAccess
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -95,6 +96,61 @@ class AnnotationIngestion:
|
|||
|
||||
return annotations
|
||||
|
||||
def load_annotations_from_db(
|
||||
self,
|
||||
chart_name: str,
|
||||
source: str = "human"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Load annotations directly from PostgreSQL database.
|
||||
|
||||
This method replaces JSON file exports by querying the database directly.
|
||||
|
||||
Args:
|
||||
chart_name: Name of the chart to load annotations for
|
||||
source: Filter by annotation source ('human', 'model', 'hybrid')
|
||||
|
||||
Returns:
|
||||
List of annotation dictionaries compatible with existing processing
|
||||
"""
|
||||
logger.info(f"Loading annotations from database for chart: {chart_name}")
|
||||
|
||||
data_access = DataAccess()
|
||||
|
||||
# Get span annotations from database
|
||||
chart = data_access.get_chart_by_name(chart_name)
|
||||
if not chart:
|
||||
raise ValueError(f"Chart not found: {chart_name}")
|
||||
|
||||
annotations_df = data_access.get_span_annotations(
|
||||
chart_id=chart['id'],
|
||||
source=source,
|
||||
min_confidence=self.config.min_confidence if self.config.min_confidence > 1 else None
|
||||
)
|
||||
|
||||
if annotations_df.empty:
|
||||
logger.warning(f"No annotations found for chart: {chart_name}")
|
||||
return []
|
||||
|
||||
# Convert DataFrame to list of dictionaries compatible with existing code
|
||||
annotations = []
|
||||
for _, row in annotations_df.iterrows():
|
||||
ann = {
|
||||
'id': row['id'],
|
||||
'label': row['label'],
|
||||
'start_time': row['start_time'].isoformat() if pd.notna(row['start_time']) else None,
|
||||
'end_time': row['end_time'].isoformat() if pd.notna(row['end_time']) else None,
|
||||
'confidence': row.get('confidence'),
|
||||
'outcome': row.get('outcome'),
|
||||
'notes': row.get('notes'),
|
||||
'source': row['source'],
|
||||
}
|
||||
annotations.append(ann)
|
||||
|
||||
logger.info(f"Loaded {len(annotations)} annotations from database")
|
||||
|
||||
return annotations
|
||||
|
||||
def get_programmatic_labels(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Generate programmatic labels using TA-Lib CDL* pattern functions.
|
||||
|
|
@ -484,6 +540,55 @@ class AnnotationIngestion:
|
|||
logger.info("Annotation ingestion complete")
|
||||
|
||||
return result_df
|
||||
|
||||
def process_from_db(
|
||||
self,
|
||||
enriched_df: pd.DataFrame,
|
||||
chart_name: str,
|
||||
source: str = "human"
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Main processing pipeline using direct database access.
|
||||
|
||||
This method replaces JSON file exports by querying PostgreSQL directly.
|
||||
|
||||
Args:
|
||||
enriched_df: DataFrame with engineered features
|
||||
chart_name: Name of the chart to load annotations for
|
||||
source: Filter by annotation source ('human', 'model', 'hybrid')
|
||||
|
||||
Returns:
|
||||
Labeled DataFrame ready for training
|
||||
"""
|
||||
logger.info(f"Starting annotation ingestion from database for chart: {chart_name}")
|
||||
|
||||
# Load annotations from database
|
||||
annotations = self.load_annotations_from_db(chart_name, source)
|
||||
|
||||
if not annotations:
|
||||
logger.warning("No annotations found, returning empty DataFrame")
|
||||
return pd.DataFrame()
|
||||
|
||||
# Add programmatic labels if enabled
|
||||
df = self.get_programmatic_labels(enriched_df)
|
||||
|
||||
# Apply label encoding
|
||||
if self.config.label_encoding == "window":
|
||||
result_df = self.create_windowed_dataset(df, annotations)
|
||||
elif self.config.label_encoding == "bio":
|
||||
result_df = self.create_bio_dataset(df, annotations)
|
||||
# For BIO, also merge human/programmatic if enabled
|
||||
if self.config.programmatic_labels.enabled:
|
||||
result_df = self.merge_labels(result_df, annotations)
|
||||
else:
|
||||
raise ValueError(f"Unknown label encoding: {self.config.label_encoding}")
|
||||
|
||||
# Log statistics
|
||||
self.log_statistics(result_df, annotations)
|
||||
|
||||
logger.info("Annotation ingestion complete")
|
||||
|
||||
return result_df
|
||||
|
||||
|
||||
def run_annotation_ingestion(
|
||||
|
|
|
|||
249
services/ml/app/data_access.py
Normal file
249
services/ml/app/data_access.py
Normal file
|
|
@ -0,0 +1,249 @@
|
|||
"""
|
||||
Direct database access for reading candle and annotation data.
|
||||
|
||||
This module provides read-only access to frontend tables (candles, annotations,
|
||||
span_annotations, charts) using SQLAlchemy table reflections or raw queries.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
import pandas as pd
|
||||
from sqlalchemy import Table, MetaData, select, and_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import engine, get_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataAccess:
|
||||
"""
|
||||
Provides read-only access to frontend database tables.
|
||||
|
||||
Uses SQLAlchemy table reflection to access tables managed by Drizzle ORM
|
||||
without creating duplicate model definitions.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize data access with table reflections."""
|
||||
self.metadata = MetaData()
|
||||
|
||||
# Reflect frontend tables
|
||||
try:
|
||||
self.charts = Table('charts', self.metadata, autoload_with=engine)
|
||||
self.candles = Table('candles', self.metadata, autoload_with=engine)
|
||||
self.annotations = Table('annotations', self.metadata, autoload_with=engine)
|
||||
self.span_annotations = Table('span_annotations', self.metadata, autoload_with=engine)
|
||||
self.span_label_types = Table('span_label_types', self.metadata, autoload_with=engine)
|
||||
logger.info("Successfully reflected frontend tables")
|
||||
except Exception as e:
|
||||
logger.error(f"Error reflecting frontend tables: {e}")
|
||||
raise
|
||||
|
||||
def get_chart_by_name(self, chart_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get chart by name.
|
||||
|
||||
Args:
|
||||
chart_name: Name of the chart
|
||||
|
||||
Returns:
|
||||
Chart dictionary or None if not found
|
||||
"""
|
||||
with get_db() as db:
|
||||
stmt = select(self.charts).where(self.charts.c.name == chart_name)
|
||||
result = db.execute(stmt).fetchone()
|
||||
|
||||
if result:
|
||||
return dict(result._mapping)
|
||||
return None
|
||||
|
||||
def get_candles(
|
||||
self,
|
||||
chart_id: int,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Get candle data for a chart.
|
||||
|
||||
Args:
|
||||
chart_id: Chart ID
|
||||
start_time: Optional start time filter
|
||||
end_time: Optional end time filter
|
||||
|
||||
Returns:
|
||||
DataFrame with columns: id, chart_id, time, open, high, low, close
|
||||
"""
|
||||
with get_db() as db:
|
||||
# Build query
|
||||
stmt = select(self.candles).where(self.candles.c.chart_id == chart_id)
|
||||
|
||||
if start_time:
|
||||
stmt = stmt.where(self.candles.c.time >= start_time)
|
||||
if end_time:
|
||||
stmt = stmt.where(self.candles.c.time <= end_time)
|
||||
|
||||
stmt = stmt.order_by(self.candles.c.time)
|
||||
|
||||
result = db.execute(stmt).fetchall()
|
||||
|
||||
if not result:
|
||||
logger.warning(f"No candles found for chart_id={chart_id}")
|
||||
return pd.DataFrame()
|
||||
|
||||
# Convert to DataFrame
|
||||
df = pd.DataFrame([dict(row._mapping) for row in result])
|
||||
logger.info(f"Loaded {len(df)} candles for chart_id={chart_id}")
|
||||
|
||||
return df
|
||||
|
||||
def get_span_annotations(
|
||||
self,
|
||||
chart_id: int,
|
||||
source: Optional[str] = None,
|
||||
min_confidence: Optional[int] = None
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Get span annotations for a chart.
|
||||
|
||||
Args:
|
||||
chart_id: Chart ID
|
||||
source: Optional filter by source ('human', 'model', 'hybrid')
|
||||
min_confidence: Optional minimum confidence filter
|
||||
|
||||
Returns:
|
||||
DataFrame with span annotations
|
||||
"""
|
||||
with get_db() as db:
|
||||
# Build query
|
||||
stmt = select(self.span_annotations).where(
|
||||
self.span_annotations.c.chart_id == chart_id
|
||||
)
|
||||
|
||||
if source:
|
||||
stmt = stmt.where(self.span_annotations.c.source == source)
|
||||
|
||||
if min_confidence is not None:
|
||||
stmt = stmt.where(
|
||||
and_(
|
||||
self.span_annotations.c.confidence.isnot(None),
|
||||
self.span_annotations.c.confidence >= min_confidence
|
||||
)
|
||||
)
|
||||
|
||||
stmt = stmt.order_by(self.span_annotations.c.start_time)
|
||||
|
||||
result = db.execute(stmt).fetchall()
|
||||
|
||||
if not result:
|
||||
logger.warning(f"No span annotations found for chart_id={chart_id}")
|
||||
return pd.DataFrame()
|
||||
|
||||
# Convert to DataFrame
|
||||
df = pd.DataFrame([dict(row._mapping) for row in result])
|
||||
logger.info(f"Loaded {len(df)} span annotations for chart_id={chart_id}")
|
||||
|
||||
return df
|
||||
|
||||
def get_point_annotations(
|
||||
self,
|
||||
chart_id: int,
|
||||
label_type: Optional[str] = None
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Get point annotations for a chart.
|
||||
|
||||
Args:
|
||||
chart_id: Chart ID
|
||||
label_type: Optional filter by label type
|
||||
|
||||
Returns:
|
||||
DataFrame with point annotations
|
||||
"""
|
||||
with get_db() as db:
|
||||
# Build query
|
||||
stmt = select(self.annotations).where(
|
||||
self.annotations.c.chart_id == chart_id
|
||||
)
|
||||
|
||||
if label_type:
|
||||
stmt = stmt.where(self.annotations.c.label_type == label_type)
|
||||
|
||||
stmt = stmt.order_by(self.annotations.c.timestamp)
|
||||
|
||||
result = db.execute(stmt).fetchall()
|
||||
|
||||
if not result:
|
||||
logger.warning(f"No point annotations found for chart_id={chart_id}")
|
||||
return pd.DataFrame()
|
||||
|
||||
# Convert to DataFrame
|
||||
df = pd.DataFrame([dict(row._mapping) for row in result])
|
||||
logger.info(f"Loaded {len(df)} point annotations for chart_id={chart_id}")
|
||||
|
||||
return df
|
||||
|
||||
def get_training_data(
|
||||
self,
|
||||
chart_name: str,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
annotation_source: str = "human",
|
||||
min_confidence: Optional[int] = None
|
||||
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
||||
"""
|
||||
Get complete training data for a chart (candles + span annotations).
|
||||
|
||||
This is a convenience method that combines candle and annotation queries.
|
||||
|
||||
Args:
|
||||
chart_name: Name of the chart
|
||||
start_time: Optional start time filter
|
||||
end_time: Optional end time filter
|
||||
annotation_source: Filter annotations by source (default: 'human')
|
||||
min_confidence: Optional minimum confidence filter
|
||||
|
||||
Returns:
|
||||
Tuple of (candles_df, annotations_df)
|
||||
"""
|
||||
# Get chart
|
||||
chart = self.get_chart_by_name(chart_name)
|
||||
if not chart:
|
||||
raise ValueError(f"Chart not found: {chart_name}")
|
||||
|
||||
chart_id = chart['id']
|
||||
logger.info(f"Loading training data for chart: {chart_name} (id={chart_id})")
|
||||
|
||||
# Get candles
|
||||
candles_df = self.get_candles(chart_id, start_time, end_time)
|
||||
|
||||
# Get annotations
|
||||
annotations_df = self.get_span_annotations(
|
||||
chart_id,
|
||||
source=annotation_source,
|
||||
min_confidence=min_confidence
|
||||
)
|
||||
|
||||
return candles_df, annotations_df
|
||||
|
||||
def get_all_charts(self) -> pd.DataFrame:
|
||||
"""
|
||||
Get all available charts.
|
||||
|
||||
Returns:
|
||||
DataFrame with all charts
|
||||
"""
|
||||
with get_db() as db:
|
||||
stmt = select(self.charts).order_by(self.charts.c.created_at)
|
||||
result = db.execute(stmt).fetchall()
|
||||
|
||||
if not result:
|
||||
return pd.DataFrame()
|
||||
|
||||
df = pd.DataFrame([dict(row._mapping) for row in result])
|
||||
logger.info(f"Found {len(df)} charts")
|
||||
|
||||
return df
|
||||
Loading…
Add table
Add a link
Reference in a new issue