File size: 3,779 Bytes
836f663
1fdaf11
 
 
 
 
 
 
 
 
 
 
 
7c4fb72
1fdaf11
 
 
 
836f663
1fdaf11
836f663
1fdaf11
 
 
 
 
 
 
 
 
 
 
 
7c4fb72
1fdaf11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c4fb72
 
 
 
 
 
1fdaf11
 
7c4fb72
1fdaf11
7c4fb72
 
 
1fdaf11
7c4fb72
1fdaf11
 
 
 
 
 
 
 
 
836f663
 
1fdaf11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import argilla as rg
from datasets import load_dataset

from datasets import load_dataset

from src.dataset import (
    load_split,
    is_label,
    is_rating,
    is_int,
    is_float,
    get_feature_values,
    get_feature_labels,
    load_repo_id,
)


def define_dataset_setting(
    dataset_name, field_columns, question_columns, metadata_columns, argilla_space_url
):
    client = rg.Argilla(api_url=argilla_space_url, api_key="owner.apikey")
    split = load_split()

    fields, questions, metadata, vectors = [], [], [], []
    mapping = {}

    # Add field columns
    for column_name in field_columns:
        field_column_name = f"{column_name}_field"
        fields.append(rg.TextField(name=field_column_name))
        mapping[column_name] = field_column_name

    # Add question columns
    for question_type, question_column_name, column_name in question_columns:
        if question_type == "Label":
            values = get_feature_values(split, column_name)
            titles = get_feature_labels(split, column_name)
            labels = {str(l): feature for l, feature in zip(values, titles)}
            questions.append(rg.LabelQuestion(name=question_column_name, labels=labels))
        elif question_type == "Rating":
            values = get_feature_values(split, column_name)
            questions.append(
                rg.RatingQuestion(name=question_column_name, values=values)
            )
        else:
            questions.append(rg.TextQuestion(name=question_column_name))

        if column_name in mapping:
            column_name = f"{column_name}__"
        mapping[column_name] = question_column_name

    # Add metadata columns
    if not metadata_columns:
        metadata_columns = []

    for metadata_type, metadata_name, column_name in metadata_columns:
        if metadata_type == "Integer":
            metadata.append(rg.IntegerMetadataProperty(name=metadata_name))
        elif metadata_type == "Float":
            metadata.append(rg.FloatMetadataProperty(name=metadata_name))
        elif metadata_type == "Term":
            values = list(map(str, get_feature_values(split, column_name)))
            metadata.append(
                rg.TermsMetadataProperty(name=metadata_name, options=values)
            )
        if column_name in mapping:
            column_name = f"{column_name}__"
        mapping[column_name] = metadata_name

    settings = rg.Settings(fields=fields, questions=questions, metadata=metadata)

    dataset = rg.Dataset(name=dataset_name, settings=settings, client=client)

    if not dataset.exists():
        dataset.create()

    return str(settings.serialize()), mapping


def add_records(argilla_dataset_name, mapping, n_records, argilla_space_url):
    client = rg.Argilla(api_url=argilla_space_url, api_key="owner.apikey")
    split = load_split()
    df = load_dataset(load_repo_id())[split].take(n_records).to_pandas()
    dataset = client.datasets(argilla_dataset_name)
    questions = dataset.settings.questions
    for question in questions:
        if question.name in mapping.values():
            column_name = [k for k, v in mapping.items() if v == question.name][0]
            column_name = column_name.replace("__", "")
            if is_label(split, column_name):
                df[column_name] = df[column_name].apply(str)
    for source, target in mapping.items():
        if source.endswith("__"):
            df[source] = df[source.replace("__", "")]
    records = df.to_dict(orient="records")
    dataset.records.log(records, mapping=mapping)
    return f"{len(df)} records added with mapping {mapping}"


def delete_dataset(argilla_dataset_name):
    dataset = client.datasets(argilla_dataset_name)
    dataset.delete()
    return f"Dataset {argilla_dataset_name} deleted"