WillHeld commited on
Commit
b0dd995
·
verified ·
1 Parent(s): c185875

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +180 -24
app.py CHANGED
@@ -1,36 +1,192 @@
1
  import spaces
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import gradio as gr
 
 
 
 
 
 
 
4
 
5
- checkpoint = "WillHeld/olmo-raccoon"
6
  device = "cuda"
7
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
8
  model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  @spaces.GPU(duration=120)
11
- def predict(message, history, temperature, top_p):
 
 
 
 
 
12
  history.append({"role": "user", "content": message})
13
  input_text = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
14
- inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
15
- outputs = model.generate(
16
- inputs,
17
- max_new_tokens=1024,
18
- temperature=float(temperature),
19
- top_p=float(top_p),
20
- do_sample=True
21
- )
22
- decoded = tokenizer.decode(outputs[0])
23
- response = decoded.split("<|assistant|>")[-1]
24
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  with gr.Blocks() as demo:
27
- chatbot = gr.ChatInterface(
28
- predict,
29
- additional_inputs=[
30
- gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
31
- gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P")
32
- ],
33
- type="messages"
34
- )
35
-
36
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import spaces
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3
  import gradio as gr
4
+ from threading import Thread
5
+ from datetime import datetime, timedelta
6
+ from datasets import Dataset
7
+ from huggingface_hub import HfApi, login
8
+ import uuid
9
+ import os
10
+ import time
11
 
12
+ checkpoint = "WillHeld/soft-raccoon"
13
  device = "cuda"
14
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
15
  model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
16
 
17
+ # Dataset configuration
18
+ DATASET_NAME = "WillHeld/soft-raccoon-conversations" # Change to your HF username
19
+ PUSH_TO_HUB = True # Set to False if you just want to save locally first
20
+
21
+ # Time-based storage settings
22
+ SAVE_INTERVAL_MINUTES = 5 # Save every 5 minutes
23
+ last_save_time = datetime.now()
24
+
25
+ # Initialize storage for conversations
26
+ conversations = []
27
+
28
+ # Login to Huggingface Hub (you'll need to set HF_TOKEN env var or use login())
29
+ # Uncomment the below line to login with your token
30
+ login(token=os.environ.get("HF_TOKEN"))
31
+
32
+ def save_to_dataset():
33
+ """Save the current conversations to a HuggingFace dataset"""
34
+ if not conversations:
35
+ return None
36
+
37
+ # Convert conversations to dataset format
38
+ dataset_dict = {
39
+ "conversation_id": [],
40
+ "timestamp": [],
41
+ "messages": [],
42
+ "metadata": []
43
+ }
44
+
45
+ for conv in conversations:
46
+ dataset_dict["conversation_id"].append(conv["conversation_id"])
47
+ dataset_dict["timestamp"].append(conv["timestamp"])
48
+ dataset_dict["messages"].append(conv["messages"])
49
+ dataset_dict["metadata"].append(conv["metadata"])
50
+
51
+ # Create dataset
52
+ dataset = Dataset.from_dict(dataset_dict)
53
+
54
+ if PUSH_TO_HUB:
55
+ try:
56
+ # Push to hub - will create the dataset if it doesn't exist
57
+ dataset.push_to_hub(DATASET_NAME)
58
+ print(f"Successfully pushed {len(conversations)} conversations to {DATASET_NAME}")
59
+ except Exception as e:
60
+ print(f"Error pushing to hub: {e}")
61
+ # Save locally as fallback
62
+ dataset.save_to_disk("local_dataset")
63
+ else:
64
+ # Save locally
65
+ dataset.save_to_disk("local_dataset")
66
+ print(f"Saved {len(conversations)} conversations locally to 'local_dataset'")
67
+
68
+ return dataset
69
+
70
  @spaces.GPU(duration=120)
71
+ def predict(message, history, temperature, top_p, conversation_id=None):
72
+ # Create or retrieve conversation ID for tracking
73
+ if conversation_id is None:
74
+ conversation_id = str(uuid.uuid4())
75
+
76
+ # Update history with user message
77
  history.append({"role": "user", "content": message})
78
  input_text = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
79
+ inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
80
+
81
+ # Create a streamer
82
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
83
+
84
+ # Set up generation parameters
85
+ generation_kwargs = {
86
+ "input_ids": inputs,
87
+ "max_new_tokens": 1024,
88
+ "temperature": float(temperature),
89
+ "top_p": float(top_p),
90
+ "do_sample": True,
91
+ "streamer": streamer,
92
+ }
93
+
94
+ # Run generation in a separate thread
95
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
96
+ thread.start()
97
+
98
+ # Yield from the streamer as tokens are generated
99
+ partial_text = ""
100
+ for new_text in streamer:
101
+ partial_text += new_text
102
+ yield partial_text
103
+
104
+ # After generation completes, update history with assistant response
105
+ history.append({"role": "assistant", "content": partial_text})
106
+
107
+ # Store conversation data
108
+ # Check if we already have this conversation
109
+ existing_conv = next((c for c in conversations if c["conversation_id"] == conversation_id), None)
110
+
111
+ if existing_conv:
112
+ # Update existing conversation
113
+ existing_conv["messages"] = history
114
+ existing_conv["metadata"]["last_updated"] = datetime.now().isoformat()
115
+ else:
116
+ # Create new conversation record
117
+ conversations.append({
118
+ "conversation_id": conversation_id,
119
+ "timestamp": datetime.now().isoformat(),
120
+ "messages": history,
121
+ "metadata": {
122
+ "model": checkpoint,
123
+ "temperature": temperature,
124
+ "top_p": top_p,
125
+ "last_updated": datetime.now().isoformat()
126
+ }
127
+ })
128
+
129
+ # Check if it's time to save based on elapsed time
130
+ global last_save_time
131
+ current_time = datetime.now()
132
+ if current_time - last_save_time > timedelta(minutes=SAVE_INTERVAL_MINUTES):
133
+ save_to_dataset()
134
+ last_save_time = current_time
135
+
136
+ return partial_text
137
+
138
+ def save_dataset_button():
139
+ """Manually save the current dataset"""
140
+ dataset = save_to_dataset()
141
+ if dataset:
142
+ return f"Saved {len(conversations)} conversations to dataset."
143
+ return "No conversations to save."
144
 
145
  with gr.Blocks() as demo:
146
+ conversation_id = gr.State(None)
147
+
148
+ with gr.Row():
149
+ with gr.Column(scale=3):
150
+ chatbot = gr.ChatInterface(
151
+ predict,
152
+ additional_inputs=[
153
+ gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
154
+ gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P"),
155
+ conversation_id
156
+ ],
157
+ type="messages"
158
+ )
159
+
160
+ with gr.Column(scale=1):
161
+ with gr.Group():
162
+ gr.Markdown("### Dataset Controls")
163
+ save_button = gr.Button("Save conversations to dataset")
164
+ save_output = gr.Textbox(label="Save Status")
165
+
166
+ # Display current conversation count
167
+ conversation_count = gr.Number(value=lambda: len(conversations),
168
+ label="Total Conversations",
169
+ interactive=False)
170
+
171
+ # Display time until next auto-save
172
+ next_save_time = gr.Textbox(label="Next Auto-Save",
173
+ value=lambda: f"In {SAVE_INTERVAL_MINUTES - (datetime.now() - last_save_time).seconds // 60} minutes")
174
+ refresh_button = gr.Button("Refresh Stats")
175
+
176
+ # Set up event handlers
177
+ save_button.click(save_dataset_button, outputs=save_output)
178
+
179
+ def refresh_stats():
180
+ mins_until_save = SAVE_INTERVAL_MINUTES - (datetime.now() - last_save_time).seconds // 60
181
+ return len(conversations), f"In {mins_until_save} minutes"
182
+
183
+ refresh_button.click(refresh_stats, outputs=[conversation_count, next_save_time])
184
+
185
+ # Save on shutdown
186
+ demo.on_close(save_to_dataset)
187
+
188
+ # Set up periodic UI refresh (every 60 seconds)
189
+ gr.Timer(60, lambda: None).start()
190
+
191
+ if __name__ == "__main__":
192
+ demo.launch()