Tsunnami commited on
Commit
358f50f
·
verified ·
1 Parent(s): 0aa33ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -13
app.py CHANGED
@@ -1,27 +1,49 @@
1
  import gradio as gr
2
  import sys
3
- import torch
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- try:
6
- from unbabel.comet import download_model, load_from_checkpoint
7
- except ImportError:
8
- print("Error: unbabel-comet package not installed")
9
- print("Install with: pip install unbabel-comet torch gradio")
10
- sys.exit(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def evaluate_translation(src_text, mt_text):
13
  if not hasattr(evaluate_translation, "model"):
14
- try:
15
- model_path = download_model("wasanx/ComeTH")
16
- evaluate_translation.model = load_from_checkpoint(model_path)
17
- except Exception as e:
18
- return f"Error loading model: {str(e)}"
19
 
20
  translations = [{"src": src_text, "mt": mt_text}]
21
  results = evaluate_translation.model.predict(
22
  translations,
23
  batch_size=1,
24
- gpus=0 if not torch.cuda.is_available() else 1
25
  )
26
  return float(results['scores'][0])
27
 
 
1
  import gradio as gr
2
  import sys
3
+ import subprocess
4
+ import pkg_resources
5
+
6
+ def install_package(package):
7
+ try:
8
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
9
+ return True
10
+ except:
11
+ return False
12
+
13
+ # Check and install missing dependencies
14
+ required_packages = ["torch", "comet"]
15
+ missing_packages = []
16
 
17
+ for package in required_packages:
18
+ try:
19
+ pkg_resources.get_distribution(package)
20
+ except pkg_resources.DistributionNotFound:
21
+ missing_packages.append(package)
22
+
23
+ if missing_packages:
24
+ print(f"Missing packages: {', '.join(missing_packages)}")
25
+ for package in missing_packages:
26
+ if install_package(package):
27
+ print(f"Successfully installed {package}")
28
+ else:
29
+ print(f"Failed to install {package}")
30
+ print(f"Please install manually: pip install {' '.join(required_packages)}")
31
+ sys.exit(1)
32
+
33
+ # Now import torch and comet after ensuring they're installed
34
+ import torch
35
+ from comet import download_model, load_from_checkpoint
36
 
37
  def evaluate_translation(src_text, mt_text):
38
  if not hasattr(evaluate_translation, "model"):
39
+ model_path = download_model("wasanx/ComeTH")
40
+ evaluate_translation.model = load_from_checkpoint(model_path)
 
 
 
41
 
42
  translations = [{"src": src_text, "mt": mt_text}]
43
  results = evaluate_translation.model.predict(
44
  translations,
45
  batch_size=1,
46
+ gpus=0
47
  )
48
  return float(results['scores'][0])
49