from datetime import datetime, timezone import json import os import uuid from fastapi import APIRouter, HTTPException from pydantic import BaseModel, Field, validator from src.envs import API, RESULTS_REPO, EVAL_RESULTS_PATH, TOKEN router = APIRouter(prefix="/api", tags=["submission"]) ALL_SUBJECTIVE_FIELDS = [ "readability", "relevance", "explanation_clarity", "problem_identification", "actionability", "completeness", "specificity", "contextual_adequacy", "consistency", "brevity", ] class ResultPayload(BaseModel): model: str = Field(..., description="Model id on the Hub (e.g. org/model)") revision: str = Field("main", description="Commit sha or branch (default: main)") bleu: float = Field(..., ge=0, description="BLEU score (0-100)") # 10 subjective metrics 0-5 readability: int = Field(..., ge=0, le=5) relevance: int = Field(..., ge=0, le=5) explanation_clarity: int = Field(..., ge=0, le=5) problem_identification: int = Field(..., ge=0, le=5) actionability: int = Field(..., ge=0, le=5) completeness: int = Field(..., ge=0, le=5) specificity: int = Field(..., ge=0, le=5) contextual_adequacy: int = Field(..., ge=0, le=5) consistency: int = Field(..., ge=0, le=5) brevity: int = Field(..., ge=0, le=5) pass_at_1: float = Field(..., ge=0, le=1) pass_at_5: float = Field(..., ge=0, le=1) pass_at_10: float = Field(..., ge=0, le=1) @validator("pass_at_5") def _p5_ge_p1(cls, v, values): if "pass_at_1" in values and v < values["pass_at_1"]: raise ValueError("pass@5 must be >= pass@1") return v @validator("pass_at_10") def _p10_ge_p5(cls, v, values): if "pass_at_5" in values and v < values["pass_at_5"]: raise ValueError("pass@10 must be >= pass@5") return v def multimetric(self) -> float: total = sum(getattr(self, f) for f in ALL_SUBJECTIVE_FIELDS) return float(total) / len(ALL_SUBJECTIVE_FIELDS) @router.post("/submit", status_code=200) async def submit_results(payload: ResultPayload): """Accept new evaluation results and push them to the results dataset.""" # Prepare JSON in expected format (compatible with read_evals.py) results_dict = { "config": { "model_dtype": "unknown", "model_name": payload.model, "model_sha": payload.revision, }, "results": {}, } # Primary metrics results_dict["results"]["bleu"] = {"score": payload.bleu} results_dict["results"]["multimetric"] = {"score": payload.multimetric()} # Subjective metrics for field in ALL_SUBJECTIVE_FIELDS: results_dict["results"][field] = {"score": getattr(payload, field)} # Pass@k metrics results_dict["results"]["pass_at_1"] = {"score": payload.pass_at_1} results_dict["results"]["pass_at_5"] = {"score": payload.pass_at_5} results_dict["results"]["pass_at_10"] = {"score": payload.pass_at_10} # File handling os.makedirs(EVAL_RESULTS_PATH, exist_ok=True) ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") unique_id = uuid.uuid4().hex[:8] filename = f"results_{payload.model.replace('/', '_')}_{ts}_{unique_id}.json" local_path = os.path.join(EVAL_RESULTS_PATH, filename) with open(local_path, "w") as fp: json.dump(results_dict, fp) try: API.upload_file( path_or_fileobj=local_path, path_in_repo=filename, repo_id=RESULTS_REPO, repo_type="dataset", commit_message=f"Add results for {payload.model}", ) except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to upload results: {e}") finally: if os.path.exists(local_path): os.remove(local_path) return {"status": "ok", "detail": "Results submitted."}