rahul7star commited on
Commit
503e5c1
·
verified ·
1 Parent(s): 2a7aa21

Update app_flash.py

Browse files
Files changed (1) hide show
  1. app_flash.py +3 -54
app_flash.py CHANGED
@@ -9,7 +9,6 @@ from datasets import load_dataset
9
  from transformers import AutoTokenizer, AutoModel
10
  from flashpack import FlashPackMixin
11
  from huggingface_hub import Repository
12
- from huggingface_hub import Repository, list_repo_files, hf_hub_download
13
  from typing import Tuple
14
 
15
  # ============================================================
@@ -23,17 +22,14 @@ print(f"🔧 Using device: {device} (CPU-only)")
23
  # 1️⃣ FlashPack model with better hidden layers
24
  # ============================================================
25
  class GemmaTrainer(nn.Module, FlashPackMixin):
26
- def __init__(self):
27
  super().__init__()
28
- input_dim = 1536
29
- hidden_dim = 1024
30
- output_dim = 1536
31
  self.fc1 = nn.Linear(input_dim, hidden_dim)
32
  self.relu = nn.ReLU()
33
  self.fc2 = nn.Linear(hidden_dim, hidden_dim)
34
  self.fc3 = nn.Linear(hidden_dim, output_dim)
35
 
36
- def forward(self, x: torch.Tensor):
37
  x = self.fc1(x)
38
  x = self.relu(x)
39
  x = self.fc2(x)
@@ -157,53 +153,6 @@ def train_flashpack_model(
157
  # 5️⃣ Load or train model
158
  # ============================================================
159
  def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
160
- input_dim = 1536 # must match the input_dim used during training
161
- try:
162
- print(f"🔁 Attempting to load FlashPack model from {hf_repo}")
163
-
164
- # 1️⃣ Try local model first
165
- local_model_path = "model.flashpack"
166
- if os.path.exists(local_model_path):
167
- print("✅ Loading local model")
168
- else:
169
- # 2️⃣ Try Hugging Face
170
- files = list_repo_files(hf_repo)
171
- if "model.flashpack" in files:
172
- print("✅ Downloading model from HF")
173
- from huggingface_hub import hf_hub_download
174
- local_model_path = hf_hub_download(repo_id=hf_repo, filename="model.flashpack")
175
- else:
176
- print("🚫 No pretrained model found")
177
- return None, None, None, None
178
-
179
- # 3️⃣ Load model with correct input_dim
180
- model = GemmaTrainer(input_dim=input_dim).from_flashpack(local_model_path)
181
- model.eval()
182
-
183
- # 4️⃣ Build encoder
184
- tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=128)
185
-
186
- # 5️⃣ Enhancement function
187
- @torch.no_grad()
188
- def enhance_fn(prompt, chat):
189
- chat = chat or []
190
- short_emb = encode_fn(prompt).to(device)
191
- mapped = model(short_emb).cpu()
192
- long_prompt = f"🌟 Enhanced prompt: {prompt} (creatively expanded)"
193
- chat.append({"role": "user", "content": prompt})
194
- chat.append({"role": "assistant", "content": long_prompt})
195
- return chat
196
-
197
- return model, tokenizer, embed_model, enhance_fn
198
-
199
- except Exception as e:
200
- print(f"⚠️ Load failed: {e}")
201
- print("⏬ Training a new FlashPack model locally...")
202
- model, dataset, embed_model, tokenizer, long_embeddings = train_flashpack_model()
203
- push_flashpack_model_to_hf(model, hf_repo, log_fn=print)
204
- return model, tokenizer, embed_model, None
205
-
206
- def get_flashpack_model1(hf_repo="rahul7star/FlashPack"):
207
  try:
208
  print(f"🔁 Attempting to load FlashPack model from {hf_repo}")
209
  model = GemmaTrainer.from_flashpack(hf_repo)
@@ -280,4 +229,4 @@ with gr.Blocks(title="Prompt Enhancer – FlashPack (CPU)", theme=gr.themes.Soft
280
  # 9️⃣ Launch
281
  # ============================================================
282
  if __name__ == "__main__":
283
- demo.launch(show_error=True)
 
9
  from transformers import AutoTokenizer, AutoModel
10
  from flashpack import FlashPackMixin
11
  from huggingface_hub import Repository
 
12
  from typing import Tuple
13
 
14
  # ============================================================
 
22
  # 1️⃣ FlashPack model with better hidden layers
23
  # ============================================================
24
  class GemmaTrainer(nn.Module, FlashPackMixin):
25
+ def __init__(self, input_dim: int, hidden_dim: int = 1024, output_dim: int = 1536):
26
  super().__init__()
 
 
 
27
  self.fc1 = nn.Linear(input_dim, hidden_dim)
28
  self.relu = nn.ReLU()
29
  self.fc2 = nn.Linear(hidden_dim, hidden_dim)
30
  self.fc3 = nn.Linear(hidden_dim, output_dim)
31
 
32
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
33
  x = self.fc1(x)
34
  x = self.relu(x)
35
  x = self.fc2(x)
 
153
  # 5️⃣ Load or train model
154
  # ============================================================
155
  def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  try:
157
  print(f"🔁 Attempting to load FlashPack model from {hf_repo}")
158
  model = GemmaTrainer.from_flashpack(hf_repo)
 
229
  # 9️⃣ Launch
230
  # ============================================================
231
  if __name__ == "__main__":
232
+ demo.launch(show_error=True)