rahul7star commited on
Commit
286064b
Β·
verified Β·
1 Parent(s): 44b2be8

Update app_flash1.py

Browse files
Files changed (1) hide show
  1. app_flash1.py +26 -11
app_flash1.py CHANGED
@@ -9,7 +9,11 @@ from datasets import load_dataset
9
  from transformers import AutoTokenizer, AutoModel
10
  from flashpack import FlashPackMixin
11
  from huggingface_hub import Repository, list_repo_files, hf_hub_download
 
12
 
 
 
 
13
  device = torch.device("cpu")
14
  torch.set_num_threads(4)
15
  print(f"πŸ”§ Using device: {device} (CPU-only mode)")
@@ -20,7 +24,7 @@ print(f"πŸ”§ Using device: {device} (CPU-only mode)")
20
  class GemmaTrainer(nn.Module, FlashPackMixin):
21
  def __init__(self):
22
  super().__init__()
23
- input_dim = 1536 # GPT-2 mean+max pooled embeddings
24
  hidden_dim = 1024
25
  output_dim = 1536
26
  self.fc1 = nn.Linear(input_dim, hidden_dim)
@@ -90,7 +94,7 @@ def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
90
 
91
  tokenizer, embed_model, encode_fn = build_encoder("gpt2")
92
 
93
- # Encode dataset embeddings
94
  s_list, l_list = [], []
95
  for i, item in enumerate(dataset):
96
  s_list.append(encode_fn(item["short_prompt"]))
@@ -100,6 +104,9 @@ def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
100
  gc.collect()
101
  short_emb, long_emb = torch.vstack(s_list), torch.vstack(l_list)
102
 
 
 
 
103
  model = GemmaTrainer()
104
  optimizer = optim.Adam(model.parameters(), lr=1e-3)
105
  loss_fn = nn.CosineSimilarity(dim=1)
@@ -118,16 +125,18 @@ def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
118
  break
119
 
120
  push_flashpack_model_to_hf(model, hf_repo, log_fn)
121
- tokenizer, embed_model, encode_fn = build_encoder("gpt2")
122
 
123
  @torch.no_grad()
124
  def enhance_fn(prompt, chat):
125
  chat = chat or []
126
- short_emb = encode_fn(prompt)
127
- mapped = model(short_emb.to(device)).cpu()
128
- long_prompt = f"🌟 Enhanced prompt (embedding-based) for: {prompt}"
 
 
 
129
  chat.append({"role": "user", "content": prompt})
130
- chat.append({"role": "assistant", "content": long_prompt})
131
  return chat
132
 
133
  return model, tokenizer, embed_model, enhance_fn, logs
@@ -157,14 +166,20 @@ def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
157
  model.eval()
158
  tokenizer, embed_model, encode_fn = build_encoder("gpt2")
159
 
 
 
 
 
160
  @torch.no_grad()
161
  def enhance_fn(prompt, chat):
162
  chat = chat or []
163
- short_emb = encode_fn(prompt).to(device)
164
- mapped = model(short_emb).cpu()
165
- long_prompt = f"🌟 Enhanced prompt (embedding-based) for: {prompt}"
 
 
166
  chat.append({"role": "user", "content": prompt})
167
- chat.append({"role": "assistant", "content": long_prompt})
168
  return chat
169
 
170
  return model, tokenizer, embed_model, enhance_fn
 
9
  from transformers import AutoTokenizer, AutoModel
10
  from flashpack import FlashPackMixin
11
  from huggingface_hub import Repository, list_repo_files, hf_hub_download
12
+ import torch.nn.functional as F
13
 
14
+ # ===========================
15
+ # Device
16
+ # ===========================
17
  device = torch.device("cpu")
18
  torch.set_num_threads(4)
19
  print(f"πŸ”§ Using device: {device} (CPU-only mode)")
 
24
  class GemmaTrainer(nn.Module, FlashPackMixin):
25
  def __init__(self):
26
  super().__init__()
27
+ input_dim = 1536
28
  hidden_dim = 1024
29
  output_dim = 1536
30
  self.fc1 = nn.Linear(input_dim, hidden_dim)
 
94
 
95
  tokenizer, embed_model, encode_fn = build_encoder("gpt2")
96
 
97
+ # Encode embeddings
98
  s_list, l_list = [], []
99
  for i, item in enumerate(dataset):
100
  s_list.append(encode_fn(item["short_prompt"]))
 
104
  gc.collect()
105
  short_emb, long_emb = torch.vstack(s_list), torch.vstack(l_list)
106
 
107
+ # Save embeddings & prompts for nearest-neighbor retrieval
108
+ train_prompts = [item["long_prompt"] for item in dataset]
109
+
110
  model = GemmaTrainer()
111
  optimizer = optim.Adam(model.parameters(), lr=1e-3)
112
  loss_fn = nn.CosineSimilarity(dim=1)
 
125
  break
126
 
127
  push_flashpack_model_to_hf(model, hf_repo, log_fn)
 
128
 
129
  @torch.no_grad()
130
  def enhance_fn(prompt, chat):
131
  chat = chat or []
132
+ short_emb_input = encode_fn(prompt)
133
+ mapped_emb = model(short_emb_input).cpu()
134
+ # Nearest neighbor
135
+ sims = F.cosine_similarity(mapped_emb, long_emb)
136
+ best_idx = sims.argmax().item()
137
+ long_prompt = train_prompts[best_idx]
138
  chat.append({"role": "user", "content": prompt})
139
+ chat.append({"role": "assistant", "content": f"🌟 Enhanced prompt: {long_prompt}"})
140
  return chat
141
 
142
  return model, tokenizer, embed_model, enhance_fn, logs
 
166
  model.eval()
167
  tokenizer, embed_model, encode_fn = build_encoder("gpt2")
168
 
169
+ # Dummy placeholders for nearest neighbor retrieval (replace with actual dataset if available)
170
+ long_emb = torch.randn(10, 1536) # placeholder embeddings
171
+ train_prompts = [f"Example long prompt {i}" for i in range(10)]
172
+
173
  @torch.no_grad()
174
  def enhance_fn(prompt, chat):
175
  chat = chat or []
176
+ short_emb_input = encode_fn(prompt)
177
+ mapped_emb = model(short_emb_input).cpu()
178
+ sims = F.cosine_similarity(mapped_emb, long_emb)
179
+ best_idx = sims.argmax().item()
180
+ long_prompt = train_prompts[best_idx]
181
  chat.append({"role": "user", "content": prompt})
182
+ chat.append({"role": "assistant", "content": f"🌟 Enhanced prompt: {long_prompt}"})
183
  return chat
184
 
185
  return model, tokenizer, embed_model, enhance_fn