Spaces:
Sleeping
Sleeping
File size: 3,913 Bytes
e7ea9f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
from datetime import datetime, timezone
import json
import os
import uuid
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field, validator
from src.envs import API, RESULTS_REPO, EVAL_RESULTS_PATH, TOKEN
router = APIRouter(prefix="/api", tags=["submission"])
ALL_SUBJECTIVE_FIELDS = [
"readability",
"relevance",
"explanation_clarity",
"problem_identification",
"actionability",
"completeness",
"specificity",
"contextual_adequacy",
"consistency",
"brevity",
]
class ResultPayload(BaseModel):
model: str = Field(..., description="Model id on the Hub (e.g. org/model)")
revision: str = Field("main", description="Commit sha or branch (default: main)")
bleu: float = Field(..., ge=0, description="BLEU score (0-100)")
# 10 subjective metrics 0-5
readability: int = Field(..., ge=0, le=5)
relevance: int = Field(..., ge=0, le=5)
explanation_clarity: int = Field(..., ge=0, le=5)
problem_identification: int = Field(..., ge=0, le=5)
actionability: int = Field(..., ge=0, le=5)
completeness: int = Field(..., ge=0, le=5)
specificity: int = Field(..., ge=0, le=5)
contextual_adequacy: int = Field(..., ge=0, le=5)
consistency: int = Field(..., ge=0, le=5)
brevity: int = Field(..., ge=0, le=5)
pass_at_1: float = Field(..., ge=0, le=1)
pass_at_5: float = Field(..., ge=0, le=1)
pass_at_10: float = Field(..., ge=0, le=1)
@validator("pass_at_5")
def _p5_ge_p1(cls, v, values):
if "pass_at_1" in values and v < values["pass_at_1"]:
raise ValueError("pass@5 must be >= pass@1")
return v
@validator("pass_at_10")
def _p10_ge_p5(cls, v, values):
if "pass_at_5" in values and v < values["pass_at_5"]:
raise ValueError("pass@10 must be >= pass@5")
return v
def multimetric(self) -> float:
total = sum(getattr(self, f) for f in ALL_SUBJECTIVE_FIELDS)
return float(total) / len(ALL_SUBJECTIVE_FIELDS)
@router.post("/submit", status_code=200)
async def submit_results(payload: ResultPayload):
"""Accept new evaluation results and push them to the results dataset."""
# Prepare JSON in expected format (compatible with read_evals.py)
results_dict = {
"config": {
"model_dtype": "unknown",
"model_name": payload.model,
"model_sha": payload.revision,
},
"results": {},
}
# Primary metrics
results_dict["results"]["bleu"] = {"score": payload.bleu}
results_dict["results"]["multimetric"] = {"score": payload.multimetric()}
# Subjective metrics
for field in ALL_SUBJECTIVE_FIELDS:
results_dict["results"][field] = {"score": getattr(payload, field)}
# Pass@k metrics
results_dict["results"]["pass_at_1"] = {"score": payload.pass_at_1}
results_dict["results"]["pass_at_5"] = {"score": payload.pass_at_5}
results_dict["results"]["pass_at_10"] = {"score": payload.pass_at_10}
# File handling
os.makedirs(EVAL_RESULTS_PATH, exist_ok=True)
ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
unique_id = uuid.uuid4().hex[:8]
filename = f"results_{payload.model.replace('/', '_')}_{ts}_{unique_id}.json"
local_path = os.path.join(EVAL_RESULTS_PATH, filename)
with open(local_path, "w") as fp:
json.dump(results_dict, fp)
try:
API.upload_file(
path_or_fileobj=local_path,
path_in_repo=filename,
repo_id=RESULTS_REPO,
repo_type="dataset",
commit_message=f"Add results for {payload.model}",
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to upload results: {e}")
finally:
if os.path.exists(local_path):
os.remove(local_path)
return {"status": "ok", "detail": "Results submitted."} |