File size: 3,379 Bytes
07423df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import logging
from typing import Any, Dict

import numpy as np
import pandas as pd
import torch

from llm_studio.src.datasets.text_causal_language_modeling_ds import (
    CustomDataset as TextCausalLanguageModelingCustomDataset,
)
from llm_studio.src.utils.exceptions import LLMDataException

logger = logging.getLogger(__name__)


class CustomDataset(TextCausalLanguageModelingCustomDataset):
    def __init__(self, df: pd.DataFrame, cfg: Any, mode: str = "train"):
        super().__init__(df=df, cfg=cfg, mode=mode)
        check_for_non_int_answers(cfg, df)
        self.answers_int = df[cfg.dataset.answer_column].astype(int).values.tolist()

        if 1 < cfg.dataset.num_classes <= max(self.answers_int):
            raise LLMDataException(
                "Number of classes is smaller than max label "
                f"{max(self.answers_int)}. Please increase the setting accordingly."
            )
        elif cfg.dataset.num_classes == 1 and max(self.answers_int) > 1:
            raise LLMDataException(
                "For binary classification, max label should be 1 but is "
                f"{max(self.answers_int)}."
            )
        if min(self.answers_int) < 0:
            raise LLMDataException(
                "Labels should be non-negative but min label is "
                f"{min(self.answers_int)}."
            )
        if (
            min(self.answers_int) != 0
            or max(self.answers_int) != len(set(self.answers_int)) - 1
        ):
            logger.warning(
                "Labels should start at 0 and be continuous but are "
                f"{sorted(set(self.answers_int))}."
            )

        if cfg.dataset.parent_id_column != "None":
            raise LLMDataException(
                "Parent ID column is not supported for classification datasets."
            )

    def __getitem__(self, idx: int) -> Dict:
        sample = super().__getitem__(idx)
        sample["class_label"] = self.answers_int[idx]
        return sample

    def postprocess_output(self, cfg, df: pd.DataFrame, output: Dict) -> Dict:
        output["logits"] = output["logits"].float()
        if cfg.dataset.num_classes == 1:
            preds = output["logits"]
            preds = np.array((preds > 0.0)).astype(int).astype(str).reshape(-1)
        else:
            preds = output["logits"]
            preds = (
                np.array(torch.argmax(preds, dim=1))  # type: ignore[arg-type]
                .astype(str)
                .reshape(-1)
            )
        output["predicted_text"] = preds
        return super().postprocess_output(cfg, df, output)

    def clean_output(self, output, cfg):
        return output

    @classmethod
    def sanity_check(cls, df: pd.DataFrame, cfg: Any, mode: str = "train"):
        # TODO: Dataset import in UI is currently using text_causal_language_modeling_ds
        check_for_non_int_answers(cfg, df)


def check_for_non_int_answers(cfg, df):
    answers_non_int = [
        x for x in df[cfg.dataset.answer_column].values if not is_castable_to_int(x)
    ]
    if len(answers_non_int) > 0:
        raise LLMDataException(
            f"Column {cfg.dataset.answer_column} contains non int items. "
            f"Sample values: {answers_non_int[:5]}."
        )


def is_castable_to_int(s):
    try:
        int(s)
        return True
    except ValueError:
        return False