AshwinSankar commited on
Commit
24bcef0
·
verified ·
1 Parent(s): dd75568

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -66
app.py CHANGED
@@ -28,72 +28,29 @@ SARVAM_LANGUAGES = INDIC_LANGUAGES
28
  TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
29
  DEVICE_MAP = "auto" if torch.cuda.is_available() else None
30
 
31
- class ModelManager:
32
- def __init__(self):
33
- self.indictrans_model = None
34
- self.indictrans_tokenizer = None
35
- self.sarvam_model = None
36
- self.sarvam_tokenizer = None
37
- self.current_model = None
38
-
39
- def load_indictrans_model(self):
40
- if self.indictrans_model is None:
41
- try:
42
- self.indictrans_model = AutoModelForCausalLM.from_pretrained(
43
- "ai4bharat/IndicTrans3-beta",
44
- torch_dtype=TORCH_DTYPE,
45
- device_map=DEVICE_MAP,
46
- token=HF_TOKEN,
47
- low_cpu_mem_usage=True,
48
- trust_remote_code=True
49
- )
50
- self.indictrans_tokenizer = AutoTokenizer.from_pretrained(
51
- "ai4bharat/IndicTrans3-beta",
52
- trust_remote_code=True
53
- )
54
- # Enable optimizations
55
- if hasattr(self.indictrans_model, 'eval'):
56
- self.indictrans_model.eval()
57
- if torch.cuda.is_available():
58
- torch.cuda.empty_cache()
59
- except Exception as e:
60
- print(f"Error loading IndicTrans model: {e}")
61
-
62
- def load_sarvam_model(self):
63
- if self.sarvam_model is None:
64
- try:
65
- self.sarvam_model = AutoModelForCausalLM.from_pretrained(
66
- "sarvamai/sarvam-translate",
67
- torch_dtype=TORCH_DTYPE,
68
- device_map=DEVICE_MAP,
69
- token=HF_TOKEN,
70
- low_cpu_mem_usage=True,
71
- trust_remote_code=True
72
- )
73
- self.sarvam_tokenizer = AutoTokenizer.from_pretrained(
74
- "sarvamai/sarvam-translate",
75
- trust_remote_code=True
76
- )
77
- # Enable optimizations
78
- if hasattr(self.sarvam_model, 'eval'):
79
- self.sarvam_model.eval()
80
- if torch.cuda.is_available():
81
- torch.cuda.empty_cache()
82
- except Exception as e:
83
- print(f"Error loading Sarvam model: {e}")
84
-
85
- def get_model_and_tokenizer(self, model_type):
86
- if model_type == "indictrans":
87
- if self.indictrans_model is None:
88
- self.load_indictrans_model()
89
- return self.indictrans_model, self.indictrans_tokenizer
90
- else: # sarvam
91
- if self.sarvam_model is None:
92
- self.load_sarvam_model()
93
- return self.sarvam_model, self.sarvam_tokenizer
94
 
95
- # Global model manager
96
- model_manager = ModelManager()
97
 
98
  def format_message_for_translation(message, target_lang):
99
  return f"Translate the following text to {target_lang}: {message}"
@@ -175,7 +132,10 @@ def translate_message(
175
  model_type: str = "indictrans"
176
  ) -> Iterator[str]:
177
 
178
- model, tokenizer = model_manager.get_model_and_tokenizer(model_type)
 
 
 
179
 
180
  if model is None or tokenizer is None:
181
  yield "Error: Model failed to load. Please try again."
 
28
  TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
29
  DEVICE_MAP = "auto" if torch.cuda.is_available() else None
30
 
31
+ indictrans_model = AutoModelForCausalLM.from_pretrained(
32
+ "ai4bharat/IndicTrans3-beta",
33
+ torch_dtype=TORCH_DTYPE,
34
+ device_map=DEVICE_MAP,
35
+ token=HF_TOKEN,
36
+ low_cpu_mem_usage=True,
37
+ trust_remote_code=True
38
+ )
39
+
40
+ sarvam_model = AutoModelForCausalLM.from_pretrained(
41
+ "sarvamai/sarvam-translate",
42
+ torch_dtype=TORCH_DTYPE,
43
+ device_map=DEVICE_MAP,
44
+ token=HF_TOKEN,
45
+ low_cpu_mem_usage=True,
46
+ trust_remote_code=True
47
+ )
48
+
49
+ tokenizer = AutoTokenizer.from_pretrained(
50
+ "ai4bharat/IndicTrans3-beta",
51
+ trust_remote_code=True
52
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
 
 
54
 
55
  def format_message_for_translation(message, target_lang):
56
  return f"Translate the following text to {target_lang}: {message}"
 
132
  model_type: str = "indictrans"
133
  ) -> Iterator[str]:
134
 
135
+ if model_type == "indictrans":
136
+ model = indictrans_model
137
+ elif model_type == "sarvam":
138
+ model = sarvam_model
139
 
140
  if model is None or tokenizer is None:
141
  yield "Error: Model failed to load. Please try again."