Safetensors
llama
text
File size: 2,495 Bytes
4dcaf3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import logging
from transformers import AutoTokenizer, AutoModelForCausalLM


logger = logging.getLogger(__name__)


class Chat:
    def __init__(
        self,
        path="mathewhe/Llama-3.1-8B-Chat",
        device="cuda",
    ):
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.model = AutoModelForCausalLM.from_pretrained(path, device_map=device)

        self.messages = list()
        self.device = device
        self.gen_kwargs = {
            "min_new_tokens": 1,
            "max_new_tokens": 2048,
            "top_p": 0.8,
            "temperature": 0.8,
            "do_sample": True,
            "repetition_penalty": 1.1,
        }

    def reset(self):
        r"""Reset the chat message history."""
        self.messages = list()

    def _inference(self, messages):
        chat = self.tokenizer.apply_chat_template(messages, tokenize=False)
        inputs = {
            k: v.to(self.device)
            for k, v in self.tokenizer(chat, return_tensors="pt", add_special_tokens=False).items()
        }
        input_length = len(inputs["input_ids"][0])
        output = self.model.generate(**inputs, **self.gen_kwargs)
        response = self.tokenizer.decode(
            output[0].tolist()[input_length:],
            skip_special_tokens=True,
        )
        return response

    def message(self, message):
        r"""
        Add the message to the chat history and return a response.
        """
        self.messages.append({"role": "user", "content": message})
        # need to add caching of internal state!!
        response = self._inference(self.messages)
        self.messages.append({"role": "assistant", "content": response})
        return response

    def cli_chat(self):
        r"""
        For CLI-based chatting (with history).
        """
        asst_prompt = "Assistant: "
        user_prompt = "---> User: "

        print(f"{asst_prompt}Hi! How can I help you?\n")
        message = input(user_prompt)
        while not (message is None or message == ""):
            response = self.message(message)
            print(f"\n{asst_prompt}{response}\n")
            message = input(user_prompt)

    def instruct(self, message):
        r"""
        For single instruction-response interactions (without history).
        """
        messages = [{"role": "user", "content": message}]
        response = self._inference(messages)
        return response


if __name__ == "__main__":
    chat = Chat()
    chat.cli_chat()