Tsunnami commited on
Commit
292e32a
·
verified ·
1 Parent(s): 2044667

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -61
app.py CHANGED
@@ -1,65 +1,24 @@
1
  import gradio as gr
2
- import sys
3
- import subprocess
4
- import pkg_resources
5
-
6
- def install_package(package):
7
- try:
8
- subprocess.check_call([sys.executable, "-m", "pip", "install", package])
9
- return True
10
- except:
11
- return False
12
-
13
- # Check and install missing dependencies
14
- required_packages = ["torch", "unbabel-comet"]
15
- missing_packages = []
16
-
17
- for package in required_packages:
18
- try:
19
- pkg_resources.get_distribution(package)
20
- except pkg_resources.DistributionNotFound:
21
- missing_packages.append(package)
22
-
23
- if missing_packages:
24
- print(f"Missing packages: {', '.join(missing_packages)}")
25
- for package in missing_packages:
26
- if install_package(package):
27
- print(f"Successfully installed {package}")
28
- else:
29
- print(f"Failed to install {package}")
30
- print(f"Please install manually: pip install {' '.join(required_packages)}")
31
- sys.exit(1)
32
-
33
- # Now import torch and comet after ensuring they're installed
34
- import torch
35
  from comet import download_model, load_from_checkpoint
36
 
37
- def evaluate_translation(src_text, mt_text):
38
- if not hasattr(evaluate_translation, "model"):
39
- model_path = download_model("wasanx/ComeTH")
40
- evaluate_translation.model = load_from_checkpoint(model_path)
41
-
42
- translations = [{"src": src_text, "mt": mt_text}]
43
- results = evaluate_translation.model.predict(
44
- translations,
45
- batch_size=1,
46
- gpus=0
47
- )
48
- return float(results['scores'][0])
49
-
50
- demo = gr.Interface(
51
- fn=evaluate_translation,
52
- inputs=[
53
- gr.Textbox(label="English Source Text"),
54
- gr.Textbox(label="Thai Translation")
55
- ],
56
- outputs=gr.Number(label="Quality Score"),
57
- examples=[
58
- ["This is a test sentence.", "นี่คือประโยคทดสอบ"],
59
- ["The weather is nice today.", "อากาศดีมากวันนี้"]
60
- ],
61
- title="ComeTH Translator Evaluator"
62
- )
63
 
64
- if __name__ == "__main__":
65
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from comet import download_model, load_from_checkpoint
3
 
4
+ model_path = download_model("wasanx/ComeTH")
5
+ model = load_from_checkpoint(model_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ def score_translation(src_text, mt_text):
8
+ translations = [{"src": src_text, "mt": mt_text}]
9
+ results = model.predict(translations, batch_size=1, gpus=1)
10
+ scores = results.get('scores', [[]])[0]
11
+ return {f"Score {i+1}": score for i, score in enumerate(scores)}
12
+
13
+ with gr.Blocks() as demo:
14
+ gr.Markdown("# Translation Quality Scoring with ComeTH Model")
15
+ with gr.Row():
16
+ with gr.Column():
17
+ src_input = gr.Textbox(label="Source Text (English)")
18
+ mt_input = gr.Textbox(label="Machine Translation Text (Thai)")
19
+ with gr.Column():
20
+ score_output = gr.Label(num_top_classes=5, label="Quality Scores")
21
+ score_button = gr.Button("Score Translation")
22
+ score_button.click(fn=score_translation, inputs=[src_input, mt_input], outputs=[score_output])
23
+
24
+ demo.launch(server_name="0.0.0.0", server_port=7860)