Tsunnami commited on
Commit
0aa33ec
·
verified ·
1 Parent(s): 6721032

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -7
app.py CHANGED
@@ -1,17 +1,30 @@
1
  import gradio as gr
2
- from comet import download_model, load_from_checkpoint
 
3
 
4
- def load_cometh_model():
5
- model_path = download_model("wasanx/ComeTH")
6
- return load_from_checkpoint(model_path)
 
 
 
7
 
8
  def evaluate_translation(src_text, mt_text):
 
 
 
 
 
 
 
9
  translations = [{"src": src_text, "mt": mt_text}]
10
- results = model.predict(translations, batch_size=1, gpus=1)
 
 
 
 
11
  return float(results['scores'][0])
12
 
13
- model = load_cometh_model()
14
-
15
  demo = gr.Interface(
16
  fn=evaluate_translation,
17
  inputs=[
 
1
  import gradio as gr
2
+ import sys
3
+ import torch
4
 
5
+ try:
6
+ from unbabel.comet import download_model, load_from_checkpoint
7
+ except ImportError:
8
+ print("Error: unbabel-comet package not installed")
9
+ print("Install with: pip install unbabel-comet torch gradio")
10
+ sys.exit(1)
11
 
12
  def evaluate_translation(src_text, mt_text):
13
+ if not hasattr(evaluate_translation, "model"):
14
+ try:
15
+ model_path = download_model("wasanx/ComeTH")
16
+ evaluate_translation.model = load_from_checkpoint(model_path)
17
+ except Exception as e:
18
+ return f"Error loading model: {str(e)}"
19
+
20
  translations = [{"src": src_text, "mt": mt_text}]
21
+ results = evaluate_translation.model.predict(
22
+ translations,
23
+ batch_size=1,
24
+ gpus=0 if not torch.cuda.is_available() else 1
25
+ )
26
  return float(results['scores'][0])
27
 
 
 
28
  demo = gr.Interface(
29
  fn=evaluate_translation,
30
  inputs=[