File size: 7,878 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import os
import re
import shutil
from argparse import ArgumentParser, Namespace

from datasets.commands import BaseDatasetsCLICommand
from datasets.utils.logging import get_logger


HIGHLIGHT_MESSAGE_PRE = """<<<<<<< This should probably be modified because it mentions: """

HIGHLIGHT_MESSAGE_POST = """=======
>>>>>>>
"""

TO_HIGHLIGHT = [
    "TextEncoderConfig",
    "ByteTextEncoder",
    "SubwordTextEncoder",
    "encoder_config",
    "maybe_build_from_corpus",
    "manual_dir",
]

TO_CONVERT = [
    # (pattern, replacement)
    # Order is important here for some replacements
    (r"tfds\.core", r"datasets"),
    (r"tf\.io\.gfile\.GFile", r"open"),
    (r"tf\.([\w\d]+)", r"datasets.Value('\1')"),
    (r"tfds\.features\.Text\(\)", r"datasets.Value('string')"),
    (r"tfds\.features\.Text\(", r"datasets.Value('string'),"),
    (r"features\s*=\s*tfds.features.FeaturesDict\(", r"features=datasets.Features("),
    (r"tfds\.features\.FeaturesDict\(", r"dict("),
    (r"The TensorFlow Datasets Authors", r"The TensorFlow Datasets Authors and the HuggingFace Datasets Authors"),
    (r"tfds\.", r"datasets."),
    (r"dl_manager\.manual_dir", r"self.config.data_dir"),
    (r"self\.builder_config", r"self.config"),
]


def convert_command_factory(args: Namespace):
    """
    Factory function used to convert a model TF 1.0 checkpoint in a PyTorch checkpoint.

    Returns: ConvertCommand
    """
    return ConvertCommand(args.tfds_path, args.datasets_directory)


class ConvertCommand(BaseDatasetsCLICommand):
    @staticmethod
    def register_subcommand(parser: ArgumentParser):
        """
        Register this command to argparse so it's available for the datasets-cli

        Args:
            parser: Root parser to register command-specific arguments
        """
        train_parser = parser.add_parser(
            "convert",
            help="Convert a TensorFlow Datasets dataset to a HuggingFace Datasets dataset.",
        )
        train_parser.add_argument(
            "--tfds_path",
            type=str,
            required=True,
            help="Path to a TensorFlow Datasets folder to convert or a single tfds file to convert.",
        )
        train_parser.add_argument(
            "--datasets_directory", type=str, required=True, help="Path to the HuggingFace Datasets folder."
        )
        train_parser.set_defaults(func=convert_command_factory)

    def __init__(self, tfds_path: str, datasets_directory: str, *args):
        self._logger = get_logger("datasets-cli/converting")

        self._tfds_path = tfds_path
        self._datasets_directory = datasets_directory

    def run(self):
        if os.path.isdir(self._tfds_path):
            abs_tfds_path = os.path.abspath(self._tfds_path)
        elif os.path.isfile(self._tfds_path):
            abs_tfds_path = os.path.dirname(self._tfds_path)
        else:
            raise ValueError("--tfds_path is neither a directory nor a file. Please check path.")

        abs_datasets_path = os.path.abspath(self._datasets_directory)

        self._logger.info(f"Converting datasets from {abs_tfds_path} to {abs_datasets_path}")

        utils_files = []
        with_manual_update = []
        imports_to_builder_map = {}

        if os.path.isdir(self._tfds_path):
            file_names = os.listdir(abs_tfds_path)
        else:
            file_names = [os.path.basename(self._tfds_path)]

        for f_name in file_names:
            self._logger.info(f"Looking at file {f_name}")
            input_file = os.path.join(abs_tfds_path, f_name)
            output_file = os.path.join(abs_datasets_path, f_name)

            if not os.path.isfile(input_file) or "__init__" in f_name or "_test" in f_name or ".py" not in f_name:
                self._logger.info("Skipping file")
                continue

            with open(input_file, encoding="utf-8") as f:
                lines = f.readlines()

            out_lines = []
            is_builder = False
            needs_manual_update = False
            tfds_imports = []
            for line in lines:
                out_line = line

                # Convert imports
                if "import tensorflow.compat.v2 as tf" in out_line:
                    continue
                elif "@tfds.core" in out_line:
                    continue
                elif "builder=self" in out_line:
                    continue
                elif "import tensorflow_datasets.public_api as tfds" in out_line:
                    out_line = "import datasets\n"
                elif "import tensorflow" in out_line:
                    # order is important here
                    out_line = ""
                    continue
                elif "from absl import logging" in out_line:
                    out_line = "from datasets import logging\n"
                elif "getLogger" in out_line:
                    out_line = out_line.replace("getLogger", "get_logger")
                elif any(expression in out_line for expression in TO_HIGHLIGHT):
                    needs_manual_update = True
                    to_remove = list(filter(lambda e: e in out_line, TO_HIGHLIGHT))
                    out_lines.append(HIGHLIGHT_MESSAGE_PRE + str(to_remove) + "\n")
                    out_lines.append(out_line)
                    out_lines.append(HIGHLIGHT_MESSAGE_POST)
                    continue
                else:
                    for pattern, replacement in TO_CONVERT:
                        out_line = re.sub(pattern, replacement, out_line)

                # Take care of saving utilities (to later move them together with main script)
                if "tensorflow_datasets" in out_line:
                    match = re.match(r"from\stensorflow_datasets.*import\s([^\.\r\n]+)", out_line)
                    tfds_imports.extend(imp.strip() for imp in match.group(1).split(","))
                    out_line = "from . import " + match.group(1)

                # Check we have not forget anything
                if "tf." in out_line or "tfds." in out_line or "tensorflow_datasets" in out_line:
                    raise ValueError(f"Error converting {out_line.strip()}")

                if "GeneratorBasedBuilder" in out_line:
                    is_builder = True
                out_lines.append(out_line)

            if is_builder or "wmt" in f_name:
                # We create a new directory for each dataset
                dir_name = f_name.replace(".py", "")
                output_dir = os.path.join(abs_datasets_path, dir_name)
                output_file = os.path.join(output_dir, f_name)
                os.makedirs(output_dir, exist_ok=True)
                self._logger.info(f"Adding directory {output_dir}")
                imports_to_builder_map.update(dict.fromkeys(tfds_imports, output_dir))
            else:
                # Utilities will be moved at the end
                utils_files.append(output_file)

            if needs_manual_update:
                with_manual_update.append(output_file)

            with open(output_file, "w", encoding="utf-8") as f:
                f.writelines(out_lines)
            self._logger.info(f"Converted in {output_file}")

        for utils_file in utils_files:
            try:
                f_name = os.path.basename(utils_file)
                dest_folder = imports_to_builder_map[f_name.replace(".py", "")]
                self._logger.info(f"Moving {dest_folder} to {utils_file}")
                shutil.copy(utils_file, dest_folder)
            except KeyError:
                self._logger.error(f"Cannot find destination folder for {utils_file}. Please copy manually.")

        if with_manual_update:
            for file_path in with_manual_update:
                self._logger.warning(
                    f"You need to manually update file {file_path} to remove configurations using 'TextEncoderConfig'."
                )