|
import os |
|
import re |
|
import shutil |
|
from argparse import ArgumentParser, Namespace |
|
|
|
from datasets.commands import BaseDatasetsCLICommand |
|
from datasets.utils.logging import get_logger |
|
|
|
|
|
HIGHLIGHT_MESSAGE_PRE = """<<<<<<< This should probably be modified because it mentions: """ |
|
|
|
HIGHLIGHT_MESSAGE_POST = """======= |
|
>>>>>>> |
|
""" |
|
|
|
TO_HIGHLIGHT = [ |
|
"TextEncoderConfig", |
|
"ByteTextEncoder", |
|
"SubwordTextEncoder", |
|
"encoder_config", |
|
"maybe_build_from_corpus", |
|
"manual_dir", |
|
] |
|
|
|
TO_CONVERT = [ |
|
|
|
|
|
(r"tfds\.core", r"datasets"), |
|
(r"tf\.io\.gfile\.GFile", r"open"), |
|
(r"tf\.([\w\d]+)", r"datasets.Value('\1')"), |
|
(r"tfds\.features\.Text\(\)", r"datasets.Value('string')"), |
|
(r"tfds\.features\.Text\(", r"datasets.Value('string'),"), |
|
(r"features\s*=\s*tfds.features.FeaturesDict\(", r"features=datasets.Features("), |
|
(r"tfds\.features\.FeaturesDict\(", r"dict("), |
|
(r"The TensorFlow Datasets Authors", r"The TensorFlow Datasets Authors and the HuggingFace Datasets Authors"), |
|
(r"tfds\.", r"datasets."), |
|
(r"dl_manager\.manual_dir", r"self.config.data_dir"), |
|
(r"self\.builder_config", r"self.config"), |
|
] |
|
|
|
|
|
def convert_command_factory(args: Namespace): |
|
""" |
|
Factory function used to convert a model TF 1.0 checkpoint in a PyTorch checkpoint. |
|
|
|
Returns: ConvertCommand |
|
""" |
|
return ConvertCommand(args.tfds_path, args.datasets_directory) |
|
|
|
|
|
class ConvertCommand(BaseDatasetsCLICommand): |
|
@staticmethod |
|
def register_subcommand(parser: ArgumentParser): |
|
""" |
|
Register this command to argparse so it's available for the datasets-cli |
|
|
|
Args: |
|
parser: Root parser to register command-specific arguments |
|
""" |
|
train_parser = parser.add_parser( |
|
"convert", |
|
help="Convert a TensorFlow Datasets dataset to a HuggingFace Datasets dataset.", |
|
) |
|
train_parser.add_argument( |
|
"--tfds_path", |
|
type=str, |
|
required=True, |
|
help="Path to a TensorFlow Datasets folder to convert or a single tfds file to convert.", |
|
) |
|
train_parser.add_argument( |
|
"--datasets_directory", type=str, required=True, help="Path to the HuggingFace Datasets folder." |
|
) |
|
train_parser.set_defaults(func=convert_command_factory) |
|
|
|
def __init__(self, tfds_path: str, datasets_directory: str, *args): |
|
self._logger = get_logger("datasets-cli/converting") |
|
|
|
self._tfds_path = tfds_path |
|
self._datasets_directory = datasets_directory |
|
|
|
def run(self): |
|
if os.path.isdir(self._tfds_path): |
|
abs_tfds_path = os.path.abspath(self._tfds_path) |
|
elif os.path.isfile(self._tfds_path): |
|
abs_tfds_path = os.path.dirname(self._tfds_path) |
|
else: |
|
raise ValueError("--tfds_path is neither a directory nor a file. Please check path.") |
|
|
|
abs_datasets_path = os.path.abspath(self._datasets_directory) |
|
|
|
self._logger.info(f"Converting datasets from {abs_tfds_path} to {abs_datasets_path}") |
|
|
|
utils_files = [] |
|
with_manual_update = [] |
|
imports_to_builder_map = {} |
|
|
|
if os.path.isdir(self._tfds_path): |
|
file_names = os.listdir(abs_tfds_path) |
|
else: |
|
file_names = [os.path.basename(self._tfds_path)] |
|
|
|
for f_name in file_names: |
|
self._logger.info(f"Looking at file {f_name}") |
|
input_file = os.path.join(abs_tfds_path, f_name) |
|
output_file = os.path.join(abs_datasets_path, f_name) |
|
|
|
if not os.path.isfile(input_file) or "__init__" in f_name or "_test" in f_name or ".py" not in f_name: |
|
self._logger.info("Skipping file") |
|
continue |
|
|
|
with open(input_file, encoding="utf-8") as f: |
|
lines = f.readlines() |
|
|
|
out_lines = [] |
|
is_builder = False |
|
needs_manual_update = False |
|
tfds_imports = [] |
|
for line in lines: |
|
out_line = line |
|
|
|
|
|
if "import tensorflow.compat.v2 as tf" in out_line: |
|
continue |
|
elif "@tfds.core" in out_line: |
|
continue |
|
elif "builder=self" in out_line: |
|
continue |
|
elif "import tensorflow_datasets.public_api as tfds" in out_line: |
|
out_line = "import datasets\n" |
|
elif "import tensorflow" in out_line: |
|
|
|
out_line = "" |
|
continue |
|
elif "from absl import logging" in out_line: |
|
out_line = "from datasets import logging\n" |
|
elif "getLogger" in out_line: |
|
out_line = out_line.replace("getLogger", "get_logger") |
|
elif any(expression in out_line for expression in TO_HIGHLIGHT): |
|
needs_manual_update = True |
|
to_remove = list(filter(lambda e: e in out_line, TO_HIGHLIGHT)) |
|
out_lines.append(HIGHLIGHT_MESSAGE_PRE + str(to_remove) + "\n") |
|
out_lines.append(out_line) |
|
out_lines.append(HIGHLIGHT_MESSAGE_POST) |
|
continue |
|
else: |
|
for pattern, replacement in TO_CONVERT: |
|
out_line = re.sub(pattern, replacement, out_line) |
|
|
|
|
|
if "tensorflow_datasets" in out_line: |
|
match = re.match(r"from\stensorflow_datasets.*import\s([^\.\r\n]+)", out_line) |
|
tfds_imports.extend(imp.strip() for imp in match.group(1).split(",")) |
|
out_line = "from . import " + match.group(1) |
|
|
|
|
|
if "tf." in out_line or "tfds." in out_line or "tensorflow_datasets" in out_line: |
|
raise ValueError(f"Error converting {out_line.strip()}") |
|
|
|
if "GeneratorBasedBuilder" in out_line: |
|
is_builder = True |
|
out_lines.append(out_line) |
|
|
|
if is_builder or "wmt" in f_name: |
|
|
|
dir_name = f_name.replace(".py", "") |
|
output_dir = os.path.join(abs_datasets_path, dir_name) |
|
output_file = os.path.join(output_dir, f_name) |
|
os.makedirs(output_dir, exist_ok=True) |
|
self._logger.info(f"Adding directory {output_dir}") |
|
imports_to_builder_map.update(dict.fromkeys(tfds_imports, output_dir)) |
|
else: |
|
|
|
utils_files.append(output_file) |
|
|
|
if needs_manual_update: |
|
with_manual_update.append(output_file) |
|
|
|
with open(output_file, "w", encoding="utf-8") as f: |
|
f.writelines(out_lines) |
|
self._logger.info(f"Converted in {output_file}") |
|
|
|
for utils_file in utils_files: |
|
try: |
|
f_name = os.path.basename(utils_file) |
|
dest_folder = imports_to_builder_map[f_name.replace(".py", "")] |
|
self._logger.info(f"Moving {dest_folder} to {utils_file}") |
|
shutil.copy(utils_file, dest_folder) |
|
except KeyError: |
|
self._logger.error(f"Cannot find destination folder for {utils_file}. Please copy manually.") |
|
|
|
if with_manual_update: |
|
for file_path in with_manual_update: |
|
self._logger.warning( |
|
f"You need to manually update file {file_path} to remove configurations using 'TextEncoderConfig'." |
|
) |
|
|