add debuging for hf datasets queries
Browse files- climateqa/engine/talk_to_data/query.py +83 -7
- requirements.txt +1 -0
climateqa/engine/talk_to_data/query.py
CHANGED
@@ -3,6 +3,8 @@ from concurrent.futures import ThreadPoolExecutor
|
|
3 |
import duckdb
|
4 |
import pandas as pd
|
5 |
import os
|
|
|
|
|
6 |
|
7 |
def find_indicator_column(table: str, indicator_columns_per_table: dict[str,str]) -> str:
|
8 |
"""Retrieves the name of the indicator column within a table.
|
@@ -41,14 +43,88 @@ async def execute_sql_query(sql_query: str) -> pd.DataFrame:
|
|
41 |
def _execute_query():
|
42 |
# Execute the query
|
43 |
con = duckdb.connect()
|
|
|
|
|
44 |
HF_TTD_TOKEN = os.getenv("HF_TTD_TOKEN")
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
# Run the query in a thread pool to avoid blocking
|
54 |
loop = asyncio.get_event_loop()
|
|
|
3 |
import duckdb
|
4 |
import pandas as pd
|
5 |
import os
|
6 |
+
import requests
|
7 |
+
import tempfile
|
8 |
|
9 |
def find_indicator_column(table: str, indicator_columns_per_table: dict[str,str]) -> str:
|
10 |
"""Retrieves the name of the indicator column within a table.
|
|
|
43 |
def _execute_query():
|
44 |
# Execute the query
|
45 |
con = duckdb.connect()
|
46 |
+
|
47 |
+
# Try to use Hugging Face authentication if token is available
|
48 |
HF_TTD_TOKEN = os.getenv("HF_TTD_TOKEN")
|
49 |
+
|
50 |
+
try:
|
51 |
+
if HF_TTD_TOKEN:
|
52 |
+
# Set up Hugging Face authentication - updated syntax
|
53 |
+
con.execute(f"""
|
54 |
+
CREATE SECRET IF NOT EXISTS hf_token (
|
55 |
+
TYPE HUGGINGFACE,
|
56 |
+
TOKEN '{HF_TTD_TOKEN}'
|
57 |
+
);
|
58 |
+
""")
|
59 |
+
print("Hugging Face authentication configured")
|
60 |
+
|
61 |
+
# Execute the query
|
62 |
+
results = con.execute(sql_query).fetchdf()
|
63 |
+
return results
|
64 |
+
|
65 |
+
except duckdb.HTTPException as e:
|
66 |
+
print(f"HTTP error accessing Hugging Face dataset: {e}")
|
67 |
+
|
68 |
+
# If we have a token but still get HTTP error, try without authentication
|
69 |
+
if HF_TTD_TOKEN:
|
70 |
+
print("Retrying without authentication...")
|
71 |
+
try:
|
72 |
+
# Create a new connection without the secret
|
73 |
+
con_no_auth = duckdb.connect()
|
74 |
+
results = con_no_auth.execute(sql_query).fetchdf()
|
75 |
+
return results
|
76 |
+
except Exception as e2:
|
77 |
+
print(f"Also failed without authentication: {e2}")
|
78 |
+
|
79 |
+
# Try to download the file locally and retry
|
80 |
+
print("Trying to download file locally and retry...")
|
81 |
+
|
82 |
+
# Extract the URL from the error message or construct it from the query
|
83 |
+
error_str = str(e)
|
84 |
+
url = None
|
85 |
+
|
86 |
+
if "HTTP GET error on '" in error_str:
|
87 |
+
url = error_str.split("HTTP GET error on '")[1].split("'")[0]
|
88 |
+
else:
|
89 |
+
# Try to extract URL from the SQL query
|
90 |
+
import re
|
91 |
+
url_match = re.search(r"'(https://huggingface\.co/[^']+)'", sql_query)
|
92 |
+
if url_match:
|
93 |
+
url = url_match.group(1)
|
94 |
+
|
95 |
+
if url:
|
96 |
+
table_name = url.split('/')[-1]
|
97 |
+
local_path = os.path.join(tempfile.gettempdir(), table_name)
|
98 |
+
print(f"Downloading {url} to {local_path}")
|
99 |
+
|
100 |
+
# Add authentication headers if token is available
|
101 |
+
headers = {}
|
102 |
+
if HF_TTD_TOKEN:
|
103 |
+
headers['Authorization'] = f'Bearer {HF_TTD_TOKEN}'
|
104 |
+
|
105 |
+
response = requests.get(url, headers=headers, stream=True)
|
106 |
+
if response.status_code == 200:
|
107 |
+
with open(local_path, 'wb') as f:
|
108 |
+
for chunk in response.iter_content(chunk_size=8192):
|
109 |
+
f.write(chunk)
|
110 |
+
|
111 |
+
# Modify the SQL query to use the local file
|
112 |
+
modified_sql = sql_query.replace(f"'{url}'", f"'{local_path}'")
|
113 |
+
results = con.execute(modified_sql).fetchdf()
|
114 |
+
return results
|
115 |
+
elif response.status_code == 401:
|
116 |
+
print("Authentication failed - check your HF_TTD_TOKEN")
|
117 |
+
raise Exception("Authentication failed. Please check your HF_TTD_TOKEN environment variable.")
|
118 |
+
else:
|
119 |
+
print(f"Failed to download file: {response.status_code}")
|
120 |
+
raise e
|
121 |
+
else:
|
122 |
+
print("Could not extract URL from error message")
|
123 |
+
raise e
|
124 |
+
|
125 |
+
except Exception as e:
|
126 |
+
print(f"Unexpected error: {e}")
|
127 |
+
raise e
|
128 |
|
129 |
# Run the query in a thread pool to avoid blocking
|
130 |
loop = asyncio.get_event_loop()
|
requirements.txt
CHANGED
@@ -27,3 +27,4 @@ openai==1.61.1
|
|
27 |
pydantic==2.9.2
|
28 |
pydantic-settings==2.2.1
|
29 |
geojson==3.2.0
|
|
|
|
27 |
pydantic==2.9.2
|
28 |
pydantic-settings==2.2.1
|
29 |
geojson==3.2.0
|
30 |
+
requests==2.32.3
|