|
import time |
|
from typing import List, Dict, Any, Optional, Union |
|
import numpy as np |
|
from .mini_bench.reward_agent import ProgressJudgeAgent |
|
from .reward_postprocessor import REWARD_PROCESSORS, REWARD_PROCESSOR_N_SAMPLES, extract_judge_hash |
|
import json |
|
import os |
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
|
def _process_unit(idx, unit, configs, n_samples, reward_processor, max_retries=5): |
|
"""하나의 unit을 처리해 (idx, reward, thought)를 돌려준다.""" |
|
agent = ProgressJudgeAgent(configs) |
|
current_temperature = configs["temperature"] |
|
|
|
rewards = [] |
|
n_err = 0 |
|
retry_count = 0 |
|
judge_hash_count_thought = {} |
|
|
|
while len(rewards) < n_samples and retry_count < max_retries: |
|
|
|
responses, _ = agent.generate_probs( |
|
unit, "ours", n=n_samples - len(rewards), temperature=current_temperature |
|
) |
|
|
|
for response in responses: |
|
content = response["response"] |
|
thought = content |
|
reward = REWARD_PROCESSORS[reward_processor](response) |
|
rewards.append(reward) |
|
|
|
if np.isnan(reward) or reward is None: |
|
n_err += 1 |
|
else: |
|
judge_hash = extract_judge_hash(response) |
|
judge_hash_count_thought[judge_hash] = (judge_hash_count_thought.get(judge_hash, (0, None))[0] + 1, thought) |
|
|
|
if n_err > 0: |
|
|
|
if n_samples == 1: |
|
current_temperature = 0.5 |
|
retry_count += 1 |
|
|
|
reward = np.nanmean(rewards) |
|
if np.isnan(reward): |
|
print(f"[idx={idx}] Warning: reward is NaN after retries -> set 0") |
|
reward = 0.0 |
|
print(judge_hash_count_thought) |
|
thought = max(judge_hash_count_thought.values(), key=lambda x: x[0])[1] |
|
|
|
return idx, reward, thought |
|
|
|
|
|
def get_ar_reward(dataset, base_url, model_name, reward_processor='avg_logits', max_workers=8): |
|
"""원본 get_ar_reward를 스레드 버전으로 교체.""" |
|
n_samples = REWARD_PROCESSOR_N_SAMPLES[reward_processor] |
|
|
|
temperature = 0.5 if n_samples > 1 else 0.0 |
|
|
|
configs = { |
|
"model_name": model_name, |
|
"base_url": base_url, |
|
"api_key": "empty", |
|
"temperature": temperature, |
|
"num_generate": 1, |
|
"use_checklist": True, |
|
"input_type": "text_only", |
|
"text_obs_type": "axtree", |
|
"image_obs_type": "som", |
|
"use_in_progress": True, |
|
"use_multimodal": False, |
|
"use_log_probs": True, |
|
} |
|
|
|
t_start = time.time() |
|
results = [None] * len(dataset) |
|
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor: |
|
futures = [ |
|
executor.submit( |
|
_process_unit, idx, unit, configs, n_samples, reward_processor |
|
) |
|
for idx, unit in enumerate(dataset) |
|
] |
|
|
|
for fut in as_completed(futures): |
|
idx, reward, thought = fut.result() |
|
results[idx] = (reward, thought) |
|
|
|
|
|
final_rewards = [float(r) for r, _ in results] |
|
thoughts = [t for _, t in results] |
|
|
|
print(f"Time taken (threaded): {time.time() - t_start:.2f} s") |
|
return final_rewards, thoughts |
|
|
|
|