import argilla as rg from datasets import load_dataset from datasets import load_dataset from src.dataset import ( load_split, is_label, is_rating, is_int, is_float, get_feature_values, get_feature_labels, load_repo_id, ) def define_dataset_setting( dataset_name, field_columns, question_columns, metadata_columns, argilla_space_url ): client = rg.Argilla(api_url=argilla_space_url, api_key="owner.apikey") split = load_split() fields, questions, metadata, vectors = [], [], [], [] mapping = {} # Add field columns for column_name in field_columns: field_column_name = f"{column_name}_field" fields.append(rg.TextField(name=field_column_name)) mapping[column_name] = field_column_name # Add question columns for question_type, question_column_name, column_name in question_columns: if question_type == "Label": values = get_feature_values(split, column_name) titles = get_feature_labels(split, column_name) labels = {str(l): feature for l, feature in zip(values, titles)} questions.append(rg.LabelQuestion(name=question_column_name, labels=labels)) elif question_type == "Rating": values = get_feature_values(split, column_name) questions.append( rg.RatingQuestion(name=question_column_name, values=values) ) else: questions.append(rg.TextQuestion(name=question_column_name)) if column_name in mapping: column_name = f"{column_name}__" mapping[column_name] = question_column_name # Add metadata columns if not metadata_columns: metadata_columns = [] for metadata_type, metadata_name, column_name in metadata_columns: if metadata_type == "Integer": metadata.append(rg.IntegerMetadataProperty(name=metadata_name)) elif metadata_type == "Float": metadata.append(rg.FloatMetadataProperty(name=metadata_name)) elif metadata_type == "Term": values = list(map(str, get_feature_values(split, column_name))) metadata.append( rg.TermsMetadataProperty(name=metadata_name, options=values) ) if column_name in mapping: column_name = f"{column_name}__" mapping[column_name] = metadata_name settings = rg.Settings(fields=fields, questions=questions, metadata=metadata) dataset = rg.Dataset(name=dataset_name, settings=settings, client=client) if not dataset.exists(): dataset.create() return str(settings.serialize()), mapping def add_records(argilla_dataset_name, mapping, n_records, argilla_space_url): client = rg.Argilla(api_url=argilla_space_url, api_key="owner.apikey") split = load_split() df = load_dataset(load_repo_id())[split].take(n_records).to_pandas() dataset = client.datasets(argilla_dataset_name) questions = dataset.settings.questions for question in questions: if question.name in mapping.values(): column_name = [k for k, v in mapping.items() if v == question.name][0] column_name = column_name.replace("__", "") if is_label(split, column_name): df[column_name] = df[column_name].apply(str) for source, target in mapping.items(): if source.endswith("__"): df[source] = df[source.replace("__", "")] records = df.to_dict(orient="records") dataset.records.log(records, mapping=mapping) return f"{len(df)} records added with mapping {mapping}" def delete_dataset(argilla_dataset_name): dataset = client.datasets(argilla_dataset_name) dataset.delete() return f"Dataset {argilla_dataset_name} deleted"