|
from fastapi import APIRouter, HTTPException |
|
from fastapi.responses import StreamingResponse |
|
from huggingface_hub import hf_hub_download, snapshot_download |
|
import os |
|
import tempfile |
|
import shutil |
|
import zipfile |
|
import io |
|
import logging |
|
|
|
router = APIRouter(tags=["download"]) |
|
|
|
@router.get("/download-dataset/{session_id}") |
|
async def download_dataset(session_id: str): |
|
""" |
|
Télécharge le dataset HuggingFace associé à une session et le renvoie au client |
|
|
|
Args: |
|
session_id: Identifiant de la session |
|
|
|
Returns: |
|
Fichier ZIP contenant le dataset |
|
""" |
|
try: |
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
|
repo_id = f"yourbench/yourbench_{session_id}" |
|
|
|
try: |
|
|
|
logging.info(f"Téléchargement du dataset {repo_id}") |
|
snapshot_path = snapshot_download( |
|
repo_id=repo_id, |
|
repo_type="dataset", |
|
local_dir=temp_dir, |
|
token=os.environ.get("HF_TOKEN") |
|
) |
|
|
|
logging.info(f"Dataset téléchargé dans {snapshot_path}") |
|
|
|
|
|
zip_io = io.BytesIO() |
|
with zipfile.ZipFile(zip_io, 'w', zipfile.ZIP_DEFLATED) as zip_file: |
|
|
|
for root, _, files in os.walk(snapshot_path): |
|
for file in files: |
|
file_path = os.path.join(root, file) |
|
arc_name = os.path.relpath(file_path, snapshot_path) |
|
zip_file.write(file_path, arcname=arc_name) |
|
|
|
|
|
zip_io.seek(0) |
|
|
|
|
|
filename = f"yourbench_{session_id}_dataset.zip" |
|
return StreamingResponse( |
|
zip_io, |
|
media_type="application/zip", |
|
headers={"Content-Disposition": f"attachment; filename={filename}"} |
|
) |
|
|
|
except Exception as e: |
|
logging.error(f"Erreur lors du téléchargement du dataset: {str(e)}") |
|
raise HTTPException( |
|
status_code=500, |
|
detail=f"Erreur lors du téléchargement du dataset: {str(e)}" |
|
) |
|
except Exception as e: |
|
logging.error(f"Erreur générale: {str(e)}") |
|
raise HTTPException( |
|
status_code=500, |
|
detail=f"Erreur lors du téléchargement: {str(e)}" |
|
) |