debug + gpu optimisation
Browse files
app.py
CHANGED
@@ -11,6 +11,8 @@ from controllable_blender import ControllableBlender
|
|
11 |
from huggingface_hub import snapshot_download
|
12 |
from huggingface_hub import login
|
13 |
|
|
|
|
|
14 |
token = os.environ.get("Token1")
|
15 |
|
16 |
login(token=token)
|
@@ -38,6 +40,7 @@ def init_world(cefr, inference_type):
|
|
38 |
opt = agent_opt.copy()
|
39 |
opt["rerank_cefr"] = cefr
|
40 |
opt["inference"] = inference_type
|
|
|
41 |
|
42 |
# Settings for rerank methods (not used if "inference" == "vocab")
|
43 |
opt["rerank_tokenizer"] = "distilroberta-base" # Tokenizer from Huggingface Transformers. Must be compatible with "rerank_model"
|
@@ -66,8 +69,12 @@ def chat(user_input, cefr, inference_type, history):
|
|
66 |
conversation_state["world"] = world
|
67 |
conversation_state["human_agent"] = human_agent
|
68 |
|
|
|
|
|
|
|
|
|
|
|
69 |
conversation_state["human_agent"].msg = user_input
|
70 |
-
|
71 |
conversation_state["world"].parley()
|
72 |
|
73 |
bot_reply = conversation_state["world"].acts[1].get("text", "")
|
|
|
11 |
from huggingface_hub import snapshot_download
|
12 |
from huggingface_hub import login
|
13 |
|
14 |
+
torch.set_default_dtype(torch.float16)
|
15 |
+
|
16 |
token = os.environ.get("Token1")
|
17 |
|
18 |
login(token=token)
|
|
|
40 |
opt = agent_opt.copy()
|
41 |
opt["rerank_cefr"] = cefr
|
42 |
opt["inference"] = inference_type
|
43 |
+
opt["gpu"]
|
44 |
|
45 |
# Settings for rerank methods (not used if "inference" == "vocab")
|
46 |
opt["rerank_tokenizer"] = "distilroberta-base" # Tokenizer from Huggingface Transformers. Must be compatible with "rerank_model"
|
|
|
69 |
conversation_state["world"] = world
|
70 |
conversation_state["human_agent"] = human_agent
|
71 |
|
72 |
+
print("🔥 Warming up...")
|
73 |
+
conversation_state["human_agent"].msg = "Hello"
|
74 |
+
conversation_state["world"].parley()
|
75 |
+
print("✅ Warmup complete.")
|
76 |
+
|
77 |
conversation_state["human_agent"].msg = user_input
|
|
|
78 |
conversation_state["world"].parley()
|
79 |
|
80 |
bot_reply = conversation_state["world"].acts[1].get("text", "")
|