|
import gradio as gr |
|
import sys |
|
import subprocess |
|
import pkg_resources |
|
|
|
def install_package(package): |
|
try: |
|
subprocess.check_call([sys.executable, "-m", "pip", "install", package]) |
|
return True |
|
except: |
|
return False |
|
|
|
|
|
required_packages = ["torch", "comet"] |
|
missing_packages = [] |
|
|
|
for package in required_packages: |
|
try: |
|
pkg_resources.get_distribution(package) |
|
except pkg_resources.DistributionNotFound: |
|
missing_packages.append(package) |
|
|
|
if missing_packages: |
|
print(f"Missing packages: {', '.join(missing_packages)}") |
|
for package in missing_packages: |
|
if install_package(package): |
|
print(f"Successfully installed {package}") |
|
else: |
|
print(f"Failed to install {package}") |
|
print(f"Please install manually: pip install {' '.join(required_packages)}") |
|
sys.exit(1) |
|
|
|
|
|
import torch |
|
from comet import download_model, load_from_checkpoint |
|
|
|
def evaluate_translation(src_text, mt_text): |
|
if not hasattr(evaluate_translation, "model"): |
|
model_path = download_model("wasanx/ComeTH") |
|
evaluate_translation.model = load_from_checkpoint(model_path) |
|
|
|
translations = [{"src": src_text, "mt": mt_text}] |
|
results = evaluate_translation.model.predict( |
|
translations, |
|
batch_size=1, |
|
gpus=0 |
|
) |
|
return float(results['scores'][0]) |
|
|
|
demo = gr.Interface( |
|
fn=evaluate_translation, |
|
inputs=[ |
|
gr.Textbox(label="English Source Text"), |
|
gr.Textbox(label="Thai Translation") |
|
], |
|
outputs=gr.Number(label="Quality Score"), |
|
examples=[ |
|
["This is a test sentence.", "นี่คือประโยคทดสอบ"], |
|
["The weather is nice today.", "อากาศดีมากวันนี้"] |
|
], |
|
title="ComeTH Translator Evaluator" |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |