Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
GAIA Question Loader - Web API version | |
Fetch questions directly from GAIA API instead of local files | |
""" | |
import json | |
import time | |
import logging | |
from typing import List, Dict, Optional | |
import requests | |
from dotenv import load_dotenv | |
import os | |
# Load environment variables | |
load_dotenv() | |
# Configure logging | |
logger = logging.getLogger(__name__) | |
def retry_with_backoff(max_retries: int = 3, initial_delay: float = 1.0, backoff_factor: float = 2.0): | |
"""Decorator to retry a function call with exponential backoff""" | |
def decorator(func): | |
def wrapper(*args, **kwargs): | |
retries = 0 | |
delay = initial_delay | |
last_exception = None | |
while retries < max_retries: | |
try: | |
return func(*args, **kwargs) | |
except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as e: | |
last_exception = e | |
retries += 1 | |
if retries < max_retries: | |
logger.warning(f"Retry {retries}/{max_retries} for {func.__name__} due to {type(e).__name__}. Delaying {delay:.2f}s") | |
time.sleep(delay) | |
delay *= backoff_factor | |
else: | |
logger.error(f"Max retries reached for {func.__name__}") | |
raise last_exception | |
except requests.exceptions.HTTPError as e: | |
if e.response and e.response.status_code in (500, 502, 503, 504): | |
last_exception = e | |
retries += 1 | |
if retries < max_retries: | |
logger.warning(f"Retry {retries}/{max_retries} for {func.__name__} due to HTTP {e.response.status_code}. Delaying {delay:.2f}s") | |
time.sleep(delay) | |
delay *= backoff_factor | |
else: | |
logger.error(f"Max retries reached for {func.__name__}") | |
raise last_exception | |
else: | |
raise | |
return func(*args, **kwargs) | |
return wrapper | |
return decorator | |
class GAIAQuestionLoaderWeb: | |
"""Load and manage GAIA questions from the web API""" | |
def __init__(self, api_base: Optional[str] = None, username: Optional[str] = None): | |
self.api_base = api_base or os.getenv("GAIA_API_BASE", "https://agents-course-unit4-scoring.hf.space") | |
self.username = username or os.getenv("GAIA_USERNAME", "tonthatthienvu") | |
self.questions: List[Dict] = [] | |
self._load_questions() | |
def _make_request(self, method: str, endpoint: str, params: Optional[Dict] = None, | |
payload: Optional[Dict] = None, timeout: int = 15) -> requests.Response: | |
"""Make HTTP request with retry logic""" | |
url = f"{self.api_base}/{endpoint.lstrip('/')}" | |
logger.info(f"Request: {method.upper()} {url}") | |
try: | |
response = requests.request(method, url, params=params, json=payload, timeout=timeout) | |
response.raise_for_status() | |
return response | |
except requests.exceptions.HTTPError as e: | |
logger.error(f"HTTPError: {e.response.status_code} for {method.upper()} {url}") | |
if e.response: | |
logger.error(f"Response: {e.response.text[:200]}") | |
raise | |
except requests.exceptions.Timeout: | |
logger.error(f"Timeout: Request to {url} timed out after {timeout}s") | |
raise | |
except requests.exceptions.ConnectionError as e: | |
logger.error(f"ConnectionError: Could not connect to {url}. Details: {e}") | |
raise | |
def _load_questions(self): | |
"""Fetch all questions from the GAIA API""" | |
try: | |
logger.info(f"Fetching questions from GAIA API: {self.api_base}/questions") | |
response = self._make_request("get", "questions", timeout=15) | |
self.questions = response.json() | |
print(f"✅ Loaded {len(self.questions)} GAIA questions from web API") | |
logger.info(f"Successfully retrieved {len(self.questions)} questions from API") | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Failed to fetch questions from API: {e}") | |
print(f"❌ Failed to load questions from web API: {e}") | |
self.questions = [] | |
except json.JSONDecodeError as e: | |
logger.error(f"Failed to parse JSON response: {e}") | |
print(f"❌ Failed to parse questions from web API: {e}") | |
self.questions = [] | |
def get_random_question(self) -> Optional[Dict]: | |
"""Get a random question from the API""" | |
try: | |
logger.info(f"Getting random question from: {self.api_base}/random-question") | |
response = self._make_request("get", "random-question", timeout=15) | |
question = response.json() | |
task_id = question.get('task_id', 'Unknown') | |
logger.info(f"Successfully retrieved random question: {task_id}") | |
return question | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Failed to get random question: {e}") | |
# Fallback to local random selection | |
import random | |
return random.choice(self.questions) if self.questions else None | |
except json.JSONDecodeError as e: | |
logger.error(f"Failed to parse random question response: {e}") | |
return None | |
def get_question_by_id(self, task_id: str) -> Optional[Dict]: | |
"""Get a specific question by task ID""" | |
return next((q for q in self.questions if q.get('task_id') == task_id), None) | |
def get_questions_by_level(self, level: str) -> List[Dict]: | |
"""Get all questions of a specific difficulty level""" | |
return [q for q in self.questions if q.get('Level') == level] | |
def get_questions_with_files(self) -> List[Dict]: | |
"""Get all questions that have associated files""" | |
return [q for q in self.questions if q.get('file_name')] | |
def get_questions_without_files(self) -> List[Dict]: | |
"""Get all questions that don't have associated files""" | |
return [q for q in self.questions if not q.get('file_name')] | |
def count_by_level(self) -> Dict[str, int]: | |
"""Count questions by difficulty level""" | |
levels = {} | |
for q in self.questions: | |
level = q.get('Level', 'Unknown') | |
levels[level] = levels.get(level, 0) + 1 | |
return levels | |
def summary(self) -> Dict: | |
"""Get a summary of loaded questions""" | |
return { | |
'total_questions': len(self.questions), | |
'with_files': len(self.get_questions_with_files()), | |
'without_files': len(self.get_questions_without_files()), | |
'by_level': self.count_by_level(), | |
'api_base': self.api_base, | |
'username': self.username | |
} | |
def download_file(self, task_id: str, save_dir: str = "./downloads") -> Optional[str]: | |
"""Download a file associated with a question""" | |
try: | |
import os | |
from pathlib import Path | |
# Create download directory | |
Path(save_dir).mkdir(exist_ok=True) | |
logger.info(f"Downloading file for task: {task_id}") | |
response = self._make_request("get", f"files/{task_id}", timeout=30) | |
# Try to get filename from headers | |
filename = task_id | |
if 'content-disposition' in response.headers: | |
import re | |
match = re.search(r'filename="?([^"]+)"?', response.headers['content-disposition']) | |
if match: | |
filename = match.group(1) | |
# Save file | |
file_path = Path(save_dir) / filename | |
with open(file_path, 'wb') as f: | |
f.write(response.content) | |
logger.info(f"File downloaded successfully: {file_path}") | |
return str(file_path) | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Failed to download file for task {task_id}: {e}") | |
return None | |
except Exception as e: | |
logger.error(f"Error saving file for task {task_id}: {e}") | |
return None | |
def test_api_connection(self) -> bool: | |
"""Test connectivity to the GAIA API""" | |
try: | |
logger.info(f"Testing API connection to: {self.api_base}") | |
response = self._make_request("get", "questions", timeout=10) | |
logger.info("✅ API connection successful") | |
return True | |
except Exception as e: | |
logger.error(f"❌ API connection failed: {e}") | |
return False |