File size: 2,207 Bytes
c6e5e78
e81eb6e
c6e5e78
55c2118
c6e5e78
 
 
 
 
 
 
 
 
 
 
 
9b3af55
c6e5e78
9b3af55
6983490
f69eeb3
c6e5e78
f69eeb3
c6e5e78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
def444b
c6e5e78
 
 
 
1b17bcf
c6e5e78
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
from transformers import AutoTokenizer
import gradio as gr
from model import LlamaForCausalLM # Import your custom model class

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token else "[PAD]"

# Initialize model with reduced parameters (135M config)
model = LlamaForCausalLM(
    vocab_size=tokenizer.vocab_size,
    dim=576,
    num_layers=30,
    hidden_dim=1536,
    num_heads=9
)
device = "cpu"
checkpoint_path = "model_bin.pt"

checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()

def generate_text(prompt, max_length=100, temperature=0.7, top_k=50):
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    
    with torch.no_grad():
        for _ in range(max_length):
            outputs = model(input_ids)
            next_token_logits = outputs[:, -1, :] / temperature
            
            # Apply top-k sampling
            top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1)
            probs = torch.softmax(top_k_logits, dim=-1)
            
            # Sample from distribution
            next_token_idx = torch.multinomial(probs, num_samples=1)
            next_token = top_k_indices[0, next_token_idx[0]]
            
            if next_token.item() == tokenizer.eos_token_id:
                break
                
            input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
    
    return tokenizer.decode(input_ids[0], skip_special_tokens=True)

# Gradio interface
demo = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.Textbox(label="Input Prompt", lines=3),
        gr.Slider(50, 200, value=50, label="Max Length"),
        gr.Slider(0.1, 2.0, value=0.7, label="Temperature"),
        gr.Slider(10, 100, value=50, label="Top-k")
    ],
    outputs=gr.Textbox(label="Generated Text", lines=5),
    title="SmolLM2 Demo",
    description="A 135M parameter language model trained on smollm-corpus"
)

if __name__ == "__main__":
    demo.launch()