File size: 3,585 Bytes
faf7233
 
 
 
 
96b9b32
 
 
 
 
faf7233
 
 
 
 
 
 
 
 
 
1f1b386
faf7233
1f1b386
9ab8f00
1f1b386
faf7233
1f1b386
96b9b32
 
 
faf7233
1f1b386
96b9b32
a58e802
faf7233
 
 
 
1b4738b
faf7233
 
 
 
 
9ab8f00
 
 
 
faf7233
1b4738b
faf7233
 
 
 
 
 
 
 
 
 
7ba9eb3
faf7233
58d7ca2
faf7233
 
 
 
58d7ca2
faf7233
 
 
 
 
 
96b9b32
 
 
13d6e67
7ba9eb3
13d6e67
 
7ba9eb3
13d6e67
 
7ba9eb3
 
faf7233
 
 
 
 
 
 
 
 
 
 
 
 
 
1f1b386
 
96b9b32
 
1f1b386
 
96b9b32
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import logging

from fastapi import FastAPI
from llama_index.llms.llama_cpp import LlamaCPP
from transformers import AutoTokenizer
from llama_index.core import set_global_tokenizer
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core import SimpleDirectoryReader
from llama_index.core import VectorStoreIndex


logging.basicConfig(
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    level=logging.INFO
)
logger = logging.getLogger(__name__)

logger.info("Запускаемся... 🥳🥳🥳")

app = FastAPI()

model_url = "https://huggingface.co/Qwen/Qwen2.5-7B-Instruct-GGUF/resolve/main/qwen2.5-7b-instruct-q3_k_m.gguf"

SYSTEM_PROMPT = ''

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")

set_global_tokenizer(
    AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct").encode
)

embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-mpnet-base-v2")

documents = SimpleDirectoryReader("./data/").load_data()

def messages_to_prompt(messages):
    messages = [{"role": m.role.value, "content": m.content} for m in messages]
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    return prompt


def completion_to_prompt(completion):
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT or "Ответ должен быть точным, кратким и с юмором."},
        {"role": "user", "content": completion},
    ]
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    return prompt


llm = LlamaCPP(
    # You can pass in the URL to a GGML model to download it automatically
    model_url=model_url,
    # optionally, you can set the path to a pre-downloaded model instead of model_url
    model_path=None,
    temperature=0.1,
    max_new_tokens=256,
    # llama2 has a context window of 4096 tokens, but we set it lower to allow for some wiggle room
    context_window=2046,
    # kwargs to pass to __call__()
    generate_kwargs={},
    # kwargs to pass to __init__()
    # set to at least 1 to use GPU
    model_kwargs={"n_gpu_layers": -1, "num_return_sequences": 1, "no_repeat_ngram_size": 2, "n_threads": 2},
    # transform inputs into Llama2 format
    messages_to_prompt=messages_to_prompt,
    completion_to_prompt=completion_to_prompt,
    verbose=True,
)

index = VectorStoreIndex.from_documents(documents, embed_model=embed_model)
query_engine = index.as_query_engine(llm=llm)

def generate_response(completion_response):
    try:
        response_text = completion_response.text
        return response_text.strip() if response_text else "Пустой ответ"
    except Exception as e:
        logger.error(f"Ошибка обработки ответа: {str(e)}")
        return "Ошибка генерации"


@app.get("/")
def greet_json():
    return {"Hello": "World!"}

@app.put("/system-prompt")
async def set_system_prompt(text: str):
    logger.info('post/system-prompt')
    # global SYSTEM_PROMPT
    # SYSTEM_PROMPT = text

@app.post("/predict")
async def predict(text: str):
    # Генерация ответа с помощью модели
    logger.info('post/predict')
    logger.info('ЗАПРОС:')
    logger.info(text)
    # response = llm.complete(text)
    response = query_engine.query(text)
    logger.info('ОТВЕТ:')
    logger.info(response)
    # text_response = generate_response(response)
    # return {"response": text_response}
    return {"response": response}