|
from fastapi import APIRouter, HTTPException |
|
from fastapi.responses import StreamingResponse, JSONResponse |
|
from huggingface_hub import hf_hub_download, snapshot_download |
|
import os |
|
import tempfile |
|
import shutil |
|
import zipfile |
|
import io |
|
import logging |
|
import json |
|
from datasets import load_dataset |
|
|
|
router = APIRouter(tags=["download"]) |
|
|
|
@router.get("/download-dataset/{session_id}") |
|
async def download_dataset(session_id: str): |
|
""" |
|
Télécharge le dataset HuggingFace associé à une session et le renvoie au client |
|
|
|
Args: |
|
session_id: Identifiant de la session |
|
|
|
Returns: |
|
Fichier ZIP contenant le dataset |
|
""" |
|
try: |
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
|
repo_id = f"yourbench/yourbench_{session_id}" |
|
|
|
try: |
|
|
|
logging.info(f"Téléchargement du dataset {repo_id}") |
|
snapshot_path = snapshot_download( |
|
repo_id=repo_id, |
|
repo_type="dataset", |
|
local_dir=temp_dir, |
|
token=os.environ.get("HF_TOKEN") |
|
) |
|
|
|
logging.info(f"Dataset téléchargé dans {snapshot_path}") |
|
|
|
|
|
zip_io = io.BytesIO() |
|
with zipfile.ZipFile(zip_io, 'w', zipfile.ZIP_DEFLATED) as zip_file: |
|
|
|
for root, _, files in os.walk(snapshot_path): |
|
for file in files: |
|
file_path = os.path.join(root, file) |
|
arc_name = os.path.relpath(file_path, snapshot_path) |
|
zip_file.write(file_path, arcname=arc_name) |
|
|
|
|
|
zip_io.seek(0) |
|
|
|
|
|
filename = f"yourbench_{session_id}_dataset.zip" |
|
return StreamingResponse( |
|
zip_io, |
|
media_type="application/zip", |
|
headers={"Content-Disposition": f"attachment; filename={filename}"} |
|
) |
|
|
|
except Exception as e: |
|
logging.error(f"Erreur lors du téléchargement du dataset: {str(e)}") |
|
raise HTTPException( |
|
status_code=500, |
|
detail=f"Erreur lors du téléchargement du dataset: {str(e)}" |
|
) |
|
except Exception as e: |
|
logging.error(f"Erreur générale: {str(e)}") |
|
raise HTTPException( |
|
status_code=500, |
|
detail=f"Erreur lors du téléchargement: {str(e)}" |
|
) |
|
|
|
@router.get("/download-questions/{session_id}") |
|
async def download_questions(session_id: str): |
|
""" |
|
Télécharge les questions générées pour une session au format JSON |
|
|
|
Args: |
|
session_id: Identifiant de la session |
|
|
|
Returns: |
|
Fichier JSON contenant les questions générées |
|
""" |
|
try: |
|
|
|
dataset_repo_id = f"yourbench/yourbench_{session_id}" |
|
|
|
|
|
all_questions = [] |
|
|
|
|
|
try: |
|
single_dataset = load_dataset(dataset_repo_id, 'single_shot_questions') |
|
if single_dataset and len(single_dataset['train']) > 0: |
|
for idx in range(len(single_dataset['train'])): |
|
all_questions.append({ |
|
"id": str(idx), |
|
"question": single_dataset['train'][idx].get("question", ""), |
|
"answer": single_dataset['train'][idx].get("self_answer", "No answer available"), |
|
"type": "single_shot" |
|
}) |
|
logging.info(f"Loaded {len(all_questions)} single-shot questions") |
|
except Exception as e: |
|
logging.error(f"Error loading single-shot questions: {str(e)}") |
|
|
|
|
|
try: |
|
multi_dataset = load_dataset(dataset_repo_id, 'multi_hop_questions') |
|
if multi_dataset and len(multi_dataset['train']) > 0: |
|
start_idx = len(all_questions) |
|
for idx in range(len(multi_dataset['train'])): |
|
all_questions.append({ |
|
"id": str(start_idx + idx), |
|
"question": multi_dataset['train'][idx].get("question", ""), |
|
"answer": multi_dataset['train'][idx].get("self_answer", "No answer available"), |
|
"type": "multi_hop" |
|
}) |
|
logging.info(f"Loaded {len(multi_dataset['train'])} multi-hop questions") |
|
except Exception as e: |
|
logging.error(f"Error loading multi-hop questions: {str(e)}") |
|
|
|
|
|
if len(all_questions) == 0: |
|
raise HTTPException(status_code=404, detail="Aucune question trouvée pour cette session") |
|
|
|
|
|
questions_json = json.dumps({ |
|
"session_id": session_id, |
|
"questions": all_questions |
|
}, ensure_ascii=False, indent=2) |
|
|
|
|
|
json_bytes = io.BytesIO(questions_json.encode('utf-8')) |
|
json_bytes.seek(0) |
|
|
|
|
|
filename = f"yourbench_{session_id}_questions.json" |
|
return StreamingResponse( |
|
json_bytes, |
|
media_type="application/json", |
|
headers={"Content-Disposition": f"attachment; filename={filename}"} |
|
) |
|
|
|
except HTTPException: |
|
|
|
raise |
|
except Exception as e: |
|
logging.error(f"Erreur lors de la récupération des questions: {str(e)}") |
|
raise HTTPException( |
|
status_code=500, |
|
detail=f"Erreur lors du téléchargement des questions: {str(e)}" |
|
) |