shivansarora commited on
Commit
3643648
·
verified ·
1 Parent(s): 8cffbdb

debug + gpu optimisation

Browse files
Files changed (1) hide show
  1. app.py +8 -1
app.py CHANGED
@@ -11,6 +11,8 @@ from controllable_blender import ControllableBlender
11
  from huggingface_hub import snapshot_download
12
  from huggingface_hub import login
13
 
 
 
14
  token = os.environ.get("Token1")
15
 
16
  login(token=token)
@@ -38,6 +40,7 @@ def init_world(cefr, inference_type):
38
  opt = agent_opt.copy()
39
  opt["rerank_cefr"] = cefr
40
  opt["inference"] = inference_type
 
41
 
42
  # Settings for rerank methods (not used if "inference" == "vocab")
43
  opt["rerank_tokenizer"] = "distilroberta-base" # Tokenizer from Huggingface Transformers. Must be compatible with "rerank_model"
@@ -66,8 +69,12 @@ def chat(user_input, cefr, inference_type, history):
66
  conversation_state["world"] = world
67
  conversation_state["human_agent"] = human_agent
68
 
 
 
 
 
 
69
  conversation_state["human_agent"].msg = user_input
70
-
71
  conversation_state["world"].parley()
72
 
73
  bot_reply = conversation_state["world"].acts[1].get("text", "")
 
11
  from huggingface_hub import snapshot_download
12
  from huggingface_hub import login
13
 
14
+ torch.set_default_dtype(torch.float16)
15
+
16
  token = os.environ.get("Token1")
17
 
18
  login(token=token)
 
40
  opt = agent_opt.copy()
41
  opt["rerank_cefr"] = cefr
42
  opt["inference"] = inference_type
43
+ opt["gpu"]
44
 
45
  # Settings for rerank methods (not used if "inference" == "vocab")
46
  opt["rerank_tokenizer"] = "distilroberta-base" # Tokenizer from Huggingface Transformers. Must be compatible with "rerank_model"
 
69
  conversation_state["world"] = world
70
  conversation_state["human_agent"] = human_agent
71
 
72
+ print("🔥 Warming up...")
73
+ conversation_state["human_agent"].msg = "Hello"
74
+ conversation_state["world"].parley()
75
+ print("✅ Warmup complete.")
76
+
77
  conversation_state["human_agent"].msg = user_input
 
78
  conversation_state["world"].parley()
79
 
80
  bot_reply = conversation_state["world"].acts[1].get("text", "")