Tsunnami commited on
Commit
f7f6e8d
·
verified ·
1 Parent(s): 0436b7b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -11
app.py CHANGED
@@ -1,24 +1,38 @@
1
  import gradio as gr
2
  from comet import download_model, load_from_checkpoint
 
3
 
4
- model_path = download_model("wasanx/ComeTH")
5
  model = load_from_checkpoint(model_path)
6
 
7
  def score_translation(src_text, mt_text):
8
  translations = [{"src": src_text, "mt": mt_text}]
9
  results = model.predict(translations, batch_size=1, gpus=1)
10
- scores = results['scores'][0]
11
- return scores
12
 
13
- with gr.Blocks() as demo:
14
- gr.Markdown("# Translation Quality Scoring with ComeTH Model")
 
 
 
 
 
 
 
 
15
  with gr.Row():
16
  with gr.Column():
17
- src_input = gr.Textbox(label="Source Text (English)")
18
- mt_input = gr.Textbox(label="Machine Translation Text (Thai)")
 
 
19
  with gr.Column():
20
- score_output = gr.Label(num_top_classes=5, label="Quality Scores")
21
- score_button = gr.Button("Score Translation")
22
- score_button.click(fn=score_translation, inputs=[src_input, mt_input], outputs=[score_output])
 
 
 
23
 
24
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
1
  import gradio as gr
2
  from comet import download_model, load_from_checkpoint
3
+ import os
4
 
5
+ model_path = os.environ.get("HF_MODEL_PATH", download_model("wasanx/ComeTH"))
6
  model = load_from_checkpoint(model_path)
7
 
8
  def score_translation(src_text, mt_text):
9
  translations = [{"src": src_text, "mt": mt_text}]
10
  results = model.predict(translations, batch_size=1, gpus=1)
11
+ return results['scores'][0]
 
12
 
13
+ examples = [
14
+ ["The weather is beautiful today.", "วันนี้อากาศดีมาก"],
15
+ ["I need to go to the hospital.", "ฉันต้องไปโรงพยาบาล"],
16
+ ["This restaurant serves delicious food.", "ร้านอาหารนี้เสิร์ฟอาหารอร่อย"],
17
+ ["Can you help me find the nearest train station?", "คุณช่วยฉันหาสถานีรถไฟที่ใกล้ที่สุดได้ไหม"]
18
+ ]
19
+
20
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
21
+ gr.Markdown("# ComeTH Translation Quality Evaluator")
22
+
23
  with gr.Row():
24
  with gr.Column():
25
+ src_input = gr.Textbox(label="Source Text (English)", placeholder="Enter English text here...")
26
+ mt_input = gr.Textbox(label="Machine Translation (Thai)", placeholder="Enter Thai translation here...")
27
+ score_button = gr.Button("Evaluate Translation", variant="primary")
28
+
29
  with gr.Column():
30
+ score_output = gr.Label(label="Quality Scores")
31
+ gr.Markdown("### Higher scores indicate better translation quality across multiple dimensions")
32
+
33
+ gr.Examples(examples=examples, inputs=[src_input, mt_input], outputs=score_output, fn=score_translation)
34
+
35
+ score_button.click(fn=score_translation, inputs=[src_input, mt_input], outputs=score_output)
36
 
37
+ if __name__ == "__main__":
38
+ demo.launch()