File size: 2,053 Bytes
ffd9ec7
 
 
ec9d0bc
073264f
fa56304
522e9d7
ffd9ec7
 
4a1495c
 
a18df92
9c4feff
fa56304
a6caac4
a18df92
522e9d7
a18df92
 
 
 
 
 
522e9d7
7f349bb
 
522e9d7
7f349bb
522e9d7
 
 
ffd9ec7
7f349bb
 
fa56304
7f349bb
 
 
 
522e9d7
 
 
 
 
7f349bb
ffd9ec7
7f349bb
 
 
522e9d7
 
 
 
7f349bb
 
 
ffd9ec7
a18df92
fa56304
ffd9ec7
522e9d7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import os
from datasets import load_dataset
from huggingface_hub import login

app = FastAPI()

# Get the token from the environment variable
hf_token = os.environ.get("HF_TOKEN")
# login(token=hf_token)

# client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=hf_token)
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=hf_token)
# dataset = load_dataset("Lhumpal/youtube-hunting-beast-transcripts", data_files={"concise": "concise/*", "raw": "raw/*"})

if dataset:
    texts = []
    for file in dataset["concise"]:
        # Remove newline characters from the 'text' field
        cleaned_text = file['text'].replace('\n', ' ')
        texts.append(cleaned_text)

class ChatRequest(BaseModel):
    message: str
    history: list[tuple[str, str]] = []
    system_message: str = "You are a friendly Chatbot."
    max_tokens: int = 512
    temperature: float = 0.7
    top_p: float = 0.95

class ChatResponse(BaseModel):
    response: str
    
@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
    try:
        messages = [{"role": "system", "content": request.system_message}]
        for val in request.history:
            if val[0]:
                messages.append({"role": "user", "content": val[0]})
            if val[1]:
                messages.append({"role": "assistant", "content": val[1]})
        messages.append({"role": "user", "content": request.message})

        response = ""
        for message in client.chat_completion(
            messages,
            max_tokens=request.max_tokens,
            stream=True,
            temperature=request.temperature,
            top_p=request.top_p,
        ):
            token = message.choices[0].delta.content
            response += token

        return {"assistant_response": response, "dataset_sample": "sample"}
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))