- Convert numpy.int64 to Python int before passing to SQLAlchemy queries - Prevents psycopg2.ProgrammingError: can't adapt type 'numpy.int64' - Applied to get_candles(), get_span_annotations(), and get_point_annotations() - All ML service database access tests now passing successfully
117 lines
4.1 KiB
Python
117 lines
4.1 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test script to verify ML service can access candle and annotation data from PostgreSQL.
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
from pathlib import Path
|
|
|
|
# Add app directory to path
|
|
sys.path.insert(0, str(Path(__file__).parent / "app"))
|
|
|
|
from app.data_access import DataAccess
|
|
from app.db import init_db
|
|
|
|
|
|
def test_db_access():
|
|
"""Test database access for ML service."""
|
|
print("=" * 60)
|
|
print("Testing ML Service Database Access")
|
|
print("=" * 60)
|
|
|
|
try:
|
|
# Initialize database
|
|
print("\n1. Initializing database connection...")
|
|
init_db()
|
|
print(" ✓ Database connection initialized")
|
|
|
|
# Create data access instance
|
|
print("\n2. Creating data access instance...")
|
|
data_access = DataAccess()
|
|
print(" ✓ Data access instance created (tables reflected)")
|
|
|
|
# Get all charts
|
|
print("\n3. Querying all charts...")
|
|
charts_df = data_access.get_all_charts()
|
|
|
|
if charts_df.empty:
|
|
print(" ⚠ No charts found in database")
|
|
return False
|
|
|
|
print(f" ✓ Found {len(charts_df)} chart(s):")
|
|
for _, chart in charts_df.iterrows():
|
|
print(f" - {chart['name']} (id={chart['id']})")
|
|
|
|
# Test with first chart
|
|
first_chart = charts_df.iloc[0]
|
|
chart_name = first_chart['name']
|
|
chart_id = first_chart['id']
|
|
|
|
print(f"\n4. Testing data access with chart: {chart_name}")
|
|
|
|
# Get candles
|
|
print(f"\n a) Querying candles for chart_id={chart_id}...")
|
|
candles_df = data_access.get_candles(chart_id)
|
|
|
|
if candles_df.empty:
|
|
print(" ⚠ No candles found")
|
|
else:
|
|
print(f" ✓ Found {len(candles_df)} candles")
|
|
print(f" - Time range: {candles_df['time'].min()} to {candles_df['time'].max()}")
|
|
print(f" - Columns: {list(candles_df.columns)}")
|
|
|
|
# Get span annotations
|
|
print(f"\n b) Querying span annotations for chart_id={chart_id}...")
|
|
annotations_df = data_access.get_span_annotations(chart_id)
|
|
|
|
if annotations_df.empty:
|
|
print(" ⚠ No span annotations found")
|
|
else:
|
|
print(f" ✓ Found {len(annotations_df)} span annotation(s)")
|
|
labels = annotations_df['label'].unique()
|
|
print(f" - Unique labels: {list(labels)}")
|
|
sources = annotations_df['source'].unique()
|
|
print(f" - Sources: {list(sources)}")
|
|
|
|
# Get point annotations
|
|
print(f"\n c) Querying point annotations for chart_id={chart_id}...")
|
|
point_annotations_df = data_access.get_point_annotations(chart_id)
|
|
|
|
if point_annotations_df.empty:
|
|
print(" ⚠ No point annotations found")
|
|
else:
|
|
print(f" ✓ Found {len(point_annotations_df)} point annotation(s)")
|
|
label_types = point_annotations_df['label_type'].unique()
|
|
print(f" - Label types: {list(label_types)}")
|
|
|
|
# Test get_training_data convenience method
|
|
print(f"\n5. Testing get_training_data() convenience method...")
|
|
try:
|
|
candles_df, annotations_df = data_access.get_training_data(
|
|
chart_name=chart_name,
|
|
annotation_source="human"
|
|
)
|
|
print(f" ✓ Successfully retrieved training data")
|
|
print(f" - Candles: {len(candles_df)} rows")
|
|
print(f" - Annotations: {len(annotations_df)} rows")
|
|
except Exception as e:
|
|
print(f" ✗ Error: {e}")
|
|
return False
|
|
|
|
print("\n" + "=" * 60)
|
|
print("✓ All database access tests passed!")
|
|
print("=" * 60)
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"\n✗ Error during testing: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
|
|
if __name__ == "__main__":
|
|
success = test_db_access()
|
|
sys.exit(0 if success else 1)
|