File size: 8,357 Bytes
25b932e
f3268bd
31737db
25b932e
 
 
 
d9e93e9
 
25b932e
f3268bd
286064b
25b932e
286064b
 
 
25b932e
 
 
d9e93e9
5ee9a29
 
 
25b932e
44b2be8
d9e93e9
286064b
44b2be8
 
d9e93e9
 
25b932e
 
31737db
f3268bd
25b932e
 
 
 
 
 
d9e93e9
5ee9a29
 
 
44b2be8
9f53409
 
 
44b2be8
25b932e
9f53409
d9e93e9
9f53409
 
25b932e
 
44b2be8
 
 
 
 
9f53409
d9e93e9
5ee9a29
 
 
9490127
25b932e
9490127
25b932e
a69d346
f3268bd
 
9490127
25b932e
9490127
f3268bd
5ee9a29
 
 
f3268bd
 
 
9490127
 
 
 
 
44b2be8
9490127
5ee9a29
9490127
 
f3268bd
9f53409
286064b
b2330bc
 
 
 
 
 
 
 
 
286064b
 
 
6c0c98e
d9e93e9
f3268bd
d9e93e9
9490127
f3268bd
d9e93e9
f3268bd
5ee9a29
 
f3268bd
 
9490127
f3268bd
9490127
25b932e
 
9490127
a69d346
 
 
 
286064b
 
 
 
 
 
a69d346
286064b
a69d346
 
9490127
d9e93e9
5ee9a29
a69d346
5ee9a29
f3268bd
6c0c98e
8f4e2a0
588725c
 
 
5ee9a29
588725c
 
 
 
 
 
 
5ee9a29
588725c
 
8143e5c
44b2be8
588725c
 
f3268bd
286064b
 
 
 
25b932e
6c0c98e
f3268bd
286064b
 
 
 
 
f3268bd
286064b
f3268bd
6c0c98e
588725c
9f53409
5ee9a29
 
 
f3268bd
 
d9e93e9
49fa7d4
f3268bd
25b932e
f3268bd
 
9490127
25b932e
a69d346
6c0c98e
9490127
f3268bd
a69d346
 
 
b2330bc
 
a69d346
9490127
a69d346
9490127
a69d346
 
6c0c98e
 
f3268bd
5ee9a29
 
9490127
 
 
 
 
5ee9a29
9490127
25b932e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
import os
import gc
import torch
import torch.nn as nn
import torch.optim as optim
import tempfile
import gradio as gr
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
from flashpack import FlashPackMixin
from huggingface_hub import Repository, list_repo_files, hf_hub_download
import torch.nn.functional as F

# ===========================
# Device
# ===========================
device = torch.device("cpu")
torch.set_num_threads(4)
print(f"πŸ”§ Using device: {device} (CPU-only mode)")

# ===========================
# Model Definition
# ===========================
class GemmaTrainer(nn.Module, FlashPackMixin):
    def __init__(self):
        super().__init__()
        input_dim = 1536
        hidden_dim = 1024
        output_dim = 1536
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x: torch.Tensor):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x

# ===========================
# Encoder
# ===========================
def build_encoder(model_name="gpt2", max_length: int = 128):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    embed_model = AutoModel.from_pretrained(model_name).to(device)
    embed_model.eval()

    @torch.no_grad()
    def encode(prompt: str) -> torch.Tensor:
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
                           padding="max_length", max_length=max_length).to(device)
        last_hidden = embed_model(**inputs).last_hidden_state
        mean_pool = last_hidden.mean(dim=1)
        max_pool, _ = last_hidden.max(dim=1)
        return torch.cat([mean_pool, max_pool], dim=1).cpu()  # doubled embedding

    return tokenizer, embed_model, encode

# ===========================
# Push model to HF
# ===========================
def push_flashpack_model_to_hf(model, hf_repo, log_fn):
    with tempfile.TemporaryDirectory() as tmp_dir:
        log_fn(f"πŸ“¦ Preparing repository {hf_repo}...")
        repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
        model.save_flashpack(os.path.join(tmp_dir, "model.flashpack"), target_dtype=torch.float32)
        with open(os.path.join(tmp_dir, "README.md"), "w") as f:
            f.write("# FlashPack Model\nTrained locally and pushed to HF.")
        log_fn("⏳ Pushing model to Hugging Face...")
        repo.push_to_hub()
        log_fn(f"βœ… Model pushed to {hf_repo}")

# ===========================
# Training
# ===========================
def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
                          hf_repo="rahul7star/FlashPack",
                          max_encode=1000):
    logs = []

    def log_fn(msg):
        logs.append(msg)
        print(msg)

    log_fn("πŸ“¦ Loading dataset...")
    dataset = load_dataset(dataset_name, split="train").select(range(max_encode))
    log_fn(f"βœ… Loaded {len(dataset)} samples")

    tokenizer, embed_model, encode_fn = build_encoder("gpt2")

    # Encode embeddings
    s_list, l_list = [], []
    for i, item in enumerate(dataset):
        s_list.append(encode_fn(item["short_prompt"]))
        l_list.append(encode_fn(item["long_prompt"]))
        if (i + 1) % 50 == 0:
            log_fn(f"  β†’ Encoded {i + 1}/{len(dataset)}")
            gc.collect()
    short_emb, long_emb = torch.vstack(s_list), torch.vstack(l_list)

    # Save embeddings & prompts for nearest-neighbor retrieval
    train_prompts = [item["long_prompt"] for item in dataset]

    model = GemmaTrainer()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.CosineSimilarity(dim=1)

    log_fn("πŸš€ Training model...")
    for epoch in range(20):
        model.train()
        optimizer.zero_grad()
        preds = model(short_emb)
        loss = 1 - loss_fn(preds, long_emb).mean()
        loss.backward()
        optimizer.step()
        log_fn(f"Epoch {epoch+1}/20 | Loss: {loss.item():.5f}")
        if loss.item() < 0.01:
            log_fn("🎯 Early stopping.")
            break

    push_flashpack_model_to_hf(model, hf_repo, log_fn)

    @torch.no_grad()
    def enhance_fn(prompt, chat):
        chat = chat or []
        short_emb_input = encode_fn(prompt)
        mapped_emb = model(short_emb_input).cpu()
        # Nearest neighbor
        sims = F.cosine_similarity(mapped_emb, long_emb)
        best_idx = sims.argmax().item()
        long_prompt = train_prompts[best_idx]
        chat.append({"role": "user", "content": prompt})
        chat.append({"role": "assistant", "content": f"🌟 Enhanced prompt: {long_prompt}"})
        return chat

    return model, tokenizer, embed_model, enhance_fn, logs

# ===========================
# Lazy Load / Get Model
# ===========================
def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
    local_model_path = "model.flashpack"

    if os.path.exists(local_model_path):
        print("βœ… Loading local model")
    else:
        try:
            files = list_repo_files(hf_repo)
            if "model.flashpack" in files:
                print("βœ… Downloading model from HF")
                local_model_path = hf_hub_download(repo_id=hf_repo, filename="model.flashpack")
            else:
                print("🚫 No pretrained model found")
                return None, None, None, None
        except Exception as e:
            print(f"⚠️ Error accessing HF: {e}")
            return None, None, None, None

    model = GemmaTrainer().from_flashpack(local_model_path)
    model.eval()
    tokenizer, embed_model, encode_fn = build_encoder("gpt2")

    # Dummy placeholders for nearest neighbor retrieval (replace with actual dataset if available)
    long_emb = torch.randn(10, 1536)  # placeholder embeddings
    train_prompts = [f"Example long prompt {i}" for i in range(10)]

    @torch.no_grad()
    def enhance_fn(prompt, chat):
        chat = chat or []
        short_emb_input = encode_fn(prompt)
        mapped_emb = model(short_emb_input).cpu()
        sims = F.cosine_similarity(mapped_emb, long_emb)
        best_idx = sims.argmax().item()
        long_prompt = train_prompts[best_idx]
        chat.append({"role": "user", "content": prompt})
        chat.append({"role": "assistant", "content": f"🌟 Enhanced prompt: {long_prompt}"})
        return chat

    return model, tokenizer, embed_model, enhance_fn

# ===========================
# Gradio UI
# ===========================
with gr.Blocks(title="✨ FlashPack Prompt Enhancer") as demo:
    gr.Markdown("## 🧠 FlashPack Prompt Enhancer (CPU)\nShort β†’ Long prompt expander")

    chatbot = gr.Chatbot(height=400, type="messages")
    user_input = gr.Textbox(label="Your prompt")
    send_btn = gr.Button("πŸš€ Enhance Prompt", variant="primary")
    clear_btn = gr.Button("🧹 Clear")
    train_btn = gr.Button("🧩 Train Model", variant="secondary")
    log_output = gr.Textbox(label="Logs", lines=15)

    # Lazy load model
    model, tokenizer, embed_model, enhance_fn = get_flashpack_model()
    logs = []

    if enhance_fn is None:
        def enhance_fn(prompt, chat):
            chat = chat or []
            chat.append({"role": "assistant",
                         "content": "⚠️ No pretrained model found. Please click 'Train Model' to create one."})
            return chat
        logs.append("⚠️ No pretrained model found. Ready to train.")
    else:
        logs.append("βœ… Model loaded β€” ready to enhance.")

    # Button callbacks
    send_btn.click(enhance_fn, [user_input, chatbot], chatbot)
    user_input.submit(enhance_fn, [user_input, chatbot], chatbot)
    clear_btn.click(lambda: [], None, chatbot)

    def retrain():
        global model, tokenizer, embed_model, enhance_fn, logs
        logs = ["πŸš€ Training model, please wait..."]
        model, tokenizer, embed_model, enhance_fn, train_logs = train_flashpack_model()
        logs.extend(train_logs)
        return gr.Textbox.update(value="\n".join(logs))

    train_btn.click(retrain, None, log_output)

if __name__ == "__main__":
    demo.launch(show_error=True)