|
|
import os |
|
|
import gc |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
import tempfile |
|
|
import gradio as gr |
|
|
from datasets import load_dataset |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
from flashpack import FlashPackMixin |
|
|
from huggingface_hub import Repository, list_repo_files, hf_hub_download |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = torch.device("cpu") |
|
|
torch.set_num_threads(4) |
|
|
print(f"π§ Using device: {device} (CPU-only mode)") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GemmaTrainer(nn.Module, FlashPackMixin): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
input_dim = 1536 |
|
|
hidden_dim = 1024 |
|
|
output_dim = 1536 |
|
|
self.fc1 = nn.Linear(input_dim, hidden_dim) |
|
|
self.relu = nn.ReLU() |
|
|
self.fc2 = nn.Linear(hidden_dim, hidden_dim) |
|
|
self.fc3 = nn.Linear(hidden_dim, output_dim) |
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
x = self.fc1(x) |
|
|
x = self.relu(x) |
|
|
x = self.fc2(x) |
|
|
x = self.relu(x) |
|
|
x = self.fc3(x) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_encoder(model_name="gpt2", max_length: int = 128): |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
embed_model = AutoModel.from_pretrained(model_name).to(device) |
|
|
embed_model.eval() |
|
|
|
|
|
@torch.no_grad() |
|
|
def encode(prompt: str) -> torch.Tensor: |
|
|
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, |
|
|
padding="max_length", max_length=max_length).to(device) |
|
|
last_hidden = embed_model(**inputs).last_hidden_state |
|
|
mean_pool = last_hidden.mean(dim=1) |
|
|
max_pool, _ = last_hidden.max(dim=1) |
|
|
return torch.cat([mean_pool, max_pool], dim=1).cpu() |
|
|
|
|
|
return tokenizer, embed_model, encode |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def push_flashpack_model_to_hf(model, hf_repo, log_fn): |
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
|
log_fn(f"π¦ Preparing repository {hf_repo}...") |
|
|
repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True) |
|
|
model.save_flashpack(os.path.join(tmp_dir, "model.flashpack"), target_dtype=torch.float32) |
|
|
with open(os.path.join(tmp_dir, "README.md"), "w") as f: |
|
|
f.write("# FlashPack Model\nTrained locally and pushed to HF.") |
|
|
log_fn("β³ Pushing model to Hugging Face...") |
|
|
repo.push_to_hub() |
|
|
log_fn(f"β
Model pushed to {hf_repo}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset", |
|
|
hf_repo="rahul7star/FlashPack", |
|
|
max_encode=1000): |
|
|
logs = [] |
|
|
|
|
|
def log_fn(msg): |
|
|
logs.append(msg) |
|
|
print(msg) |
|
|
|
|
|
log_fn("π¦ Loading dataset...") |
|
|
dataset = load_dataset(dataset_name, split="train").select(range(max_encode)) |
|
|
log_fn(f"β
Loaded {len(dataset)} samples") |
|
|
|
|
|
tokenizer, embed_model, encode_fn = build_encoder("gpt2") |
|
|
|
|
|
|
|
|
s_list, l_list = [], [] |
|
|
for i, item in enumerate(dataset): |
|
|
s_list.append(encode_fn(item["short_prompt"])) |
|
|
l_list.append(encode_fn(item["long_prompt"])) |
|
|
if (i + 1) % 50 == 0: |
|
|
log_fn(f" β Encoded {i + 1}/{len(dataset)}") |
|
|
gc.collect() |
|
|
short_emb, long_emb = torch.vstack(s_list), torch.vstack(l_list) |
|
|
|
|
|
|
|
|
train_prompts = [item["long_prompt"] for item in dataset] |
|
|
|
|
|
model = GemmaTrainer() |
|
|
optimizer = optim.Adam(model.parameters(), lr=1e-3) |
|
|
loss_fn = nn.CosineSimilarity(dim=1) |
|
|
|
|
|
log_fn("π Training model...") |
|
|
for epoch in range(20): |
|
|
model.train() |
|
|
optimizer.zero_grad() |
|
|
preds = model(short_emb) |
|
|
loss = 1 - loss_fn(preds, long_emb).mean() |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
log_fn(f"Epoch {epoch+1}/20 | Loss: {loss.item():.5f}") |
|
|
if loss.item() < 0.01: |
|
|
log_fn("π― Early stopping.") |
|
|
break |
|
|
|
|
|
push_flashpack_model_to_hf(model, hf_repo, log_fn) |
|
|
|
|
|
@torch.no_grad() |
|
|
def enhance_fn(prompt, chat): |
|
|
chat = chat or [] |
|
|
short_emb_input = encode_fn(prompt) |
|
|
mapped_emb = model(short_emb_input).cpu() |
|
|
|
|
|
sims = F.cosine_similarity(mapped_emb, long_emb) |
|
|
best_idx = sims.argmax().item() |
|
|
long_prompt = train_prompts[best_idx] |
|
|
chat.append({"role": "user", "content": prompt}) |
|
|
chat.append({"role": "assistant", "content": f"π Enhanced prompt: {long_prompt}"}) |
|
|
return chat |
|
|
|
|
|
return model, tokenizer, embed_model, enhance_fn, logs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_flashpack_model(hf_repo="rahul7star/FlashPack"): |
|
|
local_model_path = "model.flashpack" |
|
|
|
|
|
if os.path.exists(local_model_path): |
|
|
print("β
Loading local model") |
|
|
else: |
|
|
try: |
|
|
files = list_repo_files(hf_repo) |
|
|
if "model.flashpack" in files: |
|
|
print("β
Downloading model from HF") |
|
|
local_model_path = hf_hub_download(repo_id=hf_repo, filename="model.flashpack") |
|
|
else: |
|
|
print("π« No pretrained model found") |
|
|
return None, None, None, None |
|
|
except Exception as e: |
|
|
print(f"β οΈ Error accessing HF: {e}") |
|
|
return None, None, None, None |
|
|
|
|
|
model = GemmaTrainer().from_flashpack(local_model_path) |
|
|
model.eval() |
|
|
tokenizer, embed_model, encode_fn = build_encoder("gpt2") |
|
|
|
|
|
|
|
|
long_emb = torch.randn(10, 1536) |
|
|
train_prompts = [f"Example long prompt {i}" for i in range(10)] |
|
|
|
|
|
@torch.no_grad() |
|
|
def enhance_fn(prompt, chat): |
|
|
chat = chat or [] |
|
|
short_emb_input = encode_fn(prompt) |
|
|
mapped_emb = model(short_emb_input).cpu() |
|
|
sims = F.cosine_similarity(mapped_emb, long_emb) |
|
|
best_idx = sims.argmax().item() |
|
|
long_prompt = train_prompts[best_idx] |
|
|
chat.append({"role": "user", "content": prompt}) |
|
|
chat.append({"role": "assistant", "content": f"π Enhanced prompt: {long_prompt}"}) |
|
|
return chat |
|
|
|
|
|
return model, tokenizer, embed_model, enhance_fn |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="β¨ FlashPack Prompt Enhancer") as demo: |
|
|
gr.Markdown("## π§ FlashPack Prompt Enhancer (CPU)\nShort β Long prompt expander") |
|
|
|
|
|
chatbot = gr.Chatbot(height=400, type="messages") |
|
|
user_input = gr.Textbox(label="Your prompt") |
|
|
send_btn = gr.Button("π Enhance Prompt", variant="primary") |
|
|
clear_btn = gr.Button("π§Ή Clear") |
|
|
train_btn = gr.Button("π§© Train Model", variant="secondary") |
|
|
log_output = gr.Textbox(label="Logs", lines=15) |
|
|
|
|
|
|
|
|
model, tokenizer, embed_model, enhance_fn = get_flashpack_model() |
|
|
logs = [] |
|
|
|
|
|
if enhance_fn is None: |
|
|
def enhance_fn(prompt, chat): |
|
|
chat = chat or [] |
|
|
chat.append({"role": "assistant", |
|
|
"content": "β οΈ No pretrained model found. Please click 'Train Model' to create one."}) |
|
|
return chat |
|
|
logs.append("β οΈ No pretrained model found. Ready to train.") |
|
|
else: |
|
|
logs.append("β
Model loaded β ready to enhance.") |
|
|
|
|
|
|
|
|
send_btn.click(enhance_fn, [user_input, chatbot], chatbot) |
|
|
user_input.submit(enhance_fn, [user_input, chatbot], chatbot) |
|
|
clear_btn.click(lambda: [], None, chatbot) |
|
|
|
|
|
def retrain(): |
|
|
global model, tokenizer, embed_model, enhance_fn, logs |
|
|
logs = ["π Training model, please wait..."] |
|
|
model, tokenizer, embed_model, enhance_fn, train_logs = train_flashpack_model() |
|
|
logs.extend(train_logs) |
|
|
return gr.Textbox.update(value="\n".join(logs)) |
|
|
|
|
|
train_btn.click(retrain, None, log_output) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(show_error=True) |
|
|
|