candle-annotator/services/ml/app/data_access.py
Marko Djordjevic bfe437857b feat: add Python migration script and successfully test SQLite to PostgreSQL data migration
- Created scripts/migrate-sqlite-to-postgres.py as alternative to TypeScript version
- Handles all type conversions: timestamps, booleans, and JSONB fields
- Successfully migrated all 2,836 rows from SQLite to PostgreSQL
- Verified data integrity: all 6 tables migrated correctly
- Charts: 1, Candles: 2,592, Annotations: 4, Span annotations: 223
2026-02-17 14:01:21 +01:00

249 lines
8.2 KiB
Python

"""
Direct database access for reading candle and annotation data.
This module provides read-only access to frontend tables (candles, annotations,
span_annotations, charts) using SQLAlchemy table reflections or raw queries.
"""
import logging
from typing import List, Dict, Any, Optional
from datetime import datetime
import pandas as pd
from sqlalchemy import Table, MetaData, select, and_
from sqlalchemy.orm import Session
from app.db import engine, get_db
logger = logging.getLogger(__name__)
class DataAccess:
"""
Provides read-only access to frontend database tables.
Uses SQLAlchemy table reflection to access tables managed by Drizzle ORM
without creating duplicate model definitions.
"""
def __init__(self):
"""Initialize data access with table reflections."""
self.metadata = MetaData()
# Reflect frontend tables
try:
self.charts = Table('charts', self.metadata, autoload_with=engine)
self.candles = Table('candles', self.metadata, autoload_with=engine)
self.annotations = Table('annotations', self.metadata, autoload_with=engine)
self.span_annotations = Table('span_annotations', self.metadata, autoload_with=engine)
self.span_label_types = Table('span_label_types', self.metadata, autoload_with=engine)
logger.info("Successfully reflected frontend tables")
except Exception as e:
logger.error(f"Error reflecting frontend tables: {e}")
raise
def get_chart_by_name(self, chart_name: str) -> Optional[Dict[str, Any]]:
"""
Get chart by name.
Args:
chart_name: Name of the chart
Returns:
Chart dictionary or None if not found
"""
with get_db() as db:
stmt = select(self.charts).where(self.charts.c.name == chart_name)
result = db.execute(stmt).fetchone()
if result:
return dict(result._mapping)
return None
def get_candles(
self,
chart_id: int,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None
) -> pd.DataFrame:
"""
Get candle data for a chart.
Args:
chart_id: Chart ID
start_time: Optional start time filter
end_time: Optional end time filter
Returns:
DataFrame with columns: id, chart_id, time, open, high, low, close
"""
with get_db() as db:
# Build query
stmt = select(self.candles).where(self.candles.c.chart_id == chart_id)
if start_time:
stmt = stmt.where(self.candles.c.time >= start_time)
if end_time:
stmt = stmt.where(self.candles.c.time <= end_time)
stmt = stmt.order_by(self.candles.c.time)
result = db.execute(stmt).fetchall()
if not result:
logger.warning(f"No candles found for chart_id={chart_id}")
return pd.DataFrame()
# Convert to DataFrame
df = pd.DataFrame([dict(row._mapping) for row in result])
logger.info(f"Loaded {len(df)} candles for chart_id={chart_id}")
return df
def get_span_annotations(
self,
chart_id: int,
source: Optional[str] = None,
min_confidence: Optional[int] = None
) -> pd.DataFrame:
"""
Get span annotations for a chart.
Args:
chart_id: Chart ID
source: Optional filter by source ('human', 'model', 'hybrid')
min_confidence: Optional minimum confidence filter
Returns:
DataFrame with span annotations
"""
with get_db() as db:
# Build query
stmt = select(self.span_annotations).where(
self.span_annotations.c.chart_id == chart_id
)
if source:
stmt = stmt.where(self.span_annotations.c.source == source)
if min_confidence is not None:
stmt = stmt.where(
and_(
self.span_annotations.c.confidence.isnot(None),
self.span_annotations.c.confidence >= min_confidence
)
)
stmt = stmt.order_by(self.span_annotations.c.start_time)
result = db.execute(stmt).fetchall()
if not result:
logger.warning(f"No span annotations found for chart_id={chart_id}")
return pd.DataFrame()
# Convert to DataFrame
df = pd.DataFrame([dict(row._mapping) for row in result])
logger.info(f"Loaded {len(df)} span annotations for chart_id={chart_id}")
return df
def get_point_annotations(
self,
chart_id: int,
label_type: Optional[str] = None
) -> pd.DataFrame:
"""
Get point annotations for a chart.
Args:
chart_id: Chart ID
label_type: Optional filter by label type
Returns:
DataFrame with point annotations
"""
with get_db() as db:
# Build query
stmt = select(self.annotations).where(
self.annotations.c.chart_id == chart_id
)
if label_type:
stmt = stmt.where(self.annotations.c.label_type == label_type)
stmt = stmt.order_by(self.annotations.c.timestamp)
result = db.execute(stmt).fetchall()
if not result:
logger.warning(f"No point annotations found for chart_id={chart_id}")
return pd.DataFrame()
# Convert to DataFrame
df = pd.DataFrame([dict(row._mapping) for row in result])
logger.info(f"Loaded {len(df)} point annotations for chart_id={chart_id}")
return df
def get_training_data(
self,
chart_name: str,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
annotation_source: str = "human",
min_confidence: Optional[int] = None
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Get complete training data for a chart (candles + span annotations).
This is a convenience method that combines candle and annotation queries.
Args:
chart_name: Name of the chart
start_time: Optional start time filter
end_time: Optional end time filter
annotation_source: Filter annotations by source (default: 'human')
min_confidence: Optional minimum confidence filter
Returns:
Tuple of (candles_df, annotations_df)
"""
# Get chart
chart = self.get_chart_by_name(chart_name)
if not chart:
raise ValueError(f"Chart not found: {chart_name}")
chart_id = chart['id']
logger.info(f"Loading training data for chart: {chart_name} (id={chart_id})")
# Get candles
candles_df = self.get_candles(chart_id, start_time, end_time)
# Get annotations
annotations_df = self.get_span_annotations(
chart_id,
source=annotation_source,
min_confidence=min_confidence
)
return candles_df, annotations_df
def get_all_charts(self) -> pd.DataFrame:
"""
Get all available charts.
Returns:
DataFrame with all charts
"""
with get_db() as db:
stmt = select(self.charts).order_by(self.charts.c.created_at)
result = db.execute(stmt).fetchall()
if not result:
return pd.DataFrame()
df = pd.DataFrame([dict(row._mapping) for row in result])
logger.info(f"Found {len(df)} charts")
return df