File size: 7,878 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
import os
import re
import shutil
from argparse import ArgumentParser, Namespace
from datasets.commands import BaseDatasetsCLICommand
from datasets.utils.logging import get_logger
HIGHLIGHT_MESSAGE_PRE = """<<<<<<< This should probably be modified because it mentions: """
HIGHLIGHT_MESSAGE_POST = """=======
>>>>>>>
"""
TO_HIGHLIGHT = [
"TextEncoderConfig",
"ByteTextEncoder",
"SubwordTextEncoder",
"encoder_config",
"maybe_build_from_corpus",
"manual_dir",
]
TO_CONVERT = [
# (pattern, replacement)
# Order is important here for some replacements
(r"tfds\.core", r"datasets"),
(r"tf\.io\.gfile\.GFile", r"open"),
(r"tf\.([\w\d]+)", r"datasets.Value('\1')"),
(r"tfds\.features\.Text\(\)", r"datasets.Value('string')"),
(r"tfds\.features\.Text\(", r"datasets.Value('string'),"),
(r"features\s*=\s*tfds.features.FeaturesDict\(", r"features=datasets.Features("),
(r"tfds\.features\.FeaturesDict\(", r"dict("),
(r"The TensorFlow Datasets Authors", r"The TensorFlow Datasets Authors and the HuggingFace Datasets Authors"),
(r"tfds\.", r"datasets."),
(r"dl_manager\.manual_dir", r"self.config.data_dir"),
(r"self\.builder_config", r"self.config"),
]
def convert_command_factory(args: Namespace):
"""
Factory function used to convert a model TF 1.0 checkpoint in a PyTorch checkpoint.
Returns: ConvertCommand
"""
return ConvertCommand(args.tfds_path, args.datasets_directory)
class ConvertCommand(BaseDatasetsCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
"""
Register this command to argparse so it's available for the datasets-cli
Args:
parser: Root parser to register command-specific arguments
"""
train_parser = parser.add_parser(
"convert",
help="Convert a TensorFlow Datasets dataset to a HuggingFace Datasets dataset.",
)
train_parser.add_argument(
"--tfds_path",
type=str,
required=True,
help="Path to a TensorFlow Datasets folder to convert or a single tfds file to convert.",
)
train_parser.add_argument(
"--datasets_directory", type=str, required=True, help="Path to the HuggingFace Datasets folder."
)
train_parser.set_defaults(func=convert_command_factory)
def __init__(self, tfds_path: str, datasets_directory: str, *args):
self._logger = get_logger("datasets-cli/converting")
self._tfds_path = tfds_path
self._datasets_directory = datasets_directory
def run(self):
if os.path.isdir(self._tfds_path):
abs_tfds_path = os.path.abspath(self._tfds_path)
elif os.path.isfile(self._tfds_path):
abs_tfds_path = os.path.dirname(self._tfds_path)
else:
raise ValueError("--tfds_path is neither a directory nor a file. Please check path.")
abs_datasets_path = os.path.abspath(self._datasets_directory)
self._logger.info(f"Converting datasets from {abs_tfds_path} to {abs_datasets_path}")
utils_files = []
with_manual_update = []
imports_to_builder_map = {}
if os.path.isdir(self._tfds_path):
file_names = os.listdir(abs_tfds_path)
else:
file_names = [os.path.basename(self._tfds_path)]
for f_name in file_names:
self._logger.info(f"Looking at file {f_name}")
input_file = os.path.join(abs_tfds_path, f_name)
output_file = os.path.join(abs_datasets_path, f_name)
if not os.path.isfile(input_file) or "__init__" in f_name or "_test" in f_name or ".py" not in f_name:
self._logger.info("Skipping file")
continue
with open(input_file, encoding="utf-8") as f:
lines = f.readlines()
out_lines = []
is_builder = False
needs_manual_update = False
tfds_imports = []
for line in lines:
out_line = line
# Convert imports
if "import tensorflow.compat.v2 as tf" in out_line:
continue
elif "@tfds.core" in out_line:
continue
elif "builder=self" in out_line:
continue
elif "import tensorflow_datasets.public_api as tfds" in out_line:
out_line = "import datasets\n"
elif "import tensorflow" in out_line:
# order is important here
out_line = ""
continue
elif "from absl import logging" in out_line:
out_line = "from datasets import logging\n"
elif "getLogger" in out_line:
out_line = out_line.replace("getLogger", "get_logger")
elif any(expression in out_line for expression in TO_HIGHLIGHT):
needs_manual_update = True
to_remove = list(filter(lambda e: e in out_line, TO_HIGHLIGHT))
out_lines.append(HIGHLIGHT_MESSAGE_PRE + str(to_remove) + "\n")
out_lines.append(out_line)
out_lines.append(HIGHLIGHT_MESSAGE_POST)
continue
else:
for pattern, replacement in TO_CONVERT:
out_line = re.sub(pattern, replacement, out_line)
# Take care of saving utilities (to later move them together with main script)
if "tensorflow_datasets" in out_line:
match = re.match(r"from\stensorflow_datasets.*import\s([^\.\r\n]+)", out_line)
tfds_imports.extend(imp.strip() for imp in match.group(1).split(","))
out_line = "from . import " + match.group(1)
# Check we have not forget anything
if "tf." in out_line or "tfds." in out_line or "tensorflow_datasets" in out_line:
raise ValueError(f"Error converting {out_line.strip()}")
if "GeneratorBasedBuilder" in out_line:
is_builder = True
out_lines.append(out_line)
if is_builder or "wmt" in f_name:
# We create a new directory for each dataset
dir_name = f_name.replace(".py", "")
output_dir = os.path.join(abs_datasets_path, dir_name)
output_file = os.path.join(output_dir, f_name)
os.makedirs(output_dir, exist_ok=True)
self._logger.info(f"Adding directory {output_dir}")
imports_to_builder_map.update(dict.fromkeys(tfds_imports, output_dir))
else:
# Utilities will be moved at the end
utils_files.append(output_file)
if needs_manual_update:
with_manual_update.append(output_file)
with open(output_file, "w", encoding="utf-8") as f:
f.writelines(out_lines)
self._logger.info(f"Converted in {output_file}")
for utils_file in utils_files:
try:
f_name = os.path.basename(utils_file)
dest_folder = imports_to_builder_map[f_name.replace(".py", "")]
self._logger.info(f"Moving {dest_folder} to {utils_file}")
shutil.copy(utils_file, dest_folder)
except KeyError:
self._logger.error(f"Cannot find destination folder for {utils_file}. Please copy manually.")
if with_manual_update:
for file_path in with_manual_update:
self._logger.warning(
f"You need to manually update file {file_path} to remove configurations using 'TextEncoderConfig'."
)
|