Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,13 +8,9 @@ import math
|
|
| 8 |
import ast
|
| 9 |
import logging
|
| 10 |
import numpy as np
|
| 11 |
-
|
| 12 |
-
from
|
| 13 |
from scipy import stats
|
| 14 |
-
from scipy.stats import entropy
|
| 15 |
-
from scipy.signal import correlate
|
| 16 |
-
import networkx as nx
|
| 17 |
-
from matplotlib.widgets import Cursor
|
| 18 |
|
| 19 |
# Set up logging
|
| 20 |
logging.basicConfig(level=logging.DEBUG)
|
|
@@ -63,8 +59,8 @@ def ensure_float(value):
|
|
| 63 |
return float(value)
|
| 64 |
return None
|
| 65 |
|
| 66 |
-
# Function to process and visualize log probs with
|
| 67 |
-
def visualize_logprobs(json_input, prob_filter=-1e9):
|
| 68 |
try:
|
| 69 |
# Parse the input (handles both JSON and Python dictionaries)
|
| 70 |
data = parse_input(json_input)
|
|
@@ -81,18 +77,11 @@ def visualize_logprobs(json_input, prob_filter=-1e9):
|
|
| 81 |
tokens = []
|
| 82 |
logprobs = []
|
| 83 |
top_alternatives = [] # List to store top 3 log probs (selected token + 2 alternatives)
|
| 84 |
-
token_types = [] # Simplified token type categorization
|
| 85 |
for entry in content:
|
| 86 |
logprob = ensure_float(entry.get("logprob", None))
|
| 87 |
if logprob is not None and math.isfinite(logprob) and logprob >= prob_filter:
|
| 88 |
tokens.append(entry["token"])
|
| 89 |
logprobs.append(logprob)
|
| 90 |
-
# Categorize token type (simple heuristic)
|
| 91 |
-
token = entry["token"].lower().strip()
|
| 92 |
-
if token in ["the", "a", "an"]: token_types.append("article")
|
| 93 |
-
elif token in ["is", "are", "was", "were"]: token_types.append("verb")
|
| 94 |
-
elif token in ["top", "so", "need", "figure"]: token_types.append("noun")
|
| 95 |
-
else: token_types.append("other")
|
| 96 |
# Get top_logprobs, default to empty dict if None
|
| 97 |
top_probs = entry.get("top_logprobs", {})
|
| 98 |
# Ensure all values in top_logprobs are floats
|
|
@@ -112,505 +101,76 @@ def visualize_logprobs(json_input, prob_filter=-1e9):
|
|
| 112 |
|
| 113 |
# Check if there's valid data after filtering
|
| 114 |
if not logprobs or not tokens:
|
| 115 |
-
return ("No finite log probabilities or tokens to visualize after filtering
|
| 116 |
-
|
| 117 |
-
#
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
if contains and abs(event.xdata - x) < 0.5 and abs(event.ydata - y) < 0.5:
|
| 140 |
-
token_annotations[i].set_text(tokens[i])
|
| 141 |
-
token_annotations[i].set_visible(True)
|
| 142 |
-
fig_main.canvas.draw_idle()
|
| 143 |
-
else:
|
| 144 |
-
token_annotations[i].set_visible(False)
|
| 145 |
-
fig_main.canvas.draw_idle()
|
| 146 |
-
|
| 147 |
-
fig_main.canvas.mpl_connect('button_press_event', on_click)
|
| 148 |
-
|
| 149 |
-
buf_main = io.BytesIO()
|
| 150 |
-
plt.savefig(buf_main, format="png", bbox_inches="tight", dpi=100)
|
| 151 |
-
buf_main.seek(0)
|
| 152 |
-
plt.close(fig_main)
|
| 153 |
-
return buf_main
|
| 154 |
-
|
| 155 |
-
# 2. K-Means Clustering of Log Probabilities
|
| 156 |
-
def create_cluster_plot():
|
| 157 |
-
if not logprobs:
|
| 158 |
-
raise ValueError("No data for clustering plot")
|
| 159 |
-
kmeans = KMeans(n_clusters=3, random_state=42)
|
| 160 |
-
cluster_labels = kmeans.fit_predict(np.array(logprobs).reshape(-1, 1))
|
| 161 |
-
fig_cluster, ax_cluster = plt.subplots(figsize=(10, 5))
|
| 162 |
-
scatter = ax_cluster.scatter(range(len(logprobs)), logprobs, c=cluster_labels, cmap='viridis')
|
| 163 |
-
ax_cluster.set_title("K-Means Clustering of Log Probabilities")
|
| 164 |
-
ax_cluster.set_xlabel("Token Position")
|
| 165 |
-
ax_cluster.set_ylabel("Log Probability")
|
| 166 |
-
ax_cluster.grid(True)
|
| 167 |
-
plt.colorbar(scatter, ax=ax_cluster, label="Cluster")
|
| 168 |
-
buf_cluster = io.BytesIO()
|
| 169 |
-
plt.savefig(buf_cluster, format="png", bbox_inches="tight", dpi=100)
|
| 170 |
-
buf_cluster.seek(0)
|
| 171 |
-
plt.close(fig_cluster)
|
| 172 |
-
return buf_cluster
|
| 173 |
-
|
| 174 |
-
# 3. Probability Drop Analysis
|
| 175 |
-
def create_drops_plot():
|
| 176 |
-
if not logprobs or len(logprobs) < 2:
|
| 177 |
-
raise ValueError("Insufficient data for probability drops")
|
| 178 |
-
drops = [logprobs[i+1] - logprobs[i] if i < len(logprobs)-1 else 0 for i in range(len(logprobs))]
|
| 179 |
-
fig_drops, ax_drops = plt.subplots(figsize=(10, 5))
|
| 180 |
-
ax_drops.bar(range(len(drops)), drops, color='red', alpha=0.5)
|
| 181 |
-
ax_drops.set_title("Significant Probability Drops")
|
| 182 |
-
ax_drops.set_xlabel("Token Position")
|
| 183 |
-
ax_drops.set_ylabel("Log Probability Drop")
|
| 184 |
-
ax_drops.grid(True)
|
| 185 |
-
buf_drops = io.BytesIO()
|
| 186 |
-
plt.savefig(buf_drops, format="png", bbox_inches="tight", dpi=100)
|
| 187 |
-
buf_drops.seek(0)
|
| 188 |
-
plt.close(fig_drops)
|
| 189 |
-
return buf_drops
|
| 190 |
-
|
| 191 |
-
# 4. N-Gram Analysis (Bigrams for simplicity)
|
| 192 |
-
def create_ngram_plot():
|
| 193 |
-
if not logprobs or len(logprobs) < 2:
|
| 194 |
-
raise ValueError("Insufficient data for N-gram analysis")
|
| 195 |
-
bigrams = [(tokens[i], tokens[i+1]) for i in range(len(tokens)-1)]
|
| 196 |
-
bigram_probs = [logprobs[i] + logprobs[i+1] for i in range(len(tokens)-1)]
|
| 197 |
-
fig_ngram, ax_ngram = plt.subplots(figsize=(10, 5))
|
| 198 |
-
ax_ngram.bar(range(len(bigrams)), bigram_probs, color='green')
|
| 199 |
-
ax_ngram.set_title("N-Gram (Bigrams) Probability Sum")
|
| 200 |
-
ax_ngram.set_xlabel("Bigram Position")
|
| 201 |
-
ax_ngram.set_ylabel("Sum of Log Probabilities")
|
| 202 |
-
ax_ngram.set_xticks(range(len(bigrams)))
|
| 203 |
-
ax_ngram.set_xticklabels([f"{b[0]}->{b[1]}" for b in bigrams], rotation=45, ha="right")
|
| 204 |
-
ax_ngram.grid(True)
|
| 205 |
-
buf_ngram = io.BytesIO()
|
| 206 |
-
plt.savefig(buf_ngram, format="png", bbox_inches="tight", dpi=100)
|
| 207 |
-
buf_ngram.seek(0)
|
| 208 |
-
plt.close(fig_ngram)
|
| 209 |
-
return buf_ngram
|
| 210 |
|
| 211 |
-
#
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
for i in range(len(
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
|
|
|
|
|
|
|
|
|
| 227 |
|
| 228 |
-
#
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
|
|
|
| 233 |
outliers = z_scores > 2 # Threshold for outliers
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
# 7. Autocorrelation
|
| 250 |
-
def create_autocorr_plot():
|
| 251 |
-
if not logprobs:
|
| 252 |
-
raise ValueError("No data for autocorrelation")
|
| 253 |
-
autocorr = correlate(logprobs, logprobs, mode='full')
|
| 254 |
-
autocorr = autocorr[len(autocorr)//2:] / len(logprobs) # Normalize
|
| 255 |
-
fig_autocorr, ax_autocorr = plt.subplots(figsize=(10, 5))
|
| 256 |
-
ax_autocorr.plot(range(len(autocorr)), autocorr, color='purple')
|
| 257 |
-
ax_autocorr.set_title("Autocorrelation of Log Probabilities")
|
| 258 |
-
ax_autocorr.set_xlabel("Lag")
|
| 259 |
-
ax_autocorr.set_ylabel("Autocorrelation")
|
| 260 |
-
ax_autocorr.grid(True)
|
| 261 |
-
buf_autocorr = io.BytesIO()
|
| 262 |
-
plt.savefig(buf_autocorr, format="png", bbox_inches="tight", dpi=100)
|
| 263 |
-
buf_autocorr.seek(0)
|
| 264 |
-
plt.close(fig_autocorr)
|
| 265 |
-
return buf_autocorr
|
| 266 |
-
|
| 267 |
-
# 8. Smoothing (Moving Average)
|
| 268 |
-
def create_smoothing_plot():
|
| 269 |
-
if not logprobs:
|
| 270 |
-
raise ValueError("No data for smoothing")
|
| 271 |
-
window_size = 3
|
| 272 |
-
moving_avg = np.convolve(logprobs, np.ones(window_size)/window_size, mode='valid')
|
| 273 |
-
fig_smoothing, ax_smoothing = plt.subplots(figsize=(10, 5))
|
| 274 |
-
ax_smoothing.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b", label="Original")
|
| 275 |
-
ax_smoothing.plot(range(window_size-1, len(logprobs)), moving_avg, color="orange", label="Moving Average")
|
| 276 |
-
ax_smoothing.set_title("Log Probabilities with Moving Average")
|
| 277 |
-
ax_smoothing.set_xlabel("Token Position")
|
| 278 |
-
ax_smoothing.set_ylabel("Log Probability")
|
| 279 |
-
ax_smoothing.grid(True)
|
| 280 |
-
ax_smoothing.legend()
|
| 281 |
-
ax_smoothing.set_xticks([]) # Hide X-axis labels
|
| 282 |
-
buf_smoothing = io.BytesIO()
|
| 283 |
-
plt.savefig(buf_smoothing, format="png", bbox_inches="tight", dpi=100)
|
| 284 |
-
buf_smoothing.seek(0)
|
| 285 |
-
plt.close(fig_smoothing)
|
| 286 |
-
return buf_smoothing
|
| 287 |
-
|
| 288 |
-
# 9. Uncertainty Propagation (Variance of Top Logprobs)
|
| 289 |
-
def create_uncertainty_plot():
|
| 290 |
-
if not logprobs or not top_alternatives:
|
| 291 |
-
raise ValueError("No data for uncertainty propagation")
|
| 292 |
-
variances = []
|
| 293 |
-
for probs in top_alternatives:
|
| 294 |
-
if len(probs) > 1:
|
| 295 |
-
values = [p[1] for p in probs]
|
| 296 |
-
variances.append(np.var(values))
|
| 297 |
-
else:
|
| 298 |
-
variances.append(0)
|
| 299 |
-
fig_uncertainty, ax_uncertainty = plt.subplots(figsize=(10, 5))
|
| 300 |
-
ax_uncertainty.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b", label="Log Prob")
|
| 301 |
-
ax_uncertainty.fill_between(range(len(logprobs)), [lp - v for lp, v in zip(logprobs, variances)],
|
| 302 |
-
[lp + v for lp, v in zip(logprobs, variances)], color='gray', alpha=0.3, label="Uncertainty")
|
| 303 |
-
ax_uncertainty.set_title("Log Probabilities with Uncertainty Propagation")
|
| 304 |
-
ax_uncertainty.set_xlabel("Token Position")
|
| 305 |
-
ax_uncertainty.set_ylabel("Log Probability")
|
| 306 |
-
ax_uncertainty.grid(True)
|
| 307 |
-
ax_uncertainty.legend()
|
| 308 |
-
ax_uncertainty.set_xticks([]) # Hide X-axis labels
|
| 309 |
-
buf_uncertainty = io.BytesIO()
|
| 310 |
-
plt.savefig(buf_uncertainty, format="png", bbox_inches="tight", dpi=100)
|
| 311 |
-
buf_uncertainty.seek(0)
|
| 312 |
-
plt.close(fig_uncertainty)
|
| 313 |
-
return buf_uncertainty
|
| 314 |
-
|
| 315 |
-
# 10. Correlation Heatmap
|
| 316 |
-
def create_corr_plot():
|
| 317 |
-
if not logprobs or len(logprobs) < 2:
|
| 318 |
-
raise ValueError("Insufficient data for correlation heatmap")
|
| 319 |
-
corr_matrix = np.corrcoef(logprobs, rowvar=False)
|
| 320 |
-
fig_corr, ax_corr = plt.subplots(figsize=(10, 5))
|
| 321 |
-
im = ax_corr.imshow(corr_matrix, cmap='coolwarm', interpolation='nearest')
|
| 322 |
-
ax_corr.set_title("Correlation of Log Probabilities Across Positions")
|
| 323 |
-
ax_corr.set_xlabel("Token Position")
|
| 324 |
-
ax_corr.set_ylabel("Token Position")
|
| 325 |
-
plt.colorbar(im, ax=ax_corr, label="Correlation")
|
| 326 |
-
buf_corr = io.BytesIO()
|
| 327 |
-
plt.savefig(buf_corr, format="png", bbox_inches="tight", dpi=100)
|
| 328 |
-
buf_corr.seek(0)
|
| 329 |
-
plt.close(fig_corr)
|
| 330 |
-
return buf_corr
|
| 331 |
-
|
| 332 |
-
# 11. Token Type Correlation
|
| 333 |
-
def create_type_plot():
|
| 334 |
-
if not logprobs or not token_types:
|
| 335 |
-
raise ValueError("No data for token type correlation")
|
| 336 |
-
type_probs = {t: [] for t in set(token_types)}
|
| 337 |
-
for t, p in zip(token_types, logprobs):
|
| 338 |
-
type_probs[t].append(p)
|
| 339 |
-
fig_type, ax_type = plt.subplots(figsize=(10, 5))
|
| 340 |
-
for t in type_probs:
|
| 341 |
-
ax_type.bar(t, np.mean(type_probs[t]), yerr=np.std(type_probs[t]), capsize=5, label=t)
|
| 342 |
-
ax_type.set_title("Average Log Probability by Token Type")
|
| 343 |
-
ax_type.set_xlabel("Token Type")
|
| 344 |
-
ax_type.set_ylabel("Average Log Probability")
|
| 345 |
-
ax_type.grid(True)
|
| 346 |
-
ax_type.legend()
|
| 347 |
-
buf_type = io.BytesIO()
|
| 348 |
-
plt.savefig(buf_type, format="png", bbox_inches="tight", dpi=100)
|
| 349 |
-
buf_type.seek(0)
|
| 350 |
-
plt.close(fig_type)
|
| 351 |
-
return buf_type
|
| 352 |
-
|
| 353 |
-
# 12. Token Embedding Similarity vs. Probability (Simulated)
|
| 354 |
-
def create_embed_plot():
|
| 355 |
-
if not logprobs or not tokens:
|
| 356 |
-
raise ValueError("No data for embedding similarity")
|
| 357 |
-
simulated_embeddings = np.random.rand(len(tokens), 2) # 2D embeddings
|
| 358 |
-
fig_embed, ax_embed = plt.subplots(figsize=(10, 5))
|
| 359 |
-
ax_embed.scatter(simulated_embeddings[:, 0], simulated_embeddings[:, 1], c=logprobs, cmap='viridis')
|
| 360 |
-
ax_embed.set_title("Token Embedding Similarity vs. Log Probability")
|
| 361 |
-
ax_embed.set_xlabel("Embedding Dimension 1")
|
| 362 |
-
ax_embed.set_ylabel("Embedding Dimension 2")
|
| 363 |
-
plt.colorbar(ax_embed.collections[0], ax=ax_embed, label="Log Probability")
|
| 364 |
-
buf_embed = io.BytesIO()
|
| 365 |
-
plt.savefig(buf_embed, format="png", bbox_inches="tight", dpi=100)
|
| 366 |
-
buf_embed.seek(0)
|
| 367 |
-
plt.close(fig_embed)
|
| 368 |
-
return buf_embed
|
| 369 |
-
|
| 370 |
-
# 13. Bayesian Inference (Simplified as Inferred Probabilities)
|
| 371 |
-
def create_bayesian_plot():
|
| 372 |
-
if not top_alternatives:
|
| 373 |
-
raise ValueError("No data for Bayesian inference")
|
| 374 |
-
entropies = [entropy([p[1] for p in probs], base=2) for probs in top_alternatives if len(probs) > 1]
|
| 375 |
-
fig_bayesian, ax_bayesian = plt.subplots(figsize=(10, 5))
|
| 376 |
-
ax_bayesian.bar(range(len(entropies)), entropies, color='orange')
|
| 377 |
-
ax_bayesian.set_title("Bayesian Inferred Uncertainty (Entropy)")
|
| 378 |
-
ax_bayesian.set_xlabel("Token Position")
|
| 379 |
-
ax_bayesian.set_ylabel("Entropy")
|
| 380 |
-
ax_bayesian.grid(True)
|
| 381 |
-
buf_bayesian = io.BytesIO()
|
| 382 |
-
plt.savefig(buf_bayesian, format="png", bbox_inches="tight", dpi=100)
|
| 383 |
-
buf_bayesian.seek(0)
|
| 384 |
-
plt.close(fig_bayesian)
|
| 385 |
-
return buf_bayesian
|
| 386 |
-
|
| 387 |
-
# 14. Graph-Based Analysis
|
| 388 |
-
def create_graph_plot():
|
| 389 |
-
if not tokens or len(tokens) < 2:
|
| 390 |
-
raise ValueError("Insufficient data for graph analysis")
|
| 391 |
-
G = nx.DiGraph()
|
| 392 |
-
for i in range(len(tokens)-1):
|
| 393 |
-
G.add_edge(tokens[i], tokens[i+1], weight=logprobs[i+1] - logprobs[i])
|
| 394 |
-
fig_graph, ax_graph = plt.subplots(figsize=(10, 5))
|
| 395 |
-
pos = nx.spring_layout(G)
|
| 396 |
-
nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=500, edge_color='gray', width=1, ax=ax_graph)
|
| 397 |
-
ax_graph.set_title("Graph of Token Transitions")
|
| 398 |
-
buf_graph = io.BytesIO()
|
| 399 |
-
plt.savefig(buf_graph, format="png", bbox_inches="tight", dpi=100)
|
| 400 |
-
buf_graph.seek(0)
|
| 401 |
-
plt.close(fig_graph)
|
| 402 |
-
return buf_graph
|
| 403 |
-
|
| 404 |
-
# 15. Dimensionality Reduction (t-SNE)
|
| 405 |
-
def create_tsne_plot():
|
| 406 |
-
if not logprobs or not top_alternatives:
|
| 407 |
-
raise ValueError("No data for t-SNE")
|
| 408 |
-
features = np.array([logprobs + [p[1] for p in alts[:2]] for logprobs, alts in zip([logprobs], top_alternatives)])
|
| 409 |
-
tsne = TSNE(n_components=2, random_state=42)
|
| 410 |
-
tsne_result = tsne.fit_transform(features.T)
|
| 411 |
-
fig_tsne, ax_tsne = plt.subplots(figsize=(10, 5))
|
| 412 |
-
scatter = ax_tsne.scatter(tsne_result[:, 0], tsne_result[:, 1], c=logprobs, cmap='viridis')
|
| 413 |
-
ax_tsne.set_title("t-SNE of Log Probabilities and Top Alternatives")
|
| 414 |
-
ax_tsne.set_xlabel("t-SNE Dimension 1")
|
| 415 |
-
ax_tsne.set_ylabel("t-SNE Dimension 2")
|
| 416 |
-
plt.colorbar(scatter, ax=ax_tsne, label="Log Probability")
|
| 417 |
-
buf_tsne = io.BytesIO()
|
| 418 |
-
plt.savefig(buf_tsne, format="png", bbox_inches="tight", dpi=100)
|
| 419 |
-
buf_tsne.seek(0)
|
| 420 |
-
plt.close(fig_tsne)
|
| 421 |
-
return buf_tsne
|
| 422 |
-
|
| 423 |
-
# 16. Interactive Heatmap
|
| 424 |
-
def create_heatmap_plot():
|
| 425 |
-
if not logprobs:
|
| 426 |
-
raise ValueError("No data for heatmap")
|
| 427 |
-
fig_heatmap, ax_heatmap = plt.subplots(figsize=(10, 5))
|
| 428 |
-
im = ax_heatmap.imshow([logprobs], cmap='viridis', aspect='auto')
|
| 429 |
-
ax_heatmap.set_title("Interactive Heatmap of Log Probabilities")
|
| 430 |
-
ax_heatmap.set_xlabel("Token Position")
|
| 431 |
-
ax_heatmap.set_ylabel("Probability Level")
|
| 432 |
-
plt.colorbar(im, ax=ax_heatmap, label="Log Probability")
|
| 433 |
-
buf_heatmap = io.BytesIO()
|
| 434 |
-
plt.savefig(buf_heatmap, format="png", bbox_inches="tight", dpi=100)
|
| 435 |
-
buf_heatmap.seek(0)
|
| 436 |
-
plt.close(fig_heatmap)
|
| 437 |
-
return buf_heatmap
|
| 438 |
-
|
| 439 |
-
# 17. Probability Distribution Plots (Box Plots for Top Logprobs)
|
| 440 |
-
def create_dist_plot():
|
| 441 |
-
if not logprobs or not top_alternatives:
|
| 442 |
-
raise ValueError("No data for probability distribution")
|
| 443 |
-
all_top_probs = [p[1] for alts in top_alternatives for p in alts]
|
| 444 |
-
fig_dist, ax_dist = plt.subplots(figsize=(10, 5))
|
| 445 |
-
ax_dist.boxplot([logprobs] + [p[1] for alts in top_alternatives for p in alts[:2]], labels=["Selected"] + ["Alt1", "Alt2"])
|
| 446 |
-
ax_dist.set_title("Probability Distribution of Top Tokens")
|
| 447 |
-
ax_dist.set_xlabel("Token Type")
|
| 448 |
-
ax_dist.set_ylabel("Log Probability")
|
| 449 |
-
ax_dist.grid(True)
|
| 450 |
-
buf_dist = io.BytesIO()
|
| 451 |
-
plt.savefig(buf_dist, format="png", bbox_inches="tight", dpi=100)
|
| 452 |
-
buf_dist.seek(0)
|
| 453 |
-
plt.close(fig_dist)
|
| 454 |
-
return buf_dist
|
| 455 |
-
|
| 456 |
-
# Create all plots safely
|
| 457 |
-
img_main_html = "Placeholder for Log Probability Plot"
|
| 458 |
-
img_cluster_html = "Placeholder for K-Means Clustering"
|
| 459 |
-
img_drops_html = "Placeholder for Probability Drops"
|
| 460 |
-
img_ngram_html = "Placeholder for N-Gram Analysis"
|
| 461 |
-
img_markov_html = "Placeholder for Markov Chain"
|
| 462 |
-
img_anomaly_html = "Placeholder for Anomaly Detection"
|
| 463 |
-
img_autocorr_html = "Placeholder for Autocorrelation"
|
| 464 |
-
img_smoothing_html = "Placeholder for Smoothing (Moving Average)"
|
| 465 |
-
img_uncertainty_html = "Placeholder for Uncertainty Propagation"
|
| 466 |
-
img_corr_html = "Placeholder for Correlation Heatmap"
|
| 467 |
-
img_type_html = "Placeholder for Token Type Correlation"
|
| 468 |
-
img_embed_html = "Placeholder for Embedding Similarity vs. Probability"
|
| 469 |
-
img_bayesian_html = "Placeholder for Bayesian Inference (Entropy)"
|
| 470 |
-
img_graph_html = "Placeholder for Graph of Token Transitions"
|
| 471 |
-
img_tsne_html = "Placeholder for t-SNE of Log Probabilities"
|
| 472 |
-
img_heatmap_html = "Placeholder for Interactive Heatmap"
|
| 473 |
-
img_dist_html = "Placeholder for Probability Distribution"
|
| 474 |
-
|
| 475 |
-
try:
|
| 476 |
-
buf_main = create_main_plot()
|
| 477 |
-
img_main_bytes = buf_main.getvalue()
|
| 478 |
-
img_main_base64 = base64.b64encode(img_main_bytes).decode("utf-8")
|
| 479 |
-
img_main_html = f'<img src="data:image/png;base64,{img_main_base64}" style="max-width: 100%; height: auto;">'
|
| 480 |
-
except Exception as e:
|
| 481 |
-
logger.error("Failed to create main plot: %s", str(e))
|
| 482 |
-
|
| 483 |
-
try:
|
| 484 |
-
buf_cluster = create_cluster_plot()
|
| 485 |
-
img_cluster_bytes = buf_cluster.getvalue()
|
| 486 |
-
img_cluster_base64 = base64.b64encode(img_cluster_bytes).decode("utf-8")
|
| 487 |
-
img_cluster_html = f'<img src="data:image/png;base64,{img_cluster_base64}" style="max-width: 100%; height: auto;">'
|
| 488 |
-
except Exception as e:
|
| 489 |
-
logger.error("Failed to create cluster plot: %s", str(e))
|
| 490 |
-
|
| 491 |
-
try:
|
| 492 |
-
buf_drops = create_drops_plot()
|
| 493 |
-
img_drops_bytes = buf_drops.getvalue()
|
| 494 |
-
img_drops_base64 = base64.b64encode(img_drops_bytes).decode("utf-8")
|
| 495 |
-
img_drops_html = f'<img src="data:image/png;base64,{img_drops_base64}" style="max-width: 100%; height: auto;">'
|
| 496 |
-
except Exception as e:
|
| 497 |
-
logger.error("Failed to create drops plot: %s", str(e))
|
| 498 |
-
|
| 499 |
-
try:
|
| 500 |
-
buf_ngram = create_ngram_plot()
|
| 501 |
-
img_ngram_bytes = buf_ngram.getvalue()
|
| 502 |
-
img_ngram_base64 = base64.b64encode(img_ngram_bytes).decode("utf-8")
|
| 503 |
-
img_ngram_html = f'<img src="data:image/png;base64,{img_ngram_base64}" style="max-width: 100%; height: auto;">'
|
| 504 |
-
except Exception as e:
|
| 505 |
-
logger.error("Failed to create ngram plot: %s", str(e))
|
| 506 |
-
|
| 507 |
-
try:
|
| 508 |
-
buf_markov = create_markov_plot()
|
| 509 |
-
img_markov_bytes = buf_markov.getvalue()
|
| 510 |
-
img_markov_base64 = base64.b64encode(img_markov_bytes).decode("utf-8")
|
| 511 |
-
img_markov_html = f'<img src="data:image/png;base64,{img_markov_base64}" style="max-width: 100%; height: auto;">'
|
| 512 |
-
except Exception as e:
|
| 513 |
-
logger.error("Failed to create markov plot: %s", str(e))
|
| 514 |
-
|
| 515 |
-
try:
|
| 516 |
-
buf_anomaly = create_anomaly_plot()
|
| 517 |
-
img_anomaly_bytes = buf_anomaly.getvalue()
|
| 518 |
-
img_anomaly_base64 = base64.b64encode(img_anomaly_bytes).decode("utf-8")
|
| 519 |
-
img_anomaly_html = f'<img src="data:image/png;base64,{img_anomaly_base64}" style="max-width: 100%; height: auto;">'
|
| 520 |
-
except Exception as e:
|
| 521 |
-
logger.error("Failed to create anomaly plot: %s", str(e))
|
| 522 |
-
|
| 523 |
-
try:
|
| 524 |
-
buf_autocorr = create_autocorr_plot()
|
| 525 |
-
img_autocorr_bytes = buf_autocorr.getvalue()
|
| 526 |
-
img_autocorr_base64 = base64.b64encode(img_autocorr_bytes).decode("utf-8")
|
| 527 |
-
img_autocorr_html = f'<img src="data:image/png;base64,{img_autocorr_base64}" style="max-width: 100%; height: auto;">'
|
| 528 |
-
except Exception as e:
|
| 529 |
-
logger.error("Failed to create autocorr plot: %s", str(e))
|
| 530 |
-
|
| 531 |
-
try:
|
| 532 |
-
buf_smoothing = create_smoothing_plot()
|
| 533 |
-
img_smoothing_bytes = buf_smoothing.getvalue()
|
| 534 |
-
img_smoothing_base64 = base64.b64encode(img_smoothing_bytes).decode("utf-8")
|
| 535 |
-
img_smoothing_html = f'<img src="data:image/png;base64,{img_smoothing_base64}" style="max-width: 100%; height: auto;">'
|
| 536 |
-
except Exception as e:
|
| 537 |
-
logger.error("Failed to create smoothing plot: %s", str(e))
|
| 538 |
-
|
| 539 |
-
try:
|
| 540 |
-
buf_uncertainty = create_uncertainty_plot()
|
| 541 |
-
img_uncertainty_bytes = buf_uncertainty.getvalue()
|
| 542 |
-
img_uncertainty_base64 = base64.b64encode(img_uncertainty_bytes).decode("utf-8")
|
| 543 |
-
img_uncertainty_html = f'<img src="data:image/png;base64,{img_uncertainty_base64}" style="max-width: 100%; height: auto;">'
|
| 544 |
-
except Exception as e:
|
| 545 |
-
logger.error("Failed to create uncertainty plot: %s", str(e))
|
| 546 |
-
|
| 547 |
-
try:
|
| 548 |
-
buf_corr = create_corr_plot()
|
| 549 |
-
img_corr_bytes = buf_corr.getvalue()
|
| 550 |
-
img_corr_base64 = base64.b64encode(img_corr_bytes).decode("utf-8")
|
| 551 |
-
img_corr_html = f'<img src="data:image/png;base64,{img_corr_base64}" style="max-width: 100%; height: auto;">'
|
| 552 |
-
except Exception as e:
|
| 553 |
-
logger.error("Failed to create correlation plot: %s", str(e))
|
| 554 |
-
|
| 555 |
-
try:
|
| 556 |
-
buf_type = create_type_plot()
|
| 557 |
-
img_type_bytes = buf_type.getvalue()
|
| 558 |
-
img_type_base64 = base64.b64encode(img_type_bytes).decode("utf-8")
|
| 559 |
-
img_type_html = f'<img src="data:image/png;base64,{img_type_base64}" style="max-width: 100%; height: auto;">'
|
| 560 |
-
except Exception as e:
|
| 561 |
-
logger.error("Failed to create type plot: %s", str(e))
|
| 562 |
-
|
| 563 |
-
try:
|
| 564 |
-
buf_embed = create_embed_plot()
|
| 565 |
-
img_embed_bytes = buf_embed.getvalue()
|
| 566 |
-
img_embed_base64 = base64.b64encode(img_embed_bytes).decode("utf-8")
|
| 567 |
-
img_embed_html = f'<img src="data:image/png;base64,{img_embed_base64}" style="max-width: 100%; height: auto;">'
|
| 568 |
-
except Exception as e:
|
| 569 |
-
logger.error("Failed to create embed plot: %s", str(e))
|
| 570 |
-
|
| 571 |
-
try:
|
| 572 |
-
buf_bayesian = create_bayesian_plot()
|
| 573 |
-
img_bayesian_bytes = buf_bayesian.getvalue()
|
| 574 |
-
img_bayesian_base64 = base64.b64encode(img_bayesian_bytes).decode("utf-8")
|
| 575 |
-
img_bayesian_html = f'<img src="data:image/png;base64,{img_bayesian_base64}" style="max-width: 100%; height: auto;">'
|
| 576 |
-
except Exception as e:
|
| 577 |
-
logger.error("Failed to create bayesian plot: %s", str(e))
|
| 578 |
-
|
| 579 |
-
try:
|
| 580 |
-
buf_graph = create_graph_plot()
|
| 581 |
-
img_graph_bytes = buf_graph.getvalue()
|
| 582 |
-
img_graph_base64 = base64.b64encode(img_graph_bytes).decode("utf-8")
|
| 583 |
-
img_graph_html = f'<img src="data:image/png;base64,{img_graph_base64}" style="max-width: 100%; height: auto;">'
|
| 584 |
-
except Exception as e:
|
| 585 |
-
logger.error("Failed to create graph plot: %s", str(e))
|
| 586 |
-
|
| 587 |
-
try:
|
| 588 |
-
buf_tsne = create_tsne_plot()
|
| 589 |
-
img_tsne_bytes = buf_tsne.getvalue()
|
| 590 |
-
img_tsne_base64 = base64.b64encode(img_tsne_bytes).decode("utf-8")
|
| 591 |
-
img_tsne_html = f'<img src="data:image/png;base64,{img_tsne_base64}" style="max-width: 100%; height: auto;">'
|
| 592 |
-
except Exception as e:
|
| 593 |
-
logger.error("Failed to create tsne plot: %s", str(e))
|
| 594 |
-
|
| 595 |
-
try:
|
| 596 |
-
buf_heatmap = create_heatmap_plot()
|
| 597 |
-
img_heatmap_bytes = buf_heatmap.getvalue()
|
| 598 |
-
img_heatmap_base64 = base64.b64encode(img_heatmap_bytes).decode("utf-8")
|
| 599 |
-
img_heatmap_html = f'<img src="data:image/png;base64,{img_heatmap_base64}" style="max-width: 100%; height: auto;">'
|
| 600 |
-
except Exception as e:
|
| 601 |
-
logger.error("Failed to create heatmap plot: %s", str(e))
|
| 602 |
-
|
| 603 |
-
try:
|
| 604 |
-
buf_dist = create_dist_plot()
|
| 605 |
-
img_dist_bytes = buf_dist.getvalue()
|
| 606 |
-
img_dist_base64 = base64.b64encode(img_dist_bytes).decode("utf-8")
|
| 607 |
-
img_dist_html = f'<img src="data:image/png;base64,{img_dist_base64}" style="max-width: 100%; height: auto;">'
|
| 608 |
-
except Exception as e:
|
| 609 |
-
logger.error("Failed to create distribution plot: %s", str(e))
|
| 610 |
|
| 611 |
-
# Create DataFrame for the table
|
| 612 |
table_data = []
|
| 613 |
-
for i, entry in enumerate(content):
|
| 614 |
logprob = ensure_float(entry.get("logprob", None))
|
| 615 |
if logprob is not None and math.isfinite(logprob) and logprob >= prob_filter and "top_logprobs" in entry and entry["top_logprobs"] is not None:
|
| 616 |
token = entry["token"]
|
|
@@ -645,75 +205,52 @@ def visualize_logprobs(json_input, prob_filter=-1e9):
|
|
| 645 |
else None
|
| 646 |
)
|
| 647 |
|
| 648 |
-
# Generate colored text
|
| 649 |
-
if
|
| 650 |
-
min_logprob = min(
|
| 651 |
-
max_logprob = max(
|
| 652 |
if max_logprob == min_logprob:
|
| 653 |
-
normalized_probs = [0.5] * len(
|
| 654 |
else:
|
| 655 |
normalized_probs = [
|
| 656 |
-
(lp - min_logprob) / (max_logprob - min_logprob) for lp in
|
| 657 |
]
|
| 658 |
|
| 659 |
colored_text = ""
|
| 660 |
-
for i, (token, norm_prob) in enumerate(zip(
|
| 661 |
r = int(255 * (1 - norm_prob)) # Red for low confidence
|
| 662 |
g = int(255 * norm_prob) # Green for high confidence
|
| 663 |
b = 0
|
| 664 |
color = f"rgb({r}, {g}, {b})"
|
| 665 |
colored_text += f'<span style="color: {color}; font-weight: bold;">{token}</span>'
|
| 666 |
-
if i < len(
|
| 667 |
colored_text += " "
|
| 668 |
colored_text_html = f"<p>{colored_text}</p>"
|
| 669 |
else:
|
| 670 |
colored_text_html = "No finite log probabilities to display."
|
| 671 |
|
| 672 |
-
# Top 3 Token Log Probabilities
|
| 673 |
alt_viz_html = ""
|
| 674 |
-
if
|
| 675 |
-
alt_viz_html = "<h3>Top 3 Token Log Probabilities</h3><ul>"
|
| 676 |
-
for i, (token, probs) in enumerate(zip(
|
| 677 |
-
alt_viz_html += f"<li>Position {i} (Token: {token}):<br>"
|
| 678 |
for tok, prob in probs:
|
| 679 |
alt_viz_html += f"{tok}: {prob:.4f}<br>"
|
| 680 |
alt_viz_html += "</li>"
|
| 681 |
alt_viz_html += "</ul>"
|
| 682 |
|
| 683 |
-
|
| 684 |
-
def buffer_to_html(buf):
|
| 685 |
-
if isinstance(buf, str): # If it's an error message
|
| 686 |
-
return buf
|
| 687 |
-
img_bytes = buf.getvalue()
|
| 688 |
-
img_base64 = base64.b64encode(img_bytes).decode("utf-8")
|
| 689 |
-
return f'<img src="data:image/png;base64,{img_base64}" style="max-width: 100%; height: auto;">'
|
| 690 |
-
|
| 691 |
-
return (
|
| 692 |
-
buffer_to_html(img_main_html), df, colored_text_html, alt_viz_html,
|
| 693 |
-
buffer_to_html(img_cluster_html), buffer_to_html(img_drops_html), buffer_to_html(img_ngram_html),
|
| 694 |
-
buffer_to_html(img_markov_html), buffer_to_html(img_anomaly_html), buffer_to_html(img_autocorr_html),
|
| 695 |
-
buffer_to_html(img_smoothing_html), buffer_to_html(img_uncertainty_html), buffer_to_html(img_corr_html),
|
| 696 |
-
buffer_to_html(img_type_html), buffer_to_html(img_embed_html), buffer_to_html(img_bayesian_html),
|
| 697 |
-
buffer_to_html(img_graph_html), buffer_to_html(img_tsne_html), buffer_to_html(img_heatmap_html),
|
| 698 |
-
buffer_to_html(img_dist_html)
|
| 699 |
-
)
|
| 700 |
|
| 701 |
except Exception as e:
|
| 702 |
logger.error("Visualization failed: %s", str(e))
|
| 703 |
-
return (
|
| 704 |
-
f"Error: {str(e)}", None, None, None, "Placeholder for K-Means Clustering", "Placeholder for Probability Drops",
|
| 705 |
-
"Placeholder for N-Gram Analysis", "Placeholder for Markov Chain", "Placeholder for Anomaly Detection",
|
| 706 |
-
"Placeholder for Autocorrelation", "Placeholder for Smoothing (Moving Average)", "Placeholder for Uncertainty Propagation",
|
| 707 |
-
"Placeholder for Correlation Heatmap", "Placeholder for Token Type Correlation", "Placeholder for Embedding Similarity vs. Probability",
|
| 708 |
-
"Placeholder for Bayesian Inference (Entropy)", "Placeholder for Graph of Token Transitions", "Placeholder for t-SNE of Log Probabilities",
|
| 709 |
-
"Placeholder for Interactive Heatmap", "Placeholder for Probability Distribution"
|
| 710 |
-
)
|
| 711 |
|
| 712 |
-
# Gradio interface with
|
| 713 |
with gr.Blocks(title="Log Probability Visualizer") as app:
|
| 714 |
gr.Markdown("# Log Probability Visualizer")
|
| 715 |
gr.Markdown(
|
| 716 |
-
"Paste your JSON or Python dictionary log prob data below to visualize the tokens and their probabilities. Use the filter
|
| 717 |
)
|
| 718 |
|
| 719 |
with gr.Row():
|
|
@@ -725,61 +262,54 @@ with gr.Blocks(title="Log Probability Visualizer") as app:
|
|
| 725 |
)
|
| 726 |
with gr.Column(scale=1):
|
| 727 |
prob_filter = gr.Slider(minimum=-1e9, maximum=0, value=-1e9, label="Log Probability Filter (≥)")
|
|
|
|
|
|
|
| 728 |
|
| 729 |
-
with gr.
|
| 730 |
-
|
| 731 |
-
|
| 732 |
-
plot_output = gr.HTML(label="Log Probability Plot (Click for Tokens)", value="Placeholder for Log Probability Plot")
|
| 733 |
-
table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives", value=None)
|
| 734 |
-
with gr.Row():
|
| 735 |
-
text_output = gr.HTML(label="Colored Text (Confidence Visualization)", value="Placeholder for Colored Text (Confidence Visualization)")
|
| 736 |
-
alt_viz_output = gr.HTML(label="Top 3 Token Log Probabilities", value="Placeholder for Top 3 Token Log Probabilities")
|
| 737 |
-
|
| 738 |
-
with gr.Tab("Clustering & Patterns"):
|
| 739 |
-
with gr.Row():
|
| 740 |
-
cluster_output = gr.HTML(label="K-Means Clustering", value="Placeholder for K-Means Clustering")
|
| 741 |
-
drops_output = gr.HTML(label="Probability Drops", value="Placeholder for Probability Drops")
|
| 742 |
-
with gr.Row():
|
| 743 |
-
ngram_output = gr.HTML(label="N-Gram Analysis", value="Placeholder for N-Gram Analysis")
|
| 744 |
-
markov_output = gr.HTML(label="Markov Chain", value="Placeholder for Markov Chain")
|
| 745 |
-
|
| 746 |
-
with gr.Tab("Time Series & Anomalies"):
|
| 747 |
-
with gr.Row():
|
| 748 |
-
anomaly_output = gr.HTML(label="Anomaly Detection", value="Placeholder for Anomaly Detection")
|
| 749 |
-
autocorr_output = gr.HTML(label="Autocorrelation", value="Placeholder for Autocorrelation")
|
| 750 |
-
with gr.Row():
|
| 751 |
-
smoothing_output = gr.HTML(label="Smoothing (Moving Average)", value="Placeholder for Smoothing (Moving Average)")
|
| 752 |
-
uncertainty_output = gr.HTML(label="Uncertainty Propagation", value="Placeholder for Uncertainty Propagation")
|
| 753 |
-
|
| 754 |
-
with gr.Tab("Correlation & Types"):
|
| 755 |
-
with gr.Row():
|
| 756 |
-
corr_output = gr.HTML(label="Correlation Heatmap", value="Placeholder for Correlation Heatmap")
|
| 757 |
-
type_output = gr.HTML(label="Token Type Correlation", value="Placeholder for Token Type Correlation")
|
| 758 |
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
bayesian_output = gr.HTML(label="Bayesian Inference (Entropy)", value="Placeholder for Bayesian Inference (Entropy)")
|
| 763 |
-
with gr.Row():
|
| 764 |
-
graph_output = gr.HTML(label="Graph of Token Transitions", value="Placeholder for Graph of Token Transitions")
|
| 765 |
-
tsne_output = gr.HTML(label="t-SNE of Log Probabilities", value="Placeholder for t-SNE of Log Probabilities")
|
| 766 |
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
dist_output = gr.HTML(label="Probability Distribution", value="Placeholder for Probability Distribution")
|
| 771 |
|
| 772 |
btn = gr.Button("Visualize")
|
| 773 |
btn.click(
|
| 774 |
fn=visualize_logprobs,
|
| 775 |
-
inputs=[json_input, prob_filter],
|
| 776 |
-
outputs=[
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
|
| 782 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 783 |
)
|
| 784 |
|
| 785 |
app.launch()
|
|
|
|
| 8 |
import ast
|
| 9 |
import logging
|
| 10 |
import numpy as np
|
| 11 |
+
import plotly.graph_objects as go
|
| 12 |
+
from plotly.subplots import make_subplots
|
| 13 |
from scipy import stats
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
# Set up logging
|
| 16 |
logging.basicConfig(level=logging.DEBUG)
|
|
|
|
| 59 |
return float(value)
|
| 60 |
return None
|
| 61 |
|
| 62 |
+
# Function to process and visualize log probs with interactive Plotly plots
|
| 63 |
+
def visualize_logprobs(json_input, prob_filter=-1e9, page_size=50, page=0):
|
| 64 |
try:
|
| 65 |
# Parse the input (handles both JSON and Python dictionaries)
|
| 66 |
data = parse_input(json_input)
|
|
|
|
| 77 |
tokens = []
|
| 78 |
logprobs = []
|
| 79 |
top_alternatives = [] # List to store top 3 log probs (selected token + 2 alternatives)
|
|
|
|
| 80 |
for entry in content:
|
| 81 |
logprob = ensure_float(entry.get("logprob", None))
|
| 82 |
if logprob is not None and math.isfinite(logprob) and logprob >= prob_filter:
|
| 83 |
tokens.append(entry["token"])
|
| 84 |
logprobs.append(logprob)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
# Get top_logprobs, default to empty dict if None
|
| 86 |
top_probs = entry.get("top_logprobs", {})
|
| 87 |
# Ensure all values in top_logprobs are floats
|
|
|
|
| 101 |
|
| 102 |
# Check if there's valid data after filtering
|
| 103 |
if not logprobs or not tokens:
|
| 104 |
+
return (gr.update(value="No finite log probabilities or tokens to visualize after filtering"), None, None, None, 1, 0)
|
| 105 |
+
|
| 106 |
+
# Paginate data for large inputs
|
| 107 |
+
total_pages = max(1, (len(logprobs) + page_size - 1) // page_size)
|
| 108 |
+
start_idx = page * page_size
|
| 109 |
+
end_idx = min((page + 1) * page_size, len(logprobs))
|
| 110 |
+
paginated_tokens = tokens[start_idx:end_idx]
|
| 111 |
+
paginated_logprobs = logprobs[start_idx:end_idx]
|
| 112 |
+
paginated_alternatives = top_alternatives[start_idx:end_idx] if top_alternatives else []
|
| 113 |
+
|
| 114 |
+
# 1. Main Log Probability Plot (Interactive Plotly)
|
| 115 |
+
main_fig = go.Figure()
|
| 116 |
+
main_fig.add_trace(go.Scatter(x=list(range(len(paginated_logprobs))), y=paginated_logprobs, mode='markers+lines', name='Log Prob', marker=dict(color='blue')))
|
| 117 |
+
main_fig.update_layout(
|
| 118 |
+
title="Log Probabilities of Generated Tokens",
|
| 119 |
+
xaxis_title="Token Position",
|
| 120 |
+
yaxis_title="Log Probability",
|
| 121 |
+
hovermode="closest",
|
| 122 |
+
clickmode='event+select'
|
| 123 |
+
)
|
| 124 |
+
main_fig.update_traces(
|
| 125 |
+
customdata=[f"Token: {tok}, Log Prob: {prob:.4f}, Position: {i+start_idx}" for i, (tok, prob) in enumerate(zip(paginated_tokens, paginated_logprobs))],
|
| 126 |
+
hovertemplate='<b>%{customdata}</b><extra></extra>'
|
| 127 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
+
# 2. Probability Drop Analysis (Interactive Plotly)
|
| 130 |
+
if len(paginated_logprobs) < 2:
|
| 131 |
+
drops_fig = go.Figure()
|
| 132 |
+
drops_fig.add_trace(go.Bar(x=list(range(len(paginated_logprobs)-1)), y=[0], name='Drop', marker_color='red'))
|
| 133 |
+
else:
|
| 134 |
+
drops = [paginated_logprobs[i+1] - paginated_logprobs[i] for i in range(len(paginated_logprobs)-1)]
|
| 135 |
+
drops_fig = go.Figure()
|
| 136 |
+
drops_fig.add_trace(go.Bar(x=list(range(len(drops))), y=drops, name='Drop', marker_color='red'))
|
| 137 |
+
drops_fig.update_layout(
|
| 138 |
+
title="Significant Probability Drops",
|
| 139 |
+
xaxis_title="Token Position",
|
| 140 |
+
yaxis_title="Log Probability Drop",
|
| 141 |
+
hovermode="closest",
|
| 142 |
+
clickmode='event+select'
|
| 143 |
+
)
|
| 144 |
+
drops_fig.update_traces(
|
| 145 |
+
customdata=[f"Drop: {drop:.4f}, From: {paginated_tokens[i]} to {paginated_tokens[i+1]}, Position: {i+start_idx}" for i, drop in enumerate(drops)],
|
| 146 |
+
hovertemplate='<b>%{customdata}</b><extra></extra>'
|
| 147 |
+
)
|
| 148 |
|
| 149 |
+
# 3. Anomaly Detection (Interactive Plotly)
|
| 150 |
+
if not paginated_logprobs:
|
| 151 |
+
anomaly_fig = go.Figure()
|
| 152 |
+
anomaly_fig.add_trace(go.Scatter(x=[], y=[], mode='markers+lines', name='Log Prob', marker_color='blue'))
|
| 153 |
+
else:
|
| 154 |
+
z_scores = np.abs(stats.zscore(paginated_logprobs))
|
| 155 |
outliers = z_scores > 2 # Threshold for outliers
|
| 156 |
+
anomaly_fig = go.Figure()
|
| 157 |
+
anomaly_fig.add_trace(go.Scatter(x=list(range(len(paginated_logprobs))), y=paginated_logprobs, mode='markers+lines', name='Log Prob', marker_color='blue'))
|
| 158 |
+
anomaly_fig.add_trace(go.Scatter(x=np.where(outliers)[0], y=[paginated_logprobs[i] for i in np.where(outliers)[0]], mode='markers', name='Outliers', marker_color='red'))
|
| 159 |
+
anomaly_fig.update_layout(
|
| 160 |
+
title="Log Probabilities with Outliers",
|
| 161 |
+
xaxis_title="Token Position",
|
| 162 |
+
yaxis_title="Log Probability",
|
| 163 |
+
hovermode="closest",
|
| 164 |
+
clickmode='event+select'
|
| 165 |
+
)
|
| 166 |
+
anomaly_fig.update_traces(
|
| 167 |
+
customdata=[f"Token: {tok}, Log Prob: {prob:.4f}, Position: {i+start_idx}, Outlier: {out}" for i, (tok, prob, out) in enumerate(zip(paginated_tokens, paginated_logprobs, outliers))],
|
| 168 |
+
hovertemplate='<b>%{customdata}</b><extra></extra>'
|
| 169 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
+
# Create DataFrame for the table (paginated)
|
| 172 |
table_data = []
|
| 173 |
+
for i, entry in enumerate(content[start_idx:end_idx]):
|
| 174 |
logprob = ensure_float(entry.get("logprob", None))
|
| 175 |
if logprob is not None and math.isfinite(logprob) and logprob >= prob_filter and "top_logprobs" in entry and entry["top_logprobs"] is not None:
|
| 176 |
token = entry["token"]
|
|
|
|
| 205 |
else None
|
| 206 |
)
|
| 207 |
|
| 208 |
+
# Generate colored text (paginated)
|
| 209 |
+
if paginated_logprobs:
|
| 210 |
+
min_logprob = min(paginated_logprobs)
|
| 211 |
+
max_logprob = max(paginated_logprobs)
|
| 212 |
if max_logprob == min_logprob:
|
| 213 |
+
normalized_probs = [0.5] * len(paginated_logprobs)
|
| 214 |
else:
|
| 215 |
normalized_probs = [
|
| 216 |
+
(lp - min_logprob) / (max_logprob - min_logprob) for lp in paginated_logprobs
|
| 217 |
]
|
| 218 |
|
| 219 |
colored_text = ""
|
| 220 |
+
for i, (token, norm_prob) in enumerate(zip(paginated_tokens, normalized_probs)):
|
| 221 |
r = int(255 * (1 - norm_prob)) # Red for low confidence
|
| 222 |
g = int(255 * norm_prob) # Green for high confidence
|
| 223 |
b = 0
|
| 224 |
color = f"rgb({r}, {g}, {b})"
|
| 225 |
colored_text += f'<span style="color: {color}; font-weight: bold;">{token}</span>'
|
| 226 |
+
if i < len(paginated_tokens) - 1:
|
| 227 |
colored_text += " "
|
| 228 |
colored_text_html = f"<p>{colored_text}</p>"
|
| 229 |
else:
|
| 230 |
colored_text_html = "No finite log probabilities to display."
|
| 231 |
|
| 232 |
+
# Top 3 Token Log Probabilities (paginated)
|
| 233 |
alt_viz_html = ""
|
| 234 |
+
if paginated_logprobs and paginated_alternatives:
|
| 235 |
+
alt_viz_html = "<h3>Top 3 Token Log Probabilities (Paginated)</h3><ul>"
|
| 236 |
+
for i, (token, probs) in enumerate(zip(paginated_tokens, paginated_alternatives)):
|
| 237 |
+
alt_viz_html += f"<li>Position {i+start_idx} (Token: {token}):<br>"
|
| 238 |
for tok, prob in probs:
|
| 239 |
alt_viz_html += f"{tok}: {prob:.4f}<br>"
|
| 240 |
alt_viz_html += "</li>"
|
| 241 |
alt_viz_html += "</ul>"
|
| 242 |
|
| 243 |
+
return (main_fig, df, colored_text_html, alt_viz_html, drops_fig, anomaly_fig, total_pages, page)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
except Exception as e:
|
| 246 |
logger.error("Visualization failed: %s", str(e))
|
| 247 |
+
return (gr.update(value=f"Error: {str(e)}"), None, "No finite log probabilities to display.", None, gr.update(value="No data for probability drops."), gr.update(value="No data for anomalies."), 1, 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
|
| 249 |
+
# Gradio interface with interactive layout and pagination
|
| 250 |
with gr.Blocks(title="Log Probability Visualizer") as app:
|
| 251 |
gr.Markdown("# Log Probability Visualizer")
|
| 252 |
gr.Markdown(
|
| 253 |
+
"Paste your JSON or Python dictionary log prob data below to visualize the tokens and their probabilities. Use the filter and pagination to navigate large inputs."
|
| 254 |
)
|
| 255 |
|
| 256 |
with gr.Row():
|
|
|
|
| 262 |
)
|
| 263 |
with gr.Column(scale=1):
|
| 264 |
prob_filter = gr.Slider(minimum=-1e9, maximum=0, value=-1e9, label="Log Probability Filter (≥)")
|
| 265 |
+
page_size = gr.Number(value=50, label="Page Size", precision=0, minimum=10, maximum=1000)
|
| 266 |
+
page = gr.Number(value=0, label="Page Number", precision=0, minimum=0)
|
| 267 |
|
| 268 |
+
with gr.Row():
|
| 269 |
+
plot_output = gr.Plot(label="Log Probability Plot (Click for Tokens)")
|
| 270 |
+
drops_output = gr.Plot(label="Probability Drops (Click for Details)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
|
| 272 |
+
with gr.Row():
|
| 273 |
+
anomaly_output = gr.Plot(label="Anomaly Detection (Click for Details)")
|
| 274 |
+
table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
|
| 276 |
+
with gr.Row():
|
| 277 |
+
text_output = gr.HTML(label="Colored Text (Confidence Visualization)")
|
| 278 |
+
alt_viz_output = gr.HTML(label="Top 3 Token Log Probabilities")
|
|
|
|
| 279 |
|
| 280 |
btn = gr.Button("Visualize")
|
| 281 |
btn.click(
|
| 282 |
fn=visualize_logprobs,
|
| 283 |
+
inputs=[json_input, prob_filter, page_size, page],
|
| 284 |
+
outputs=[plot_output, table_output, text_output, alt_viz_output, drops_output, anomaly_output, gr.State(visible=False), gr.State(visible=False)],
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
# Pagination controls
|
| 288 |
+
with gr.Row():
|
| 289 |
+
prev_btn = gr.Button("Previous Page")
|
| 290 |
+
next_btn = gr.Button("Next Page")
|
| 291 |
+
total_pages_output = gr.Number(label="Total Pages", interactive=False, visible=False)
|
| 292 |
+
current_page_output = gr.Number(label="Current Page", interactive=False, visible=False)
|
| 293 |
+
|
| 294 |
+
def update_page(json_input, prob_filter, page_size, current_page, action):
|
| 295 |
+
if action == "prev" and current_page > 0:
|
| 296 |
+
current_page -= 1
|
| 297 |
+
elif action == "next":
|
| 298 |
+
total_pages = visualize_logprobs(json_input, prob_filter, page_size, 0)[6] # Get total pages
|
| 299 |
+
if current_page < total_pages - 1:
|
| 300 |
+
current_page += 1
|
| 301 |
+
return gr.update(value=current_page), gr.update(value=total_pages)
|
| 302 |
+
|
| 303 |
+
prev_btn.click(
|
| 304 |
+
fn=lambda *args: update_page(*args, "prev"),
|
| 305 |
+
inputs=[json_input, prob_filter, page_size, page, gr.State()],
|
| 306 |
+
outputs=[page, total_pages_output]
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
next_btn.click(
|
| 310 |
+
fn=lambda *args: update_page(*args, "next"),
|
| 311 |
+
inputs=[json_input, prob_filter, page_size, page, gr.State()],
|
| 312 |
+
outputs=[page, total_pages_output]
|
| 313 |
)
|
| 314 |
|
| 315 |
app.launch()
|