File size: 5,889 Bytes
2018b94
 
73d9a01
 
 
 
 
 
 
 
2018b94
9867f8a
 
 
 
 
 
 
 
 
 
 
73d9a01
 
 
f80827c
 
 
73d9a01
 
 
2018b94
 
 
 
9867f8a
 
 
 
 
 
73d9a01
 
2018b94
73d9a01
 
2018b94
73d9a01
 
2018b94
f80827c
73d9a01
f80827c
2018b94
73d9a01
2018b94
73d9a01
 
2018b94
f80827c
2018b94
f80827c
73d9a01
 
2018b94
73d9a01
2018b94
 
73d9a01
 
2018b94
73d9a01
2018b94
 
73d9a01
2018b94
 
73d9a01
2018b94
73d9a01
f80827c
 
 
2018b94
73d9a01
 
 
 
 
 
 
2018b94
73d9a01
 
2018b94
73d9a01
 
2018b94
73d9a01
f80827c
73d9a01
f80827c
73d9a01
 
 
 
 
 
 
 
2018b94
 
 
 
 
 
 
 
73d9a01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import math

import gradio as gr
from datasets import concatenate_datasets
from huggingface_hub import HfApi
from huggingface_hub.errors import HFValidationError
from requests.exceptions import HTTPError
from transformer_ranker.datacleaner import DatasetCleaner, TaskCategory
from transformer_ranker.embedder import Embedder

BANNER = """
# 🌐 TransformerRanker ⚡️

Find the best language model for your downstream task.  
Load a dataset, pick models from the 🤗 Hub, and rank them by **transferability**.  

[![repository](https://img.shields.io/badge/Code%20Repo-black?style=flat&logo=github)](https://github.com/flairNLP/transformer-ranker)
[![license](https://img.shields.io/badge/License-MIT-brightgreen?style=flat)](https://opensource.org/licenses/MIT)
[![package](https://img.shields.io/badge/Package-orange?style=flat&logo=python)](https://pypi.org/project/transformer-ranker/)
[![tutorials](https://img.shields.io/badge/Tutorials-blue?style=flat&logo=readthedocs&logoColor=white)](https://github.com/flairNLP/transformer-ranker/blob/main/docs/01-walkthrough.md)

Developed at [Humboldt University of Berlin](https://www.informatik.hu-berlin.de/en/forschung-en/gebiete/ml-en/).
"""

FOOTER = """
**Note:** Quick CPU-only demo.   
**Built by** [@lukasgarbas](https://huggingface.co/lukasgarbas) & [@plonerma](https://huggingface.co/plonerma)   
**Questions?** Open a [GitHub issue](https://github.com/flairNLP/transformer-ranker/issues) 🔫
"""

CSS = """
.gradio-container {
    max-width: 800px;
    margin: auto;
}
.banner {
    text-align: center;
}
.banner img {
    display: inline-block;
}
"""

UNSET = "-"

hf_api = HfApi()
preprocessing = DatasetCleaner()


def validate_dataset(dataset_name):
    """Quick look dataset existence on Hub."""
    try:
        hf_api.dataset_info(dataset_name)
        return gr.update(interactive=True)
    except (HTTPError, HFValidationError):
        return gr.update(value="Load data", interactive=False)


def preprocess_dataset(dataset):
    """Detect text/label columns and task type."""
    data = concatenate_datasets(list(dataset.values()))
    sample_size = len(data)

    try:
        text_column = preprocessing._find_column(data, "text column")
    except ValueError:
        gr.Warning("Text column not auto-detected — select in settings.")
        text_column = UNSET

    try:
        label_column = preprocessing._find_column(data, "label column")
    except ValueError:
        gr.Warning("Label column not auto-detected — select in settings.")
        label_column = UNSET

    task_category = UNSET
    if label_column != UNSET:
        try:
            task_category = preprocessing._find_task_category(data, label_column)
        except ValueError:
            gr.Warning(
                "Task category not auto-detected — framework supports classification, regression."
            )

    return (
        gr.update(
            value=task_category,
            choices=[str(t) for t in TaskCategory],
            interactive=True,
        ),
        gr.update(
            value=text_column, choices=data.column_names, interactive=True
        ),
        gr.update(
            value=UNSET, choices=[UNSET, *data.column_names], interactive=True
        ),
        gr.update(
            value=label_column, choices=data.column_names, interactive=True
        ),
        sample_size,
    )


def compute_ratio(num_samples_to_use, num_samples):
    if num_samples > 0:
        return num_samples_to_use / num_samples
    else:
        return 0.0


def ensure_dataset_is_loaded(dataset, text_column, label_column, task_category):
    if dataset and text_column != UNSET and label_column != UNSET and task_category != UNSET:
        return gr.update(interactive=True)
    else:
        return gr.update(interactive=False)


# apply monkey patch to enable callbacks
_old_embed = Embedder.embed

def _new_embed(embedder, sentences, batch_size: int = 32, **kw):
    if embedder.tracker is not None:
        embedder.tracker.update_num_batches(math.ceil(len(sentences) / batch_size))

    return _old_embed(embedder, sentences, batch_size=batch_size, **kw)

Embedder.embed = _new_embed

_old_embed_batch = Embedder.embed_batch

def _new_embed_batch(embedder, *args, **kw):
    r = _old_embed_batch(embedder, *args, **kw)
    if embedder.tracker is not None:
        embedder.tracker.update_batch_complete()
    return r

Embedder.embed_batch = _new_embed_batch

_old_init = Embedder.__init__

def _new_init(embedder, *args, tracker=None, **kw):
    _old_init(embedder, *args, **kw)
    embedder.tracker = tracker

Embedder.__init__ = _new_init


class EmbeddingProgressTracker:
    def __init__(self, *, progress, model_names):
        self.model_names = model_names
        self.progress_bar = progress

    @property
    def total(self):
        return len(self.model_names)

    def __enter__(self):
        self.progress_bar = gr.Progress(track_tqdm=False)
        self.current_model = -1
        self.batches_complete = 0
        self.batches_total = None
        return self

    def __exit__(self, typ, value, tb):
        if typ is None:
            self.progress_bar(1.0, desc="Done")
        else:
            self.progress_bar(1.0, desc="Error")
        return False

    def update_num_batches(self, total):
        self.current_model += 1
        self.batches_complete = 0
        self.batches_total = total
        self.update_bar()

    def update_batch_complete(self):
        self.batches_complete += 1
        self.update_bar()

    def update_bar(self):
        i = self.current_model
        description = f"Running {self.model_names[i]} ({i + 1} / {self.total})"

        progress = i / self.total
        if self.batches_total is not None:
            progress += (self.batches_complete / self.batches_total) / self.total

        self.progress_bar(progress=progress, desc=description)