COMETH / app.py
Tsunnami's picture
Update app.py
baf753b verified
import gradio as gr
from comet import download_model, load_from_checkpoint
import os
model_path = os.environ.get("HF_MODEL_PATH", download_model("wasanx/ComeTH"))
model = load_from_checkpoint(model_path)
def score_translation(src_text, mt_text):
translations = [{"src": src_text, "mt": mt_text}]
results = model.predict(translations, batch_size=1, gpus=1)
return results["scores"][0]
good_examples = [
["The weather is beautiful today.", "วันนี้อากาศดีมาก"],
["I need to go to the hospital.", "ฉันต้องไปโรงพยาบาล"],
["This restaurant serves delicious food.", "ร้านอาหารนี้เสิร์ฟอาหารอร่อย"],
["Can you help me find the nearest train station?", "คุณช่วยฉันหาสถานีรถไฟที่ใกล้ที่สุดได้ไหม"]
]
bad_examples = [
["The weather is beautiful today.", "วันนี้อากาศแย่สุดๆ"],
["I need to go to the hospital.", "ฉันอยากกินข้าว"],
["This restaurant serves delicious food.", "ร้านนี้ไม่อร่อยเลย"],
["Can you help me find the nearest train station?", "คุณพูดภาษาอังกฤษได้ไหม?"]
]
font_css = """
@import url("https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;700&display=swap");
* {
font-family: 'JetBrains Mono', monospace !important;
}
"""
with gr.Blocks(theme=gr.themes.Soft(), css=font_css) as demo:
gr.Markdown("# ComeTH Translation Quality Estimator")
with gr.Row():
with gr.Column(scale=1):
src_input = gr.Textbox(label="Source Text (English)", placeholder="Enter English text here...")
mt_input = gr.Textbox(label="Candidate Translation (Thai)", placeholder="Enter Thai translation here...")
score_button = gr.Button("Evaluate Translation", variant="primary")
with gr.Column(scale=1):
score_output = gr.Label(label="Quality Scores")
gr.Markdown("### Higher scores indicate better translation quality across multiple dimensions")
gr.Markdown("## Good Translation Examples")
gr.Examples(
examples=good_examples,
inputs=[src_input, mt_input],
outputs=score_output,
fn=score_translation
)
gr.Markdown("## Bad Translation Examples")
gr.Examples(
examples=bad_examples,
inputs=[src_input, mt_input],
outputs=score_output,
fn=score_translation
)
score_button.click(fn=score_translation, inputs=[src_input, mt_input], outputs=score_output)
if __name__ == "__main__":
demo.launch()