Spaces:
Paused
Paused
File size: 4,486 Bytes
a53045b 8c21b2d a53045b 45e8e17 8c21b2d a53045b 8c21b2d a53045b 8c21b2d 836c9c0 7bb0763 16de4fd 7bb0763 16de4fd 8c21b2d a53045b 8c21b2d a53045b 8c21b2d a53045b 8c21b2d a53045b 8c21b2d 7bb0763 8c21b2d 45e8e17 8c21b2d a53045b 836c9c0 45e8e17 8c21b2d a53045b 8c21b2d 44bb509 a53045b 45524d1 a53045b 45e8e17 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
from config import ModelArgs
from model import Llama
import torch
import torch.nn.functional as F
from tokenizer import Tokenizer
import argparse
tokenizer = Tokenizer()
tokenizer = tokenizer.ready_tokenizer()
def remove_hashtag_lines(text):
"""Removes lines that contain hashtags from the given text."""
lines = text.split("\n")
cleaned_lines = [line for line in lines if "#" not in line]
return "\n".join(cleaned_lines)
def remove_prefix(state_dict, prefix):
new_state_dict = {}
for key, value in state_dict.items():
if key.startswith(prefix):
new_key = key[len(prefix):] # Remove the prefix
new_state_dict[new_key] = value
else:
new_state_dict[key] = value
return new_state_dict
def topk_sampling(model, prompt, device, max_length=50, top_k=50, temperature=1.0, frequency_penalty=0.5):
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
# generated_tokens = [] # Store generated tokens
token_frequencies = {} # Track token counts
for step in range(max_length):
with torch.no_grad():
outputs = model(input_ids)
logits = outputs[:, -1, :] # Get logits for next token
logits = logits / temperature
# # Step 1: Apply frequency penalty ONLY AFTER the first token is generated
if step > 0: # Skip penalty on first step
for token in input_ids[0].tolist():
token_frequencies[token] = token_frequencies.get(token, 0) + 1 # Count occurrences
# Modify logits AFTER counting
for token, freq in token_frequencies.items():
logits[0, token] -= frequency_penalty * (freq ** 0.8) # Apply soft penalty
# Convert logits to probabilities
probs = F.softmax(logits, dim=-1)
# Top-k filtering
top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1)
# Apply temperature scaling
# probs = probs / temperature
# Sample from top-k
next_token = torch.multinomial(top_k_probs, num_samples=1)
# if next_token == tokenizer.eos_token_id:
# break # Stop if EOS token is generated
# Store generated token AFTER sampling
# token_id = next_token.item()
# generated_tokens.append(token_id)
# Update input_ids for next step
xcol = torch.gather(top_k_indices, -1, next_token)
if xcol == tokenizer.eos_token_id:
break
# generated_tokens.append(xcol)
input_ids = torch.cat([input_ids, xcol], dim=1)
# Decode only the generated tokens
return tokenizer.decode(input_ids[0], skip_special_tokens=True)
def main():
# torch.set_float32_matmul_precision('high')
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str, default=''' Follow the given instructions carefully. My mom is about to retire from her 10 long years of service to a company. write me a message saying how grateful we are for her service to our company. ''')
parser.add_argument("--max_length", type=int, default=256)
parser.add_argument("--temperature", type=float, default=0.8)
# parser.add_argument("--repetition_penalty", type=float, default=1.2)
args = parser.parse_args()
model = Llama(device=ModelArgs.device, embeddings_dims=ModelArgs.embeddings_dims, no_of_decoder_layers=ModelArgs.no_of_decoder_layers, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout)
# model = torch.compile(model)
model = model.to(ModelArgs.device)
dict_model = torch.load('DPO_model_1650.pt')
dict_model['MODEL_STATE'] = remove_prefix(dict_model['MODEL_STATE'], '_orig_mod.')
model.load_state_dict(dict_model['MODEL_STATE'])
model.eval()
print("Model ready")
# prompt = 'Its a secret'
with torch.no_grad():
generated_text = topk_sampling(model, args.prompt, max_length=args.max_length, top_k=args.top_k, temperature=args.temperature, device=ModelArgs.device)
# generated_text = remove_hashtag_lines(generated_text)
print("Generated: ", generated_text)
# generated_text = beam_search(model, tokenizer, args.prompt, beam_width=5, max_length=50, temperature=1.0)
# print(args.prompt + generated_text)
if __name__ == '__main__':
main()
|