File size: 4,252 Bytes
3d18a82 a695aae 3d18a82 9da7e1d 3d18a82 148a80f 3643648 f52fcf2 bbe9094 3d18a82 3643648 3d18a82 3643648 2856fde 3643648 3d18a82 8cffbdb 3d18a82 8cffbdb 3d18a82 631882d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import gradio as gr
import os
os.environ["OMP_NUM_THREADS"] = "1"
import json
import torch
from parlai.core.opt import Opt
from parlai.zoo.blender.blender_3B import download
from parlai.core.agents import Agent
from parlai.core.params import ParlaiParser
from parlai.core.worlds import DialogPartnerWorld
from controllable_blender import ControllableBlender
from huggingface_hub import snapshot_download
from huggingface_hub import login
torch.set_default_dtype(torch.float16)
token = os.environ.get("Token1")
login(token=token)
snapshot_download(repo_id="shivansarora/ControllableBlender", local_dir="ParlAI/data/models/blender/blender_3B")
# Load options
agent_opt = json.load(open("blender_3B.opt", 'r'))
download(agent_opt["datapath"])
conversation_state = {"world": None, "human_agent": None}
class GradioHumanAgent(Agent):
def __init__(self, opt):
super().__init__(opt)
self.msg = None
def observe(self, msg):
return msg
def act(self):
return {"text": self.msg, "episode_done": False}
def init_world(cefr, inference_type):
opt = agent_opt.copy()
opt["rerank_cefr"] = cefr
opt["inference"] = inference_type
opt["gpu"]
# Settings for rerank methods (not used if "inference" == "vocab")
opt["rerank_tokenizer"] = "distilroberta-base" # Tokenizer from Huggingface Transformers. Must be compatible with "rerank_model"
opt["rerank_model"] = "complexity_model" # Model fine-tuned on complexity data
opt["rerank_model_device"] = "cuda" # Device for complexity model
opt["penalty_stddev"] = 2 # Controls how harshly sub-tokens are penalised (lower = harsher). Use -1 to remove penalties
opt["filter_path"] = "data/filter.txt" # Path to list of English words to ensure OOV words are not generated. Capitalised words are ignored. Use empty string to remove filter
# Settings for vocab methods (not used if "inference" == "rerank")
opt["wordlist_path"] = "data/sample_wordlist.txt" # Path to list of vocab the chatbot is restricted to
# Same top-k sampling configs for all settings described in the paper
opt["beam_size"] = 20
opt["topk"] = 40
human_agent = GradioHumanAgent(opt)
model_agent = ControllableBlender(opt)
world = DialogPartnerWorld(opt, [human_agent, model_agent])
return human_agent, world
def chat(user_input, cefr, inference_type, history):
if conversation_state["world"] is None:
human_agent, world = init_world(cefr, inference_type)
conversation_state["world"] = world
conversation_state["human_agent"] = human_agent
print("🔥 Warming up...")
conversation_state["human_agent"].msg = "Hello"
conversation_state["world"].agents[1].opt['beam_size'] = 1
conversation_state["world"].agents[1].opt['topk'] = 10
conversation_state["world"].parley()
print("✅ Warmup complete.")
conversation_state["human_agent"].msg = user_input
conversation_state["world"].parley()
bot_reply = conversation_state["world"].acts[1].get("text", "")
history.append([user_input, bot_reply.strip()])
return history, history
def reset_chat():
conversation_state["world"] = None
conversation_state["human_agent"] = None
return []
with gr.Blocks() as demo:
cefr = gr.Dropdown(["A1", "A2", "B1", "B2", "C1", "C2"], label="CEFR", value="B2")
inference_type = gr.Dropdown(["rerank", "vocab"], label="Inference", value="rerank")
user_input = gr.Textbox(label="your message")
chatbot = gr.Chatbot(label="Controllable Complexity Chatbot")
send_btn = gr.Button("Send")
state = gr.State([])
def user_chat(message, cefr_level, infer_type, history):
# call your chat function here
print("Received:", user_input)
new_history, _ = chat(message, cefr_level, infer_type, history)
print("Received:", user_input)
return new_history, new_history
send_btn.click(
fn=user_chat,
inputs=[user_input, cefr, inference_type, state],
outputs=[chatbot, state]
)
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, ssr_mode=False)
|