Spaces:
Running
Running
import gradio as gr | |
from transformers import AutoTokenizer, T5Tokenizer | |
import asyncio | |
import threading | |
from concurrent.futures import ThreadPoolExecutor | |
import time | |
# Fixed list of custom tokenizers (left) | |
TOKENIZER_CUSTOM = { | |
"T5 Extended": "alakxender/dhivehi-T5-tokenizer-extended", | |
"RoBERTa Extended": "alakxender/dhivehi-roberta-tokenizer-extended", | |
"Google mT5": "google/mt5-base", | |
"Google mT5 Extended": "alakxender/mt5-dhivehi-tokenizer-extended", | |
"DeBERTa Extended": "alakxender/deberta-dhivehi-tokenizer-extended", | |
"XLM-RoBERTa Extended": "alakxender/xlmr-dhivehi-tokenizer-extended", | |
"Bert Extended": "alakxender/bert-dhivehi-tokenizer-extended", | |
"Bert Extended Fast": "alakxender/bert-fast-dhivehi-tokenizer-extended" | |
} | |
# Suggested stock model paths for the right input | |
SUGGESTED_STOCK_PATHS = [ | |
"google/flan-t5-base", | |
"t5-small", | |
"t5-base", | |
"t5-large", | |
"google/mt5-base", | |
"microsoft/trocr-base-handwritten", | |
"microsoft/trocr-base-printed", | |
"microsoft/deberta-v3-base" | |
"xlm-roberta-base", | |
"naver-clova-ix/donut-base", | |
"bert-base-multilingual-cased" | |
] | |
# Cache for loaded tokenizers to avoid reloading | |
tokenizer_cache = {} | |
# Load tokenizer with fallback to slow T5 | |
def load_tokenizer(tokenizer_path): | |
if tokenizer_path in tokenizer_cache: | |
return tokenizer_cache[tokenizer_path] | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) | |
tokenizer_cache[tokenizer_path] = tokenizer | |
return tokenizer | |
except Exception: | |
if "t5" in tokenizer_path.lower() or "mt5" in tokenizer_path.lower(): | |
tokenizer = T5Tokenizer.from_pretrained(tokenizer_path) | |
tokenizer_cache[tokenizer_path] = tokenizer | |
return tokenizer | |
raise | |
# Tokenize and decode with enhanced visualization | |
def tokenize_display(text, tokenizer_path): | |
try: | |
tokenizer = load_tokenizer(tokenizer_path) | |
encoding = tokenizer(text, return_offsets_mapping=False, add_special_tokens=True) | |
tokens = tokenizer.convert_ids_to_tokens(encoding.input_ids) | |
ids = encoding.input_ids | |
decoded = tokenizer.decode(ids, skip_special_tokens=False) | |
return tokens, ids, decoded | |
except Exception as e: | |
return [f"[ERROR] {str(e)}"], [], "[Tokenizer Error]" | |
def create_token_visualization(tokens, ids): | |
"""Create a visual representation of tokens with colors and spacing""" | |
if not tokens or not ids: | |
return "❌ No tokens to display" | |
# Create colored token blocks | |
token_blocks = [] | |
colors = ["🟦", "🟩", "🟨", "🟪", "🟧", "🟫"] | |
for i, (token, token_id) in enumerate(zip(tokens, ids)): | |
color = colors[i % len(colors)] | |
# Clean token display (remove special characters for better readability) | |
clean_token = token.replace('▁', '_').replace('</s>', '[END]').replace('<s>', '[START]') | |
token_blocks.append(f"{color} `{clean_token}` ({token_id})") | |
return " ".join(token_blocks) | |
# Async comparison with progress updates | |
def compare_side_by_side_with_progress(dv_text, en_text, custom_label, stock_path, progress=gr.Progress()): | |
def format_block(title, tokenizer_path): | |
dv_tokens, dv_ids, dv_decoded = tokenize_display(dv_text, tokenizer_path) | |
en_tokens, en_ids, en_decoded = tokenize_display(en_text, tokenizer_path) | |
return f"""\ | |
## 🔤 {title} | |
### 🈁 Dhivehi: `{dv_text}` | |
**🎯 Tokens:** {len(dv_tokens) if dv_ids else 'N/A'} tokens | |
{create_token_visualization(dv_tokens, dv_ids)} | |
**🔢 Token IDs:** `{dv_ids if dv_ids else '[ERROR]'}` | |
**🔄 Decoded:** `{dv_decoded}` | |
--- | |
### 🇬🇧 English: `{en_text}` | |
**🎯 Tokens:** {len(en_tokens) if en_ids else 'N/A'} tokens | |
{create_token_visualization(en_tokens, en_ids)} | |
**🔢 Token IDs:** `{en_ids if en_ids else '[ERROR]'}` | |
**🔄 Decoded:** `{en_decoded}` | |
--- | |
""" | |
try: | |
custom_path = TOKENIZER_CUSTOM[custom_label] | |
except KeyError: | |
return "[ERROR] Invalid custom tokenizer selected", "" | |
# Show loading progress | |
progress(0.1, desc="Loading custom tokenizer...") | |
# Load custom tokenizer | |
try: | |
custom_result = format_block("Custom Tokenizer", custom_path) | |
progress(0.5, desc="Custom tokenizer loaded. Loading stock tokenizer...") | |
except Exception as e: | |
custom_result = f"[ERROR] Failed to load custom tokenizer: {str(e)}" | |
progress(0.5, desc="Custom tokenizer failed. Loading stock tokenizer...") | |
# Load stock tokenizer | |
try: | |
stock_result = format_block("Stock Tokenizer", stock_path) | |
progress(1.0, desc="Complete!") | |
except Exception as e: | |
stock_result = f"[ERROR] Failed to load stock tokenizer: {str(e)}" | |
progress(1.0, desc="Complete with errors!") | |
return custom_result, stock_result | |
# Non-blocking comparison function | |
def compare_tokenizers_async(dv_text, en_text, custom_label, stock_path): | |
# Return immediate loading message | |
loading_msg = """ | |
## ⏳ Loading Tokenizer... | |
🚀 **Status:** Downloading and initializing tokenizer... | |
*This may take a moment for first-time downloads* | |
""" | |
# Use ThreadPoolExecutor for non-blocking execution | |
with ThreadPoolExecutor(max_workers=2) as executor: | |
future = executor.submit(compare_side_by_side_with_progress, dv_text, en_text, custom_label, stock_path) | |
# Return loading state first | |
yield loading_msg, loading_msg | |
# Then return actual results | |
try: | |
custom_result, stock_result = future.result(timeout=120) # 2 minute timeout | |
yield custom_result, stock_result | |
except Exception as e: | |
error_msg = f"## ❌ Error\n\n**Failed to load tokenizers:** {str(e)}" | |
yield error_msg, error_msg | |
# Gradio UI with better UX | |
with gr.Blocks(title="Dhivehi Tokenizer Comparison Tool", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("## 🧠 Dhivehi Tokenizer Comparison") | |
gr.Markdown("Compare how different tokenizers process Dhivehi and English input text.") | |
with gr.Row(): | |
dhivehi_text = gr.Textbox( | |
label="Dhivehi Text", | |
lines=2, | |
value="އީދުގެ ހަރަކާތްތައް ފެށުމަށް މިރޭ ހުޅުމާލޭގައި އީދު މަޅި ރޯކުރަނީ", | |
rtl=True, | |
placeholder="Enter Dhivehi text here..." | |
) | |
english_text = gr.Textbox( | |
label="English Text", | |
lines=2, | |
value="The quick brown fox jumps over the lazy dog", | |
placeholder="Enter English text here..." | |
) | |
with gr.Row(): | |
tokenizer_a = gr.Dropdown( | |
label="Select Custom Tokenizer", | |
choices=list(TOKENIZER_CUSTOM.keys()), | |
value="T5 Extended", | |
info="Pre-trained Dhivehi tokenizers (or paste a path)" | |
) | |
tokenizer_b = gr.Dropdown( | |
label="Enter or Select Stock Tokenizer Path", | |
choices=SUGGESTED_STOCK_PATHS, | |
value="google/flan-t5-base", | |
allow_custom_value=True, | |
info="Standard HuggingFace tokenizers (or paste a path)" | |
) | |
compare_button = gr.Button("🔄 Compare Tokenizers", variant="primary", size="lg") | |
with gr.Row(): | |
output_custom = gr.Markdown(label="Custom Tokenizer Output", height=400) | |
output_stock = gr.Markdown(label="Stock Tokenizer Output", height=400) | |
# Use the non-blocking function | |
compare_button.click( | |
compare_side_by_side_with_progress, | |
inputs=[dhivehi_text, english_text, tokenizer_a, tokenizer_b], | |
outputs=[output_custom, output_stock], | |
show_progress=True | |
) | |
if __name__ == "__main__": | |
demo.launch() |