File size: 4,216 Bytes
3d18a82
 
 
9da7e1d
3d18a82
 
 
 
 
 
 
148a80f
 
3643648
 
f52fcf2
 
bbe9094
3d18a82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3643648
3d18a82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3643648
 
2856fde
 
3643648
 
 
3d18a82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8cffbdb
3d18a82
8cffbdb
3d18a82
 
 
 
 
 
 
 
631882d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import gradio as gr
import os
import json
import torch
from parlai.core.opt import Opt
from parlai.zoo.blender.blender_3B import download
from parlai.core.agents import Agent
from parlai.core.params import ParlaiParser
from parlai.core.worlds import DialogPartnerWorld
from controllable_blender import ControllableBlender
from huggingface_hub import snapshot_download
from huggingface_hub import login

torch.set_default_dtype(torch.float16)

token = os.environ.get("Token1")

login(token=token)

snapshot_download(repo_id="shivansarora/ControllableBlender", local_dir="ParlAI/data/models/blender/blender_3B")

# Load options
agent_opt = json.load(open("blender_3B.opt", 'r'))
download(agent_opt["datapath"])
conversation_state = {"world": None, "human_agent": None}

class GradioHumanAgent(Agent):
    def __init__(self, opt):
        super().__init__(opt)
        self.msg = None

    def observe(self, msg):
        return msg

    def act(self):
        return {"text": self.msg, "episode_done": False}


def init_world(cefr, inference_type):
    opt = agent_opt.copy()
    opt["rerank_cefr"] = cefr
    opt["inference"] = inference_type
    opt["gpu"]

    # Settings for rerank methods (not used if "inference" == "vocab")
    opt["rerank_tokenizer"] = "distilroberta-base"        # Tokenizer from Huggingface Transformers. Must be compatible with "rerank_model"
    opt["rerank_model"] = "complexity_model"              # Model fine-tuned on complexity data
    opt["rerank_model_device"] = "cuda"                   # Device for complexity model
    opt["penalty_stddev"] = 2                             # Controls how harshly sub-tokens are penalised (lower = harsher). Use -1 to remove penalties
    opt["filter_path"] = "data/filter.txt"                # Path to list of English words to ensure OOV words are not generated. Capitalised words are ignored. Use empty string to remove filter

    # Settings for vocab methods (not used if "inference" == "rerank")
    opt["wordlist_path"] = "data/sample_wordlist.txt"     # Path to list of vocab the chatbot is restricted to    
   
    # Same top-k sampling configs for all settings described in the paper
    opt["beam_size"] = 20
    opt["topk"] = 40

    human_agent = GradioHumanAgent(opt)
    model_agent = ControllableBlender(opt)
    world = DialogPartnerWorld(opt, [human_agent, model_agent])

    return human_agent, world

def chat(user_input, cefr, inference_type, history):
    if conversation_state["world"] is None:

        human_agent, world = init_world(cefr, inference_type)
        conversation_state["world"] = world
        conversation_state["human_agent"] = human_agent

        print("🔥 Warming up...")
        conversation_state["human_agent"].msg = "Hello"
        conversation_state["world"].agents[1].opt['beam_size'] = 1
        conversation_state["world"].agents[1].opt['topk'] = 10
        conversation_state["world"].parley()
        print("✅ Warmup complete.")
    
    conversation_state["human_agent"].msg = user_input
    conversation_state["world"].parley()

    bot_reply = conversation_state["world"].acts[1].get("text", "")
    history.append([user_input, bot_reply.strip()])
    return history, history

def reset_chat():
    conversation_state["world"] = None
    conversation_state["human_agent"] = None
    return []

with gr.Blocks() as demo:
    cefr = gr.Dropdown(["A1", "A2", "B1", "B2", "C1", "C2"], label="CEFR", value="B2")
    inference_type = gr.Dropdown(["rerank", "vocab"], label="Inference", value="rerank")
    user_input = gr.Textbox(label="your message")
    chatbot = gr.Chatbot(label="Controllable Complexity Chatbot")
    send_btn = gr.Button("Send")

    state = gr.State([])

    def user_chat(message, cefr_level, infer_type, history):
        # call your chat function here
        print("Received:", user_input)
        new_history, _ = chat(message, cefr_level, infer_type, history)
        print("Received:", user_input)
        return new_history, new_history

    send_btn.click(
        fn=user_chat,
        inputs=[user_input, cefr, inference_type, state],
        outputs=[chatbot, state]
    )

demo.launch(server_name="0.0.0.0", server_port=7860, share=False, ssr_mode=False)