rahul7star's picture
Update app_flash1.py
286064b verified
import os
import gc
import torch
import torch.nn as nn
import torch.optim as optim
import tempfile
import gradio as gr
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
from flashpack import FlashPackMixin
from huggingface_hub import Repository, list_repo_files, hf_hub_download
import torch.nn.functional as F
# ===========================
# Device
# ===========================
device = torch.device("cpu")
torch.set_num_threads(4)
print(f"πŸ”§ Using device: {device} (CPU-only mode)")
# ===========================
# Model Definition
# ===========================
class GemmaTrainer(nn.Module, FlashPackMixin):
def __init__(self):
super().__init__()
input_dim = 1536
hidden_dim = 1024
output_dim = 1536
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, output_dim)
def forward(self, x: torch.Tensor):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
# ===========================
# Encoder
# ===========================
def build_encoder(model_name="gpt2", max_length: int = 128):
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
embed_model = AutoModel.from_pretrained(model_name).to(device)
embed_model.eval()
@torch.no_grad()
def encode(prompt: str) -> torch.Tensor:
inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
padding="max_length", max_length=max_length).to(device)
last_hidden = embed_model(**inputs).last_hidden_state
mean_pool = last_hidden.mean(dim=1)
max_pool, _ = last_hidden.max(dim=1)
return torch.cat([mean_pool, max_pool], dim=1).cpu() # doubled embedding
return tokenizer, embed_model, encode
# ===========================
# Push model to HF
# ===========================
def push_flashpack_model_to_hf(model, hf_repo, log_fn):
with tempfile.TemporaryDirectory() as tmp_dir:
log_fn(f"πŸ“¦ Preparing repository {hf_repo}...")
repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
model.save_flashpack(os.path.join(tmp_dir, "model.flashpack"), target_dtype=torch.float32)
with open(os.path.join(tmp_dir, "README.md"), "w") as f:
f.write("# FlashPack Model\nTrained locally and pushed to HF.")
log_fn("⏳ Pushing model to Hugging Face...")
repo.push_to_hub()
log_fn(f"βœ… Model pushed to {hf_repo}")
# ===========================
# Training
# ===========================
def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
hf_repo="rahul7star/FlashPack",
max_encode=1000):
logs = []
def log_fn(msg):
logs.append(msg)
print(msg)
log_fn("πŸ“¦ Loading dataset...")
dataset = load_dataset(dataset_name, split="train").select(range(max_encode))
log_fn(f"βœ… Loaded {len(dataset)} samples")
tokenizer, embed_model, encode_fn = build_encoder("gpt2")
# Encode embeddings
s_list, l_list = [], []
for i, item in enumerate(dataset):
s_list.append(encode_fn(item["short_prompt"]))
l_list.append(encode_fn(item["long_prompt"]))
if (i + 1) % 50 == 0:
log_fn(f" β†’ Encoded {i + 1}/{len(dataset)}")
gc.collect()
short_emb, long_emb = torch.vstack(s_list), torch.vstack(l_list)
# Save embeddings & prompts for nearest-neighbor retrieval
train_prompts = [item["long_prompt"] for item in dataset]
model = GemmaTrainer()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CosineSimilarity(dim=1)
log_fn("πŸš€ Training model...")
for epoch in range(20):
model.train()
optimizer.zero_grad()
preds = model(short_emb)
loss = 1 - loss_fn(preds, long_emb).mean()
loss.backward()
optimizer.step()
log_fn(f"Epoch {epoch+1}/20 | Loss: {loss.item():.5f}")
if loss.item() < 0.01:
log_fn("🎯 Early stopping.")
break
push_flashpack_model_to_hf(model, hf_repo, log_fn)
@torch.no_grad()
def enhance_fn(prompt, chat):
chat = chat or []
short_emb_input = encode_fn(prompt)
mapped_emb = model(short_emb_input).cpu()
# Nearest neighbor
sims = F.cosine_similarity(mapped_emb, long_emb)
best_idx = sims.argmax().item()
long_prompt = train_prompts[best_idx]
chat.append({"role": "user", "content": prompt})
chat.append({"role": "assistant", "content": f"🌟 Enhanced prompt: {long_prompt}"})
return chat
return model, tokenizer, embed_model, enhance_fn, logs
# ===========================
# Lazy Load / Get Model
# ===========================
def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
local_model_path = "model.flashpack"
if os.path.exists(local_model_path):
print("βœ… Loading local model")
else:
try:
files = list_repo_files(hf_repo)
if "model.flashpack" in files:
print("βœ… Downloading model from HF")
local_model_path = hf_hub_download(repo_id=hf_repo, filename="model.flashpack")
else:
print("🚫 No pretrained model found")
return None, None, None, None
except Exception as e:
print(f"⚠️ Error accessing HF: {e}")
return None, None, None, None
model = GemmaTrainer().from_flashpack(local_model_path)
model.eval()
tokenizer, embed_model, encode_fn = build_encoder("gpt2")
# Dummy placeholders for nearest neighbor retrieval (replace with actual dataset if available)
long_emb = torch.randn(10, 1536) # placeholder embeddings
train_prompts = [f"Example long prompt {i}" for i in range(10)]
@torch.no_grad()
def enhance_fn(prompt, chat):
chat = chat or []
short_emb_input = encode_fn(prompt)
mapped_emb = model(short_emb_input).cpu()
sims = F.cosine_similarity(mapped_emb, long_emb)
best_idx = sims.argmax().item()
long_prompt = train_prompts[best_idx]
chat.append({"role": "user", "content": prompt})
chat.append({"role": "assistant", "content": f"🌟 Enhanced prompt: {long_prompt}"})
return chat
return model, tokenizer, embed_model, enhance_fn
# ===========================
# Gradio UI
# ===========================
with gr.Blocks(title="✨ FlashPack Prompt Enhancer") as demo:
gr.Markdown("## 🧠 FlashPack Prompt Enhancer (CPU)\nShort β†’ Long prompt expander")
chatbot = gr.Chatbot(height=400, type="messages")
user_input = gr.Textbox(label="Your prompt")
send_btn = gr.Button("πŸš€ Enhance Prompt", variant="primary")
clear_btn = gr.Button("🧹 Clear")
train_btn = gr.Button("🧩 Train Model", variant="secondary")
log_output = gr.Textbox(label="Logs", lines=15)
# Lazy load model
model, tokenizer, embed_model, enhance_fn = get_flashpack_model()
logs = []
if enhance_fn is None:
def enhance_fn(prompt, chat):
chat = chat or []
chat.append({"role": "assistant",
"content": "⚠️ No pretrained model found. Please click 'Train Model' to create one."})
return chat
logs.append("⚠️ No pretrained model found. Ready to train.")
else:
logs.append("βœ… Model loaded β€” ready to enhance.")
# Button callbacks
send_btn.click(enhance_fn, [user_input, chatbot], chatbot)
user_input.submit(enhance_fn, [user_input, chatbot], chatbot)
clear_btn.click(lambda: [], None, chatbot)
def retrain():
global model, tokenizer, embed_model, enhance_fn, logs
logs = ["πŸš€ Training model, please wait..."]
model, tokenizer, embed_model, enhance_fn, train_logs = train_flashpack_model()
logs.extend(train_logs)
return gr.Textbox.update(value="\n".join(logs))
train_btn.click(retrain, None, log_output)
if __name__ == "__main__":
demo.launch(show_error=True)