beast-llm / app.py
Lhumpal's picture
Update app.py
a18df92 verified
raw
history blame
2.05 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import os
from datasets import load_dataset
from huggingface_hub import login
app = FastAPI()
# Get the token from the environment variable
hf_token = os.environ.get("HF_TOKEN")
# login(token=hf_token)
# client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=hf_token)
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=hf_token)
# dataset = load_dataset("Lhumpal/youtube-hunting-beast-transcripts", data_files={"concise": "concise/*", "raw": "raw/*"})
if dataset:
texts = []
for file in dataset["concise"]:
# Remove newline characters from the 'text' field
cleaned_text = file['text'].replace('\n', ' ')
texts.append(cleaned_text)
class ChatRequest(BaseModel):
message: str
history: list[tuple[str, str]] = []
system_message: str = "You are a friendly Chatbot."
max_tokens: int = 512
temperature: float = 0.7
top_p: float = 0.95
class ChatResponse(BaseModel):
response: str
@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
try:
messages = [{"role": "system", "content": request.system_message}]
for val in request.history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": request.message})
response = ""
for message in client.chat_completion(
messages,
max_tokens=request.max_tokens,
stream=True,
temperature=request.temperature,
top_p=request.top_p,
):
token = message.choices[0].delta.content
response += token
return {"assistant_response": response, "dataset_sample": "sample"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))