- Created scripts/migrate-sqlite-to-postgres.py as alternative to TypeScript version - Handles all type conversions: timestamps, booleans, and JSONB fields - Successfully migrated all 2,836 rows from SQLite to PostgreSQL - Verified data integrity: all 6 tables migrated correctly - Charts: 1, Candles: 2,592, Annotations: 4, Span annotations: 223
249 lines
8.2 KiB
Python
249 lines
8.2 KiB
Python
"""
|
|
Direct database access for reading candle and annotation data.
|
|
|
|
This module provides read-only access to frontend tables (candles, annotations,
|
|
span_annotations, charts) using SQLAlchemy table reflections or raw queries.
|
|
"""
|
|
|
|
import logging
|
|
from typing import List, Dict, Any, Optional
|
|
from datetime import datetime
|
|
|
|
import pandas as pd
|
|
from sqlalchemy import Table, MetaData, select, and_
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.db import engine, get_db
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DataAccess:
|
|
"""
|
|
Provides read-only access to frontend database tables.
|
|
|
|
Uses SQLAlchemy table reflection to access tables managed by Drizzle ORM
|
|
without creating duplicate model definitions.
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initialize data access with table reflections."""
|
|
self.metadata = MetaData()
|
|
|
|
# Reflect frontend tables
|
|
try:
|
|
self.charts = Table('charts', self.metadata, autoload_with=engine)
|
|
self.candles = Table('candles', self.metadata, autoload_with=engine)
|
|
self.annotations = Table('annotations', self.metadata, autoload_with=engine)
|
|
self.span_annotations = Table('span_annotations', self.metadata, autoload_with=engine)
|
|
self.span_label_types = Table('span_label_types', self.metadata, autoload_with=engine)
|
|
logger.info("Successfully reflected frontend tables")
|
|
except Exception as e:
|
|
logger.error(f"Error reflecting frontend tables: {e}")
|
|
raise
|
|
|
|
def get_chart_by_name(self, chart_name: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get chart by name.
|
|
|
|
Args:
|
|
chart_name: Name of the chart
|
|
|
|
Returns:
|
|
Chart dictionary or None if not found
|
|
"""
|
|
with get_db() as db:
|
|
stmt = select(self.charts).where(self.charts.c.name == chart_name)
|
|
result = db.execute(stmt).fetchone()
|
|
|
|
if result:
|
|
return dict(result._mapping)
|
|
return None
|
|
|
|
def get_candles(
|
|
self,
|
|
chart_id: int,
|
|
start_time: Optional[datetime] = None,
|
|
end_time: Optional[datetime] = None
|
|
) -> pd.DataFrame:
|
|
"""
|
|
Get candle data for a chart.
|
|
|
|
Args:
|
|
chart_id: Chart ID
|
|
start_time: Optional start time filter
|
|
end_time: Optional end time filter
|
|
|
|
Returns:
|
|
DataFrame with columns: id, chart_id, time, open, high, low, close
|
|
"""
|
|
with get_db() as db:
|
|
# Build query
|
|
stmt = select(self.candles).where(self.candles.c.chart_id == chart_id)
|
|
|
|
if start_time:
|
|
stmt = stmt.where(self.candles.c.time >= start_time)
|
|
if end_time:
|
|
stmt = stmt.where(self.candles.c.time <= end_time)
|
|
|
|
stmt = stmt.order_by(self.candles.c.time)
|
|
|
|
result = db.execute(stmt).fetchall()
|
|
|
|
if not result:
|
|
logger.warning(f"No candles found for chart_id={chart_id}")
|
|
return pd.DataFrame()
|
|
|
|
# Convert to DataFrame
|
|
df = pd.DataFrame([dict(row._mapping) for row in result])
|
|
logger.info(f"Loaded {len(df)} candles for chart_id={chart_id}")
|
|
|
|
return df
|
|
|
|
def get_span_annotations(
|
|
self,
|
|
chart_id: int,
|
|
source: Optional[str] = None,
|
|
min_confidence: Optional[int] = None
|
|
) -> pd.DataFrame:
|
|
"""
|
|
Get span annotations for a chart.
|
|
|
|
Args:
|
|
chart_id: Chart ID
|
|
source: Optional filter by source ('human', 'model', 'hybrid')
|
|
min_confidence: Optional minimum confidence filter
|
|
|
|
Returns:
|
|
DataFrame with span annotations
|
|
"""
|
|
with get_db() as db:
|
|
# Build query
|
|
stmt = select(self.span_annotations).where(
|
|
self.span_annotations.c.chart_id == chart_id
|
|
)
|
|
|
|
if source:
|
|
stmt = stmt.where(self.span_annotations.c.source == source)
|
|
|
|
if min_confidence is not None:
|
|
stmt = stmt.where(
|
|
and_(
|
|
self.span_annotations.c.confidence.isnot(None),
|
|
self.span_annotations.c.confidence >= min_confidence
|
|
)
|
|
)
|
|
|
|
stmt = stmt.order_by(self.span_annotations.c.start_time)
|
|
|
|
result = db.execute(stmt).fetchall()
|
|
|
|
if not result:
|
|
logger.warning(f"No span annotations found for chart_id={chart_id}")
|
|
return pd.DataFrame()
|
|
|
|
# Convert to DataFrame
|
|
df = pd.DataFrame([dict(row._mapping) for row in result])
|
|
logger.info(f"Loaded {len(df)} span annotations for chart_id={chart_id}")
|
|
|
|
return df
|
|
|
|
def get_point_annotations(
|
|
self,
|
|
chart_id: int,
|
|
label_type: Optional[str] = None
|
|
) -> pd.DataFrame:
|
|
"""
|
|
Get point annotations for a chart.
|
|
|
|
Args:
|
|
chart_id: Chart ID
|
|
label_type: Optional filter by label type
|
|
|
|
Returns:
|
|
DataFrame with point annotations
|
|
"""
|
|
with get_db() as db:
|
|
# Build query
|
|
stmt = select(self.annotations).where(
|
|
self.annotations.c.chart_id == chart_id
|
|
)
|
|
|
|
if label_type:
|
|
stmt = stmt.where(self.annotations.c.label_type == label_type)
|
|
|
|
stmt = stmt.order_by(self.annotations.c.timestamp)
|
|
|
|
result = db.execute(stmt).fetchall()
|
|
|
|
if not result:
|
|
logger.warning(f"No point annotations found for chart_id={chart_id}")
|
|
return pd.DataFrame()
|
|
|
|
# Convert to DataFrame
|
|
df = pd.DataFrame([dict(row._mapping) for row in result])
|
|
logger.info(f"Loaded {len(df)} point annotations for chart_id={chart_id}")
|
|
|
|
return df
|
|
|
|
def get_training_data(
|
|
self,
|
|
chart_name: str,
|
|
start_time: Optional[datetime] = None,
|
|
end_time: Optional[datetime] = None,
|
|
annotation_source: str = "human",
|
|
min_confidence: Optional[int] = None
|
|
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
"""
|
|
Get complete training data for a chart (candles + span annotations).
|
|
|
|
This is a convenience method that combines candle and annotation queries.
|
|
|
|
Args:
|
|
chart_name: Name of the chart
|
|
start_time: Optional start time filter
|
|
end_time: Optional end time filter
|
|
annotation_source: Filter annotations by source (default: 'human')
|
|
min_confidence: Optional minimum confidence filter
|
|
|
|
Returns:
|
|
Tuple of (candles_df, annotations_df)
|
|
"""
|
|
# Get chart
|
|
chart = self.get_chart_by_name(chart_name)
|
|
if not chart:
|
|
raise ValueError(f"Chart not found: {chart_name}")
|
|
|
|
chart_id = chart['id']
|
|
logger.info(f"Loading training data for chart: {chart_name} (id={chart_id})")
|
|
|
|
# Get candles
|
|
candles_df = self.get_candles(chart_id, start_time, end_time)
|
|
|
|
# Get annotations
|
|
annotations_df = self.get_span_annotations(
|
|
chart_id,
|
|
source=annotation_source,
|
|
min_confidence=min_confidence
|
|
)
|
|
|
|
return candles_df, annotations_df
|
|
|
|
def get_all_charts(self) -> pd.DataFrame:
|
|
"""
|
|
Get all available charts.
|
|
|
|
Returns:
|
|
DataFrame with all charts
|
|
"""
|
|
with get_db() as db:
|
|
stmt = select(self.charts).order_by(self.charts.c.created_at)
|
|
result = db.execute(stmt).fetchall()
|
|
|
|
if not result:
|
|
return pd.DataFrame()
|
|
|
|
df = pd.DataFrame([dict(row._mapping) for row in result])
|
|
logger.info(f"Found {len(df)} charts")
|
|
|
|
return df
|