File size: 4,252 Bytes
3d18a82
 
a695aae
3d18a82
9da7e1d
3d18a82
 
 
 
 
 
 
148a80f
 
3643648
 
f52fcf2
 
bbe9094
3d18a82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3643648
3d18a82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3643648
 
2856fde
 
3643648
 
 
3d18a82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8cffbdb
3d18a82
8cffbdb
3d18a82
 
 
 
 
 
 
 
631882d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
import os
os.environ["OMP_NUM_THREADS"] = "1"
import json
import torch
from parlai.core.opt import Opt
from parlai.zoo.blender.blender_3B import download
from parlai.core.agents import Agent
from parlai.core.params import ParlaiParser
from parlai.core.worlds import DialogPartnerWorld
from controllable_blender import ControllableBlender
from huggingface_hub import snapshot_download
from huggingface_hub import login

torch.set_default_dtype(torch.float16)

token = os.environ.get("Token1")

login(token=token)

snapshot_download(repo_id="shivansarora/ControllableBlender", local_dir="ParlAI/data/models/blender/blender_3B")

# Load options
agent_opt = json.load(open("blender_3B.opt", 'r'))
download(agent_opt["datapath"])
conversation_state = {"world": None, "human_agent": None}

class GradioHumanAgent(Agent):
    def __init__(self, opt):
        super().__init__(opt)
        self.msg = None

    def observe(self, msg):
        return msg

    def act(self):
        return {"text": self.msg, "episode_done": False}


def init_world(cefr, inference_type):
    opt = agent_opt.copy()
    opt["rerank_cefr"] = cefr
    opt["inference"] = inference_type
    opt["gpu"]

    # Settings for rerank methods (not used if "inference" == "vocab")
    opt["rerank_tokenizer"] = "distilroberta-base"        # Tokenizer from Huggingface Transformers. Must be compatible with "rerank_model"
    opt["rerank_model"] = "complexity_model"              # Model fine-tuned on complexity data
    opt["rerank_model_device"] = "cuda"                   # Device for complexity model
    opt["penalty_stddev"] = 2                             # Controls how harshly sub-tokens are penalised (lower = harsher). Use -1 to remove penalties
    opt["filter_path"] = "data/filter.txt"                # Path to list of English words to ensure OOV words are not generated. Capitalised words are ignored. Use empty string to remove filter

    # Settings for vocab methods (not used if "inference" == "rerank")
    opt["wordlist_path"] = "data/sample_wordlist.txt"     # Path to list of vocab the chatbot is restricted to    
   
    # Same top-k sampling configs for all settings described in the paper
    opt["beam_size"] = 20
    opt["topk"] = 40

    human_agent = GradioHumanAgent(opt)
    model_agent = ControllableBlender(opt)
    world = DialogPartnerWorld(opt, [human_agent, model_agent])

    return human_agent, world

def chat(user_input, cefr, inference_type, history):
    if conversation_state["world"] is None:

        human_agent, world = init_world(cefr, inference_type)
        conversation_state["world"] = world
        conversation_state["human_agent"] = human_agent

        print("🔥 Warming up...")
        conversation_state["human_agent"].msg = "Hello"
        conversation_state["world"].agents[1].opt['beam_size'] = 1
        conversation_state["world"].agents[1].opt['topk'] = 10
        conversation_state["world"].parley()
        print("✅ Warmup complete.")
    
    conversation_state["human_agent"].msg = user_input
    conversation_state["world"].parley()

    bot_reply = conversation_state["world"].acts[1].get("text", "")
    history.append([user_input, bot_reply.strip()])
    return history, history

def reset_chat():
    conversation_state["world"] = None
    conversation_state["human_agent"] = None
    return []

with gr.Blocks() as demo:
    cefr = gr.Dropdown(["A1", "A2", "B1", "B2", "C1", "C2"], label="CEFR", value="B2")
    inference_type = gr.Dropdown(["rerank", "vocab"], label="Inference", value="rerank")
    user_input = gr.Textbox(label="your message")
    chatbot = gr.Chatbot(label="Controllable Complexity Chatbot")
    send_btn = gr.Button("Send")

    state = gr.State([])

    def user_chat(message, cefr_level, infer_type, history):
        # call your chat function here
        print("Received:", user_input)
        new_history, _ = chat(message, cefr_level, infer_type, history)
        print("Received:", user_input)
        return new_history, new_history

    send_btn.click(
        fn=user_chat,
        inputs=[user_input, cefr, inference_type, state],
        outputs=[chatbot, state]
    )

demo.launch(server_name="0.0.0.0", server_port=7860, share=False, ssr_mode=False)