|
import gradio as gr |
|
import os |
|
os.environ["OMP_NUM_THREADS"] = "1" |
|
import json |
|
import torch |
|
from parlai.core.opt import Opt |
|
from parlai.zoo.blender.blender_3B import download |
|
from parlai.core.agents import Agent |
|
from parlai.core.params import ParlaiParser |
|
from parlai.core.worlds import DialogPartnerWorld |
|
from controllable_blender import ControllableBlender |
|
from huggingface_hub import snapshot_download |
|
from huggingface_hub import login |
|
|
|
torch.set_default_dtype(torch.float16) |
|
|
|
token = os.environ.get("Token1") |
|
|
|
login(token=token) |
|
|
|
snapshot_download(repo_id="shivansarora/ControllableBlender", local_dir="ParlAI/data/models/blender/blender_3B") |
|
|
|
|
|
agent_opt = json.load(open("blender_3B.opt", 'r')) |
|
download(agent_opt["datapath"]) |
|
conversation_state = {"world": None, "human_agent": None} |
|
|
|
class GradioHumanAgent(Agent): |
|
def __init__(self, opt): |
|
super().__init__(opt) |
|
self.msg = None |
|
|
|
def observe(self, msg): |
|
return msg |
|
|
|
def act(self): |
|
return {"text": self.msg, "episode_done": False} |
|
|
|
|
|
def init_world(cefr, inference_type): |
|
opt = agent_opt.copy() |
|
opt["rerank_cefr"] = cefr |
|
opt["inference"] = inference_type |
|
opt["gpu"] |
|
|
|
|
|
opt["rerank_tokenizer"] = "distilroberta-base" |
|
opt["rerank_model"] = "complexity_model" |
|
opt["rerank_model_device"] = "cuda" |
|
opt["penalty_stddev"] = 2 |
|
opt["filter_path"] = "data/filter.txt" |
|
|
|
|
|
opt["wordlist_path"] = "data/sample_wordlist.txt" |
|
|
|
|
|
opt["beam_size"] = 20 |
|
opt["topk"] = 40 |
|
|
|
human_agent = GradioHumanAgent(opt) |
|
model_agent = ControllableBlender(opt) |
|
world = DialogPartnerWorld(opt, [human_agent, model_agent]) |
|
|
|
return human_agent, world |
|
|
|
def chat(user_input, cefr, inference_type, history): |
|
if conversation_state["world"] is None: |
|
|
|
human_agent, world = init_world(cefr, inference_type) |
|
conversation_state["world"] = world |
|
conversation_state["human_agent"] = human_agent |
|
|
|
print("π₯ Warming up...") |
|
conversation_state["human_agent"].msg = "Hello" |
|
conversation_state["world"].agents[1].opt['beam_size'] = 1 |
|
conversation_state["world"].agents[1].opt['topk'] = 10 |
|
conversation_state["world"].parley() |
|
print("β
Warmup complete.") |
|
|
|
conversation_state["human_agent"].msg = user_input |
|
conversation_state["world"].parley() |
|
|
|
bot_reply = conversation_state["world"].acts[1].get("text", "") |
|
history.append([user_input, bot_reply.strip()]) |
|
return history, history |
|
|
|
def reset_chat(): |
|
conversation_state["world"] = None |
|
conversation_state["human_agent"] = None |
|
return [] |
|
|
|
with gr.Blocks() as demo: |
|
cefr = gr.Dropdown(["A1", "A2", "B1", "B2", "C1", "C2"], label="CEFR", value="B2") |
|
inference_type = gr.Dropdown(["rerank", "vocab"], label="Inference", value="rerank") |
|
user_input = gr.Textbox(label="your message") |
|
chatbot = gr.Chatbot(label="Controllable Complexity Chatbot") |
|
send_btn = gr.Button("Send") |
|
|
|
state = gr.State([]) |
|
|
|
def user_chat(message, cefr_level, infer_type, history): |
|
|
|
print("Received:", user_input) |
|
new_history, _ = chat(message, cefr_level, infer_type, history) |
|
print("Received:", user_input) |
|
return new_history, new_history |
|
|
|
send_btn.click( |
|
fn=user_chat, |
|
inputs=[user_input, cefr, inference_type, state], |
|
outputs=[chatbot, state] |
|
) |
|
|
|
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, ssr_mode=False) |
|
|