File size: 4,486 Bytes
a53045b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c21b2d
a53045b
45e8e17
8c21b2d
 
 
a53045b
 
8c21b2d
a53045b
8c21b2d
836c9c0
7bb0763
 
 
16de4fd
7bb0763
 
 
16de4fd
8c21b2d
a53045b
8c21b2d
a53045b
 
8c21b2d
a53045b
8c21b2d
 
a53045b
 
8c21b2d
7bb0763
 
8c21b2d
 
45e8e17
 
8c21b2d
 
a53045b
836c9c0
 
 
45e8e17
8c21b2d
a53045b
8c21b2d
44bb509
a53045b
 
45524d1
a53045b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45e8e17
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from config import ModelArgs
from model import Llama
import torch
import torch.nn.functional as F
from tokenizer import Tokenizer
import argparse


tokenizer = Tokenizer()
tokenizer = tokenizer.ready_tokenizer()


def remove_hashtag_lines(text):
    """Removes lines that contain hashtags from the given text."""
    lines = text.split("\n")
    cleaned_lines = [line for line in lines if "#" not in line]
    return "\n".join(cleaned_lines)


def remove_prefix(state_dict, prefix):
    new_state_dict = {}
    for key, value in state_dict.items():
        if key.startswith(prefix):
            new_key = key[len(prefix):]  # Remove the prefix
            new_state_dict[new_key] = value
        else:
            new_state_dict[key] = value
    return new_state_dict


def topk_sampling(model, prompt, device, max_length=50, top_k=50, temperature=1.0, frequency_penalty=0.5):
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    # generated_tokens = []  # Store generated tokens
    token_frequencies = {}  # Track token counts

    for step in range(max_length):
        with torch.no_grad():
            outputs = model(input_ids)
            logits = outputs[:, -1, :]  # Get logits for next token
            
            logits = logits / temperature
            # # Step 1: Apply frequency penalty ONLY AFTER the first token is generated
            if step > 0:  # Skip penalty on first step
                for token in input_ids[0].tolist():
                    token_frequencies[token] = token_frequencies.get(token, 0) + 1  # Count occurrences

                # Modify logits AFTER counting
                for token, freq in token_frequencies.items():
                    logits[0, token] -= frequency_penalty * (freq ** 0.8)  # Apply soft penalty

            # Convert logits to probabilities
            probs = F.softmax(logits, dim=-1)

            # Top-k filtering
            top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1)

            # Apply temperature scaling
            # probs = probs / temperature

            # Sample from top-k
            next_token = torch.multinomial(top_k_probs, num_samples=1)

            # if next_token == tokenizer.eos_token_id:
            #     break  # Stop if EOS token is generated

            # Store generated token AFTER sampling
            # token_id = next_token.item()
            # generated_tokens.append(token_id)

            # Update input_ids for next step
            xcol = torch.gather(top_k_indices, -1, next_token)

            if xcol == tokenizer.eos_token_id:
                break
            # generated_tokens.append(xcol)
            input_ids = torch.cat([input_ids, xcol], dim=1)

    # Decode only the generated tokens
    return tokenizer.decode(input_ids[0], skip_special_tokens=True)
def main():

    # torch.set_float32_matmul_precision('high')

    parser = argparse.ArgumentParser()
    parser.add_argument("--prompt", type=str, default=''' Follow the given instructions carefully. My mom is about to retire from her 10 long years of service to a company. write me a message saying how grateful we are for her service to our company. ''')
    parser.add_argument("--max_length", type=int, default=256)
    parser.add_argument("--temperature", type=float, default=0.8)
    # parser.add_argument("--repetition_penalty", type=float, default=1.2)
    args = parser.parse_args()
    
    model = Llama(device=ModelArgs.device, embeddings_dims=ModelArgs.embeddings_dims, no_of_decoder_layers=ModelArgs.no_of_decoder_layers, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout)
    # model = torch.compile(model)
    model = model.to(ModelArgs.device)

    dict_model = torch.load('DPO_model_1650.pt')
    dict_model['MODEL_STATE'] = remove_prefix(dict_model['MODEL_STATE'], '_orig_mod.')
    model.load_state_dict(dict_model['MODEL_STATE'])
    model.eval()
    print("Model ready")
    # prompt = 'Its a secret'

    with torch.no_grad():
        generated_text = topk_sampling(model, args.prompt, max_length=args.max_length, top_k=args.top_k, temperature=args.temperature, device=ModelArgs.device)
        # generated_text = remove_hashtag_lines(generated_text)
        print("Generated: ", generated_text)
        # generated_text = beam_search(model, tokenizer, args.prompt, beam_width=5, max_length=50, temperature=1.0)
        # print(args.prompt + generated_text)


if __name__ == '__main__':
    main()