Pledge_Tracker / system /hero_pipeline.py
yulongchen's picture
Add system
d21fef3
import os
from datetime import datetime
import subprocess
from huggingface_hub import hf_hub_download
import json
def run_hero_reranking(pipeline_base_dir, suggestion_meta):
base_dir = f"{pipeline_base_dir}"
hero_dir = os.path.join(base_dir, "hero")
os.makedirs(hero_dir, exist_ok=True)
if suggestion_meta:
hyde_path = hf_hub_download(
repo_id="PledgeTracker/demo_feedback",
filename="manifesto_icl_hyde_fc.json",
repo_type="dataset",
token=os.environ["HF_TOKEN"]
)
with open(hyde_path, "r", encoding="utf-8") as f:
all_hyde_data = json.load(f)
idx = suggestion_meta["index"]
single_hyde = [all_hyde_data[idx]]
save_path = os.path.join(hero_dir, "manifesto_icl_hyde_fc.json")
with open(save_path, "w", encoding="utf-8") as f:
json.dump(single_hyde, f, indent=2)
hyde_output = os.path.join(hero_dir, "manifesto_icl_hyde_fc.json")
def safe_run(cmd, timeout=600):
try:
print(f"πŸ‘‰ Running: {' '.join(str(x) for x in cmd)}")
subprocess.run(cmd, check=True, timeout=timeout)
except subprocess.CalledProcessError as e:
print(f"[❌ ERROR] Subprocess failed: {e}")
if e.stderr:
print("[stderr]:", e.stderr.decode())
raise
except subprocess.TimeoutExpired:
print(f"[❌ TIMEOUT] Command timed out: {' '.join(cmd)}")
raise
# Step 3.2: retrieval
print("πŸ” Step 3.2: Retrieval from knowledge store ...")
knowledge_store_dir = os.path.join(base_dir, "augmented_data_store")
retrieval_output = os.path.join(hero_dir, "manifesto_icl_retrieval_top_k_QA.json")
if not os.path.exists(retrieval_output):
safe_run([
"python", "system/baseline/retrieval_optimized.py",
"--knowledge_store_dir", knowledge_store_dir,
"--target_data", hyde_output,
"--json_output", retrieval_output,
])
# Step 3.3: reranking
print("🏷️ Step 3.3: Reranking retrieved evidence ...")
rerank_output = os.path.join(hero_dir, "manifesto_icl_reranking_top_k_QA.json")
if not os.path.exists(rerank_output):
safe_run([
"python", "system/baseline/reranking_optimized.py",
"--target_data", retrieval_output,
"--json_output", rerank_output,
"--top_k", str(50),
])
return {
"hyde": hyde_output,
"retrieved": retrieval_output,
"reranked": rerank_output,
}
def run_hero_pipeline(pipeline_base_dir):
base_dir = f"{pipeline_base_dir}"
hero_dir = os.path.join(base_dir, "hero")
os.makedirs(hero_dir, exist_ok=True)
target_data = os.path.join(base_dir, "claim.json")
hyde_output = os.path.join(hero_dir, "manifesto_icl_hyde_fc.json")
def safe_run(cmd, timeout=600):
try:
print(f"πŸ‘‰ Running: {' '.join(cmd)}")
subprocess.run(cmd, check=True, timeout=timeout)
except subprocess.CalledProcessError as e:
print(f"[❌ ERROR] Subprocess failed: {e}")
if e.stderr:
print("[stderr]:", e.stderr.decode())
raise
except subprocess.TimeoutExpired:
print(f"[❌ TIMEOUT] Command timed out: {' '.join(cmd)}")
raise
# Step 3.1: hyde_fc_generation
if not os.path.exists(hyde_output):
print("🧠 Step 3.1: HyDE ICL generation ...")
safe_run([
"python", "system/baseline/hyde_fc_generation_optimized.py",
"--target_data", target_data,
"--json_output", hyde_output
])
# Step 3.2: retrieval
print("πŸ” Step 3.2: Retrieval from knowledge store ...")
knowledge_store_dir = os.path.join(base_dir, "initial_data_store")
retrieval_output = os.path.join(hero_dir, "manifesto_icl_retrieval_top_k.json")
if not os.path.exists(retrieval_output):
safe_run([
"python", "system/baseline/retrieval_optimized.py",
"--knowledge_store_dir", knowledge_store_dir,
"--target_data", hyde_output,
"--json_output", retrieval_output
])
# Step 3.3: reranking
print("🏷️ Step 3.3: Reranking retrieved evidence ...")
rerank_output = os.path.join(hero_dir, "manifesto_icl_reranking_top_k.json")
if not os.path.exists(rerank_output):
safe_run([
"python", "system/baseline/reranking_optimized.py",
"--target_data", retrieval_output,
"--json_output", rerank_output
])
# Step 3.4: question generation
print("❓ Step 3.4: Generating QA pairs ...")
reference_corpus = "system/baseline/train.json"
qa_output = os.path.join(hero_dir, "manifesto_icl_top_k_qa.json")
if not os.path.exists(qa_output):
safe_run([
"python", "system/baseline/question_generation_optimized.py",
"--reference_corpus", reference_corpus,
"--top_k_target_knowledge", rerank_output,
"--output_questions", qa_output,
"--model", "meta-llama/Meta-Llama-3.1-8B-Instruct"
])
return {
"hyde": hyde_output,
"retrieved": retrieval_output,
"reranked": rerank_output,
"qa_pairs": qa_output
}