fix: resolve numpy type conversion issues in ML service data access
- Convert numpy.int64 to Python int before passing to SQLAlchemy queries - Prevents psycopg2.ProgrammingError: can't adapt type 'numpy.int64' - Applied to get_candles(), get_span_annotations(), and get_point_annotations() - All ML service database access tests now passing successfully
This commit is contained in:
parent
5377431c9d
commit
d1557a3846
6 changed files with 437 additions and 119 deletions
|
|
@ -78,6 +78,9 @@ class DataAccess:
|
|||
DataFrame with columns: id, chart_id, time, open, high, low, close
|
||||
"""
|
||||
with get_db() as db:
|
||||
# Convert numpy types to native Python types for SQLAlchemy
|
||||
chart_id = int(chart_id)
|
||||
|
||||
# Build query
|
||||
stmt = select(self.candles).where(self.candles.c.chart_id == chart_id)
|
||||
|
||||
|
|
@ -118,6 +121,11 @@ class DataAccess:
|
|||
DataFrame with span annotations
|
||||
"""
|
||||
with get_db() as db:
|
||||
# Convert numpy types to native Python types for SQLAlchemy
|
||||
chart_id = int(chart_id)
|
||||
if min_confidence is not None:
|
||||
min_confidence = int(min_confidence)
|
||||
|
||||
# Build query
|
||||
stmt = select(self.span_annotations).where(
|
||||
self.span_annotations.c.chart_id == chart_id
|
||||
|
|
@ -164,6 +172,9 @@ class DataAccess:
|
|||
DataFrame with point annotations
|
||||
"""
|
||||
with get_db() as db:
|
||||
# Convert numpy types to native Python types for SQLAlchemy
|
||||
chart_id = int(chart_id)
|
||||
|
||||
# Build query
|
||||
stmt = select(self.annotations).where(
|
||||
self.annotations.c.chart_id == chart_id
|
||||
|
|
|
|||
117
services/ml/test_db_access.py
Normal file
117
services/ml/test_db_access.py
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify ML service can access candle and annotation data from PostgreSQL.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Add app directory to path
|
||||
sys.path.insert(0, str(Path(__file__).parent / "app"))
|
||||
|
||||
from app.data_access import DataAccess
|
||||
from app.db import init_db
|
||||
|
||||
|
||||
def test_db_access():
|
||||
"""Test database access for ML service."""
|
||||
print("=" * 60)
|
||||
print("Testing ML Service Database Access")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
# Initialize database
|
||||
print("\n1. Initializing database connection...")
|
||||
init_db()
|
||||
print(" ✓ Database connection initialized")
|
||||
|
||||
# Create data access instance
|
||||
print("\n2. Creating data access instance...")
|
||||
data_access = DataAccess()
|
||||
print(" ✓ Data access instance created (tables reflected)")
|
||||
|
||||
# Get all charts
|
||||
print("\n3. Querying all charts...")
|
||||
charts_df = data_access.get_all_charts()
|
||||
|
||||
if charts_df.empty:
|
||||
print(" ⚠ No charts found in database")
|
||||
return False
|
||||
|
||||
print(f" ✓ Found {len(charts_df)} chart(s):")
|
||||
for _, chart in charts_df.iterrows():
|
||||
print(f" - {chart['name']} (id={chart['id']})")
|
||||
|
||||
# Test with first chart
|
||||
first_chart = charts_df.iloc[0]
|
||||
chart_name = first_chart['name']
|
||||
chart_id = first_chart['id']
|
||||
|
||||
print(f"\n4. Testing data access with chart: {chart_name}")
|
||||
|
||||
# Get candles
|
||||
print(f"\n a) Querying candles for chart_id={chart_id}...")
|
||||
candles_df = data_access.get_candles(chart_id)
|
||||
|
||||
if candles_df.empty:
|
||||
print(" ⚠ No candles found")
|
||||
else:
|
||||
print(f" ✓ Found {len(candles_df)} candles")
|
||||
print(f" - Time range: {candles_df['time'].min()} to {candles_df['time'].max()}")
|
||||
print(f" - Columns: {list(candles_df.columns)}")
|
||||
|
||||
# Get span annotations
|
||||
print(f"\n b) Querying span annotations for chart_id={chart_id}...")
|
||||
annotations_df = data_access.get_span_annotations(chart_id)
|
||||
|
||||
if annotations_df.empty:
|
||||
print(" ⚠ No span annotations found")
|
||||
else:
|
||||
print(f" ✓ Found {len(annotations_df)} span annotation(s)")
|
||||
labels = annotations_df['label'].unique()
|
||||
print(f" - Unique labels: {list(labels)}")
|
||||
sources = annotations_df['source'].unique()
|
||||
print(f" - Sources: {list(sources)}")
|
||||
|
||||
# Get point annotations
|
||||
print(f"\n c) Querying point annotations for chart_id={chart_id}...")
|
||||
point_annotations_df = data_access.get_point_annotations(chart_id)
|
||||
|
||||
if point_annotations_df.empty:
|
||||
print(" ⚠ No point annotations found")
|
||||
else:
|
||||
print(f" ✓ Found {len(point_annotations_df)} point annotation(s)")
|
||||
label_types = point_annotations_df['label_type'].unique()
|
||||
print(f" - Label types: {list(label_types)}")
|
||||
|
||||
# Test get_training_data convenience method
|
||||
print(f"\n5. Testing get_training_data() convenience method...")
|
||||
try:
|
||||
candles_df, annotations_df = data_access.get_training_data(
|
||||
chart_name=chart_name,
|
||||
annotation_source="human"
|
||||
)
|
||||
print(f" ✓ Successfully retrieved training data")
|
||||
print(f" - Candles: {len(candles_df)} rows")
|
||||
print(f" - Annotations: {len(annotations_df)} rows")
|
||||
except Exception as e:
|
||||
print(f" ✗ Error: {e}")
|
||||
return False
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✓ All database access tests passed!")
|
||||
print("=" * 60)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ Error during testing: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = test_db_access()
|
||||
sys.exit(0 if success else 1)
|
||||
Loading…
Add table
Add a link
Reference in a new issue