|
from fastapi import APIRouter, HTTPException |
|
import random |
|
from datasets import load_dataset |
|
from huggingface_hub import HfApi, dataset_info |
|
import os |
|
|
|
router = APIRouter(tags=["benchmark"]) |
|
|
|
@router.get("/benchmark-questions/{session_id}") |
|
async def get_benchmark_questions(session_id: str): |
|
""" |
|
Get example questions from the generated benchmark |
|
|
|
Args: |
|
session_id: Session ID for the benchmark |
|
|
|
Returns: |
|
Dictionary with sample questions from the dataset |
|
""" |
|
try: |
|
|
|
dataset_repo_id = f"yourbench/yourbench_{session_id}" |
|
|
|
|
|
response = { |
|
"success": False, |
|
"questions": [], |
|
"dataset_url": f"https://huggingface.co/datasets/{dataset_repo_id}" |
|
} |
|
|
|
|
|
questions = [] |
|
|
|
try: |
|
|
|
single_dataset = load_dataset(dataset_repo_id, 'single_shot_questions') |
|
if single_dataset and len(single_dataset['train']) > 0: |
|
|
|
sample_indices = random.sample(range(len(single_dataset['train'])), min(2, len(single_dataset['train']))) |
|
for idx in sample_indices: |
|
questions.append({ |
|
"id": str(idx), |
|
"question": single_dataset['train'][idx].get("question", ""), |
|
"type": "single_shot" |
|
}) |
|
print(f"Loaded {len(questions)} single-shot questions") |
|
except Exception as e: |
|
print(f"Error loading single-shot questions: {str(e)}") |
|
|
|
try: |
|
|
|
if len(questions) < 2: |
|
multi_dataset = load_dataset(dataset_repo_id, 'multi_hop_questions') |
|
if multi_dataset and len(multi_dataset['train']) > 0: |
|
|
|
remaining = 2 - len(questions) |
|
sample_indices = random.sample(range(len(multi_dataset['train'])), min(remaining, len(multi_dataset['train']))) |
|
for idx in sample_indices: |
|
questions.append({ |
|
"id": str(idx), |
|
"question": multi_dataset['train'][idx].get("question", ""), |
|
"type": "multi_hop" |
|
}) |
|
print(f"Loaded {len(questions)} multi-hop questions") |
|
except Exception as e: |
|
print(f"Error loading multi-hop questions: {str(e)}") |
|
|
|
|
|
if len(questions) == 0: |
|
|
|
session_dir = os.path.join("uploaded_files", session_id) |
|
if not os.path.exists(session_dir): |
|
raise HTTPException(status_code=404, detail="Dataset not found") |
|
|
|
|
|
response["success"] = len(questions) > 0 |
|
response["questions"] = questions |
|
return response |
|
|
|
except HTTPException: |
|
|
|
raise |
|
except Exception as e: |
|
return { |
|
"success": False, |
|
"error": f"Error retrieving benchmark questions: {str(e)}" |
|
} |