Spaces:
Running
Running
File size: 5,889 Bytes
2018b94 73d9a01 2018b94 9867f8a 73d9a01 f80827c 73d9a01 2018b94 9867f8a 73d9a01 2018b94 73d9a01 2018b94 73d9a01 2018b94 f80827c 73d9a01 f80827c 2018b94 73d9a01 2018b94 73d9a01 2018b94 f80827c 2018b94 f80827c 73d9a01 2018b94 73d9a01 2018b94 73d9a01 2018b94 73d9a01 2018b94 73d9a01 2018b94 73d9a01 2018b94 73d9a01 f80827c 2018b94 73d9a01 2018b94 73d9a01 2018b94 73d9a01 2018b94 73d9a01 f80827c 73d9a01 f80827c 73d9a01 2018b94 73d9a01 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import math
import gradio as gr
from datasets import concatenate_datasets
from huggingface_hub import HfApi
from huggingface_hub.errors import HFValidationError
from requests.exceptions import HTTPError
from transformer_ranker.datacleaner import DatasetCleaner, TaskCategory
from transformer_ranker.embedder import Embedder
BANNER = """
# 🌐 TransformerRanker ⚡️
Find the best language model for your downstream task.
Load a dataset, pick models from the 🤗 Hub, and rank them by **transferability**.
[](https://github.com/flairNLP/transformer-ranker)
[](https://opensource.org/licenses/MIT)
[](https://pypi.org/project/transformer-ranker/)
[](https://github.com/flairNLP/transformer-ranker/blob/main/docs/01-walkthrough.md)
Developed at [Humboldt University of Berlin](https://www.informatik.hu-berlin.de/en/forschung-en/gebiete/ml-en/).
"""
FOOTER = """
**Note:** Quick CPU-only demo.
**Built by** [@lukasgarbas](https://huggingface.co/lukasgarbas) & [@plonerma](https://huggingface.co/plonerma)
**Questions?** Open a [GitHub issue](https://github.com/flairNLP/transformer-ranker/issues) 🔫
"""
CSS = """
.gradio-container {
max-width: 800px;
margin: auto;
}
.banner {
text-align: center;
}
.banner img {
display: inline-block;
}
"""
UNSET = "-"
hf_api = HfApi()
preprocessing = DatasetCleaner()
def validate_dataset(dataset_name):
"""Quick look dataset existence on Hub."""
try:
hf_api.dataset_info(dataset_name)
return gr.update(interactive=True)
except (HTTPError, HFValidationError):
return gr.update(value="Load data", interactive=False)
def preprocess_dataset(dataset):
"""Detect text/label columns and task type."""
data = concatenate_datasets(list(dataset.values()))
sample_size = len(data)
try:
text_column = preprocessing._find_column(data, "text column")
except ValueError:
gr.Warning("Text column not auto-detected — select in settings.")
text_column = UNSET
try:
label_column = preprocessing._find_column(data, "label column")
except ValueError:
gr.Warning("Label column not auto-detected — select in settings.")
label_column = UNSET
task_category = UNSET
if label_column != UNSET:
try:
task_category = preprocessing._find_task_category(data, label_column)
except ValueError:
gr.Warning(
"Task category not auto-detected — framework supports classification, regression."
)
return (
gr.update(
value=task_category,
choices=[str(t) for t in TaskCategory],
interactive=True,
),
gr.update(
value=text_column, choices=data.column_names, interactive=True
),
gr.update(
value=UNSET, choices=[UNSET, *data.column_names], interactive=True
),
gr.update(
value=label_column, choices=data.column_names, interactive=True
),
sample_size,
)
def compute_ratio(num_samples_to_use, num_samples):
if num_samples > 0:
return num_samples_to_use / num_samples
else:
return 0.0
def ensure_dataset_is_loaded(dataset, text_column, label_column, task_category):
if dataset and text_column != UNSET and label_column != UNSET and task_category != UNSET:
return gr.update(interactive=True)
else:
return gr.update(interactive=False)
# apply monkey patch to enable callbacks
_old_embed = Embedder.embed
def _new_embed(embedder, sentences, batch_size: int = 32, **kw):
if embedder.tracker is not None:
embedder.tracker.update_num_batches(math.ceil(len(sentences) / batch_size))
return _old_embed(embedder, sentences, batch_size=batch_size, **kw)
Embedder.embed = _new_embed
_old_embed_batch = Embedder.embed_batch
def _new_embed_batch(embedder, *args, **kw):
r = _old_embed_batch(embedder, *args, **kw)
if embedder.tracker is not None:
embedder.tracker.update_batch_complete()
return r
Embedder.embed_batch = _new_embed_batch
_old_init = Embedder.__init__
def _new_init(embedder, *args, tracker=None, **kw):
_old_init(embedder, *args, **kw)
embedder.tracker = tracker
Embedder.__init__ = _new_init
class EmbeddingProgressTracker:
def __init__(self, *, progress, model_names):
self.model_names = model_names
self.progress_bar = progress
@property
def total(self):
return len(self.model_names)
def __enter__(self):
self.progress_bar = gr.Progress(track_tqdm=False)
self.current_model = -1
self.batches_complete = 0
self.batches_total = None
return self
def __exit__(self, typ, value, tb):
if typ is None:
self.progress_bar(1.0, desc="Done")
else:
self.progress_bar(1.0, desc="Error")
return False
def update_num_batches(self, total):
self.current_model += 1
self.batches_complete = 0
self.batches_total = total
self.update_bar()
def update_batch_complete(self):
self.batches_complete += 1
self.update_bar()
def update_bar(self):
i = self.current_model
description = f"Running {self.model_names[i]} ({i + 1} / {self.total})"
progress = i / self.total
if self.batches_total is not None:
progress += (self.batches_complete / self.batches_total) / self.total
self.progress_bar(progress=progress, desc=description)
|