File size: 3,913 Bytes
e7ea9f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from datetime import datetime, timezone
import json
import os
import uuid

from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field, validator

from src.envs import API, RESULTS_REPO, EVAL_RESULTS_PATH, TOKEN

router = APIRouter(prefix="/api", tags=["submission"])

ALL_SUBJECTIVE_FIELDS = [
    "readability",
    "relevance",
    "explanation_clarity",
    "problem_identification",
    "actionability",
    "completeness",
    "specificity",
    "contextual_adequacy",
    "consistency",
    "brevity",
]


class ResultPayload(BaseModel):
    model: str = Field(..., description="Model id on the Hub (e.g. org/model)")
    revision: str = Field("main", description="Commit sha or branch (default: main)")
    bleu: float = Field(..., ge=0, description="BLEU score (0-100)")

    # 10 subjective metrics 0-5
    readability: int = Field(..., ge=0, le=5)
    relevance: int = Field(..., ge=0, le=5)
    explanation_clarity: int = Field(..., ge=0, le=5)
    problem_identification: int = Field(..., ge=0, le=5)
    actionability: int = Field(..., ge=0, le=5)
    completeness: int = Field(..., ge=0, le=5)
    specificity: int = Field(..., ge=0, le=5)
    contextual_adequacy: int = Field(..., ge=0, le=5)
    consistency: int = Field(..., ge=0, le=5)
    brevity: int = Field(..., ge=0, le=5)

    pass_at_1: float = Field(..., ge=0, le=1)
    pass_at_5: float = Field(..., ge=0, le=1)
    pass_at_10: float = Field(..., ge=0, le=1)

    @validator("pass_at_5")
    def _p5_ge_p1(cls, v, values):
        if "pass_at_1" in values and v < values["pass_at_1"]:
            raise ValueError("pass@5 must be >= pass@1")
        return v

    @validator("pass_at_10")
    def _p10_ge_p5(cls, v, values):
        if "pass_at_5" in values and v < values["pass_at_5"]:
            raise ValueError("pass@10 must be >= pass@5")
        return v

    def multimetric(self) -> float:
        total = sum(getattr(self, f) for f in ALL_SUBJECTIVE_FIELDS)
        return float(total) / len(ALL_SUBJECTIVE_FIELDS)


@router.post("/submit", status_code=200)
async def submit_results(payload: ResultPayload):
    """Accept new evaluation results and push them to the results dataset."""

    # Prepare JSON in expected format (compatible with read_evals.py)
    results_dict = {
        "config": {
            "model_dtype": "unknown",
            "model_name": payload.model,
            "model_sha": payload.revision,
        },
        "results": {},
    }

    # Primary metrics
    results_dict["results"]["bleu"] = {"score": payload.bleu}
    results_dict["results"]["multimetric"] = {"score": payload.multimetric()}

    # Subjective metrics
    for field in ALL_SUBJECTIVE_FIELDS:
        results_dict["results"][field] = {"score": getattr(payload, field)}

    # Pass@k metrics
    results_dict["results"]["pass_at_1"] = {"score": payload.pass_at_1}
    results_dict["results"]["pass_at_5"] = {"score": payload.pass_at_5}
    results_dict["results"]["pass_at_10"] = {"score": payload.pass_at_10}

    # File handling
    os.makedirs(EVAL_RESULTS_PATH, exist_ok=True)
    ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
    unique_id = uuid.uuid4().hex[:8]
    filename = f"results_{payload.model.replace('/', '_')}_{ts}_{unique_id}.json"
    local_path = os.path.join(EVAL_RESULTS_PATH, filename)

    with open(local_path, "w") as fp:
        json.dump(results_dict, fp)

    try:
        API.upload_file(
            path_or_fileobj=local_path,
            path_in_repo=filename,
            repo_id=RESULTS_REPO,
            repo_type="dataset",
            commit_message=f"Add results for {payload.model}",
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Failed to upload results: {e}")
    finally:
        if os.path.exists(local_path):
            os.remove(local_path)

    return {"status": "ok", "detail": "Results submitted."}