RSHVR commited on
Commit
dad23dc
·
verified ·
1 Parent(s): cb66cb4

Fix token and temperature inputs

Browse files
Files changed (1) hide show
  1. app.py +54 -6
app.py CHANGED
@@ -31,7 +31,7 @@ def show_model_change_info(model_name):
31
  return model_name
32
 
33
 
34
- async def respond(message, history, model_name="command-a-03-2025"):
35
  """Generate streaming response using Cohere API"""
36
 
37
  # Convert Gradio history format to API format
@@ -75,7 +75,9 @@ async def respond(message, history, model_name="command-a-03-2025"):
75
  user_message=message,
76
  conversation_history=conversation_history,
77
  api_key=api_key,
78
- model_name=model_name
 
 
79
  ):
80
  partial_message += chunk
81
  yield partial_message
@@ -84,14 +86,33 @@ async def respond(message, history, model_name="command-a-03-2025"):
84
 
85
  with gr.Blocks() as demo:
86
  gr.Markdown("## Modular Chatbot")
 
 
 
 
 
87
 
88
  with gr.Row():
89
  with gr.Column(scale=2):
 
 
 
 
 
 
 
 
 
 
 
 
90
  chat_interface = gr.ChatInterface(
91
- fn=respond,
92
  type="messages",
93
- save_history=True
 
94
  )
 
95
  with gr.Accordion("Chat Settings", elem_id="chat_settings_group"):
96
  with gr.Row():
97
  with gr.Column(scale=3):
@@ -126,6 +147,15 @@ with gr.Blocks() as demo:
126
  outputs=[model]
127
  )
128
 
 
 
 
 
 
 
 
 
 
129
  with gr.Column(scale=1):
130
  temperature = gr.Slider(
131
  label="Temperature",
@@ -140,12 +170,30 @@ with gr.Blocks() as demo:
140
  )
141
  max_tokens = gr.Textbox(
142
  label="Max Tokens",
143
- info="Higher values allow longer responses.",
144
  value="8192",
145
  elem_id="max_tokens_input",
146
  interactive=True,
147
- show_label=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  )
 
 
 
 
149
 
150
  if __name__ == "__main__":
151
  demo.launch()
 
31
  return model_name
32
 
33
 
34
+ async def respond(message, history, model_name="command-a-03-2025", temperature=0.7, max_tokens=None):
35
  """Generate streaming response using Cohere API"""
36
 
37
  # Convert Gradio history format to API format
 
75
  user_message=message,
76
  conversation_history=conversation_history,
77
  api_key=api_key,
78
+ model_name=model_name,
79
+ temperature=temperature,
80
+ max_tokens=max_tokens
81
  ):
82
  partial_message += chunk
83
  yield partial_message
 
86
 
87
  with gr.Blocks() as demo:
88
  gr.Markdown("## Modular Chatbot")
89
+
90
+ # State components to track current values
91
+ temperature_state = gr.State(value=0.7)
92
+ max_tokens_state = gr.State(value=None)
93
+ model_state = gr.State(value=COHERE_MODELS[0])
94
 
95
  with gr.Row():
96
  with gr.Column(scale=2):
97
+ # Define wrapper function after all components are created
98
+ async def chat_wrapper(message, history, model_val, temp_val, tokens_val):
99
+ # Use the state values directly
100
+ current_model = model_val if model_val else COHERE_MODELS[0]
101
+ current_temp = temp_val if temp_val is not None else 0.7
102
+ current_max_tokens = tokens_val
103
+
104
+ # Stream the response
105
+ async for chunk in respond(message, history, current_model, current_temp, current_max_tokens):
106
+ yield chunk
107
+
108
+ # Create chat interface using the wrapper with additional inputs
109
  chat_interface = gr.ChatInterface(
110
+ fn=chat_wrapper,
111
  type="messages",
112
+ save_history=True,
113
+ additional_inputs=[model_state, temperature_state, max_tokens_state]
114
  )
115
+
116
  with gr.Accordion("Chat Settings", elem_id="chat_settings_group"):
117
  with gr.Row():
118
  with gr.Column(scale=3):
 
147
  outputs=[model]
148
  )
149
 
150
+ # Update state when model changes
151
+ model.change(
152
+ fn=lambda x: x,
153
+ inputs=[model],
154
+ outputs=[model_state]
155
+ )
156
+
157
+
158
+
159
  with gr.Column(scale=1):
160
  temperature = gr.Slider(
161
  label="Temperature",
 
170
  )
171
  max_tokens = gr.Textbox(
172
  label="Max Tokens",
173
+ info="Higher values allow longer responses. Leave empty for default.",
174
  value="8192",
175
  elem_id="max_tokens_input",
176
  interactive=True,
177
+ show_label=True,
178
+ )
179
+
180
+ # Update state when temperature changes
181
+ temperature.change(
182
+ fn=lambda x: x,
183
+ inputs=[temperature],
184
+ outputs=[temperature_state]
185
+ )
186
+
187
+ # Update state when max_tokens changes
188
+ max_tokens.change(
189
+ fn=lambda x: int(x) if x and str(x).strip() else None,
190
+ inputs=[max_tokens],
191
+ outputs=[max_tokens_state]
192
  )
193
+
194
+
195
+
196
+
197
 
198
  if __name__ == "__main__":
199
  demo.launch()