File size: 8,357 Bytes
25b932e f3268bd 31737db 25b932e d9e93e9 25b932e f3268bd 286064b 25b932e 286064b 25b932e d9e93e9 5ee9a29 25b932e 44b2be8 d9e93e9 286064b 44b2be8 d9e93e9 25b932e 31737db f3268bd 25b932e d9e93e9 5ee9a29 44b2be8 9f53409 44b2be8 25b932e 9f53409 d9e93e9 9f53409 25b932e 44b2be8 9f53409 d9e93e9 5ee9a29 9490127 25b932e 9490127 25b932e a69d346 f3268bd 9490127 25b932e 9490127 f3268bd 5ee9a29 f3268bd 9490127 44b2be8 9490127 5ee9a29 9490127 f3268bd 9f53409 286064b b2330bc 286064b 6c0c98e d9e93e9 f3268bd d9e93e9 9490127 f3268bd d9e93e9 f3268bd 5ee9a29 f3268bd 9490127 f3268bd 9490127 25b932e 9490127 a69d346 286064b a69d346 286064b a69d346 9490127 d9e93e9 5ee9a29 a69d346 5ee9a29 f3268bd 6c0c98e 8f4e2a0 588725c 5ee9a29 588725c 5ee9a29 588725c 8143e5c 44b2be8 588725c f3268bd 286064b 25b932e 6c0c98e f3268bd 286064b f3268bd 286064b f3268bd 6c0c98e 588725c 9f53409 5ee9a29 f3268bd d9e93e9 49fa7d4 f3268bd 25b932e f3268bd 9490127 25b932e a69d346 6c0c98e 9490127 f3268bd a69d346 b2330bc a69d346 9490127 a69d346 9490127 a69d346 6c0c98e f3268bd 5ee9a29 9490127 5ee9a29 9490127 25b932e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
import os
import gc
import torch
import torch.nn as nn
import torch.optim as optim
import tempfile
import gradio as gr
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
from flashpack import FlashPackMixin
from huggingface_hub import Repository, list_repo_files, hf_hub_download
import torch.nn.functional as F
# ===========================
# Device
# ===========================
device = torch.device("cpu")
torch.set_num_threads(4)
print(f"π§ Using device: {device} (CPU-only mode)")
# ===========================
# Model Definition
# ===========================
class GemmaTrainer(nn.Module, FlashPackMixin):
def __init__(self):
super().__init__()
input_dim = 1536
hidden_dim = 1024
output_dim = 1536
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, output_dim)
def forward(self, x: torch.Tensor):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
# ===========================
# Encoder
# ===========================
def build_encoder(model_name="gpt2", max_length: int = 128):
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
embed_model = AutoModel.from_pretrained(model_name).to(device)
embed_model.eval()
@torch.no_grad()
def encode(prompt: str) -> torch.Tensor:
inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
padding="max_length", max_length=max_length).to(device)
last_hidden = embed_model(**inputs).last_hidden_state
mean_pool = last_hidden.mean(dim=1)
max_pool, _ = last_hidden.max(dim=1)
return torch.cat([mean_pool, max_pool], dim=1).cpu() # doubled embedding
return tokenizer, embed_model, encode
# ===========================
# Push model to HF
# ===========================
def push_flashpack_model_to_hf(model, hf_repo, log_fn):
with tempfile.TemporaryDirectory() as tmp_dir:
log_fn(f"π¦ Preparing repository {hf_repo}...")
repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
model.save_flashpack(os.path.join(tmp_dir, "model.flashpack"), target_dtype=torch.float32)
with open(os.path.join(tmp_dir, "README.md"), "w") as f:
f.write("# FlashPack Model\nTrained locally and pushed to HF.")
log_fn("β³ Pushing model to Hugging Face...")
repo.push_to_hub()
log_fn(f"β
Model pushed to {hf_repo}")
# ===========================
# Training
# ===========================
def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
hf_repo="rahul7star/FlashPack",
max_encode=1000):
logs = []
def log_fn(msg):
logs.append(msg)
print(msg)
log_fn("π¦ Loading dataset...")
dataset = load_dataset(dataset_name, split="train").select(range(max_encode))
log_fn(f"β
Loaded {len(dataset)} samples")
tokenizer, embed_model, encode_fn = build_encoder("gpt2")
# Encode embeddings
s_list, l_list = [], []
for i, item in enumerate(dataset):
s_list.append(encode_fn(item["short_prompt"]))
l_list.append(encode_fn(item["long_prompt"]))
if (i + 1) % 50 == 0:
log_fn(f" β Encoded {i + 1}/{len(dataset)}")
gc.collect()
short_emb, long_emb = torch.vstack(s_list), torch.vstack(l_list)
# Save embeddings & prompts for nearest-neighbor retrieval
train_prompts = [item["long_prompt"] for item in dataset]
model = GemmaTrainer()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CosineSimilarity(dim=1)
log_fn("π Training model...")
for epoch in range(20):
model.train()
optimizer.zero_grad()
preds = model(short_emb)
loss = 1 - loss_fn(preds, long_emb).mean()
loss.backward()
optimizer.step()
log_fn(f"Epoch {epoch+1}/20 | Loss: {loss.item():.5f}")
if loss.item() < 0.01:
log_fn("π― Early stopping.")
break
push_flashpack_model_to_hf(model, hf_repo, log_fn)
@torch.no_grad()
def enhance_fn(prompt, chat):
chat = chat or []
short_emb_input = encode_fn(prompt)
mapped_emb = model(short_emb_input).cpu()
# Nearest neighbor
sims = F.cosine_similarity(mapped_emb, long_emb)
best_idx = sims.argmax().item()
long_prompt = train_prompts[best_idx]
chat.append({"role": "user", "content": prompt})
chat.append({"role": "assistant", "content": f"π Enhanced prompt: {long_prompt}"})
return chat
return model, tokenizer, embed_model, enhance_fn, logs
# ===========================
# Lazy Load / Get Model
# ===========================
def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
local_model_path = "model.flashpack"
if os.path.exists(local_model_path):
print("β
Loading local model")
else:
try:
files = list_repo_files(hf_repo)
if "model.flashpack" in files:
print("β
Downloading model from HF")
local_model_path = hf_hub_download(repo_id=hf_repo, filename="model.flashpack")
else:
print("π« No pretrained model found")
return None, None, None, None
except Exception as e:
print(f"β οΈ Error accessing HF: {e}")
return None, None, None, None
model = GemmaTrainer().from_flashpack(local_model_path)
model.eval()
tokenizer, embed_model, encode_fn = build_encoder("gpt2")
# Dummy placeholders for nearest neighbor retrieval (replace with actual dataset if available)
long_emb = torch.randn(10, 1536) # placeholder embeddings
train_prompts = [f"Example long prompt {i}" for i in range(10)]
@torch.no_grad()
def enhance_fn(prompt, chat):
chat = chat or []
short_emb_input = encode_fn(prompt)
mapped_emb = model(short_emb_input).cpu()
sims = F.cosine_similarity(mapped_emb, long_emb)
best_idx = sims.argmax().item()
long_prompt = train_prompts[best_idx]
chat.append({"role": "user", "content": prompt})
chat.append({"role": "assistant", "content": f"π Enhanced prompt: {long_prompt}"})
return chat
return model, tokenizer, embed_model, enhance_fn
# ===========================
# Gradio UI
# ===========================
with gr.Blocks(title="β¨ FlashPack Prompt Enhancer") as demo:
gr.Markdown("## π§ FlashPack Prompt Enhancer (CPU)\nShort β Long prompt expander")
chatbot = gr.Chatbot(height=400, type="messages")
user_input = gr.Textbox(label="Your prompt")
send_btn = gr.Button("π Enhance Prompt", variant="primary")
clear_btn = gr.Button("π§Ή Clear")
train_btn = gr.Button("π§© Train Model", variant="secondary")
log_output = gr.Textbox(label="Logs", lines=15)
# Lazy load model
model, tokenizer, embed_model, enhance_fn = get_flashpack_model()
logs = []
if enhance_fn is None:
def enhance_fn(prompt, chat):
chat = chat or []
chat.append({"role": "assistant",
"content": "β οΈ No pretrained model found. Please click 'Train Model' to create one."})
return chat
logs.append("β οΈ No pretrained model found. Ready to train.")
else:
logs.append("β
Model loaded β ready to enhance.")
# Button callbacks
send_btn.click(enhance_fn, [user_input, chatbot], chatbot)
user_input.submit(enhance_fn, [user_input, chatbot], chatbot)
clear_btn.click(lambda: [], None, chatbot)
def retrain():
global model, tokenizer, embed_model, enhance_fn, logs
logs = ["π Training model, please wait..."]
model, tokenizer, embed_model, enhance_fn, train_logs = train_flashpack_model()
logs.extend(train_logs)
return gr.Textbox.update(value="\n".join(logs))
train_btn.click(retrain, None, log_output)
if __name__ == "__main__":
demo.launch(show_error=True)
|