|
from dataclasses import dataclass, field |
|
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union |
|
|
|
import pyarrow as pa |
|
|
|
|
|
if TYPE_CHECKING: |
|
from .features import FeatureType |
|
|
|
|
|
@dataclass |
|
class Translation: |
|
"""`Feature` for translations with fixed languages per example. |
|
Here for compatibility with tfds. |
|
|
|
Args: |
|
languages (`dict`): |
|
A dictionary for each example mapping string language codes to string translations. |
|
|
|
Example: |
|
|
|
```python |
|
>>> # At construction time: |
|
>>> datasets.features.Translation(languages=['en', 'fr', 'de']) |
|
>>> # During data generation: |
|
>>> yield { |
|
... 'en': 'the cat', |
|
... 'fr': 'le chat', |
|
... 'de': 'die katze' |
|
... } |
|
``` |
|
""" |
|
|
|
languages: list[str] |
|
id: Optional[str] = None |
|
|
|
dtype: ClassVar[str] = "dict" |
|
pa_type: ClassVar[Any] = None |
|
_type: str = field(default="Translation", init=False, repr=False) |
|
|
|
def __call__(self): |
|
return pa.struct({lang: pa.string() for lang in sorted(self.languages)}) |
|
|
|
def flatten(self) -> Union["FeatureType", dict[str, "FeatureType"]]: |
|
"""Flatten the Translation feature into a dictionary.""" |
|
from .features import Value |
|
|
|
return {k: Value("string") for k in sorted(self.languages)} |
|
|
|
|
|
@dataclass |
|
class TranslationVariableLanguages: |
|
"""`Feature` for translations with variable languages per example. |
|
Here for compatibility with tfds. |
|
|
|
Args: |
|
languages (`dict`): |
|
A dictionary for each example mapping string language codes to one or more string translations. |
|
The languages present may vary from example to example. |
|
|
|
Returns: |
|
- `language` or `translation` (variable-length 1D `tf.Tensor` of `tf.string`): |
|
Language codes sorted in ascending order or plain text translations, sorted to align with language codes. |
|
|
|
Example: |
|
|
|
```python |
|
>>> # At construction time: |
|
>>> datasets.features.TranslationVariableLanguages(languages=['en', 'fr', 'de']) |
|
>>> # During data generation: |
|
>>> yield { |
|
... 'en': 'the cat', |
|
... 'fr': ['le chat', 'la chatte,'] |
|
... 'de': 'die katze' |
|
... } |
|
>>> # Tensor returned : |
|
>>> { |
|
... 'language': ['en', 'de', 'fr', 'fr'], |
|
... 'translation': ['the cat', 'die katze', 'la chatte', 'le chat'], |
|
... } |
|
``` |
|
""" |
|
|
|
languages: Optional[list] = None |
|
num_languages: Optional[int] = None |
|
id: Optional[str] = None |
|
|
|
dtype: ClassVar[str] = "dict" |
|
pa_type: ClassVar[Any] = None |
|
_type: str = field(default="TranslationVariableLanguages", init=False, repr=False) |
|
|
|
def __post_init__(self): |
|
self.languages = sorted(set(self.languages)) if self.languages else None |
|
self.num_languages = len(self.languages) if self.languages else None |
|
|
|
def __call__(self): |
|
return pa.struct({"language": pa.list_(pa.string()), "translation": pa.list_(pa.string())}) |
|
|
|
def encode_example(self, translation_dict): |
|
lang_set = set(self.languages) |
|
if set(translation_dict) == {"language", "translation"}: |
|
return translation_dict |
|
elif self.languages and set(translation_dict) - lang_set: |
|
raise ValueError( |
|
f"Some languages in example ({', '.join(sorted(set(translation_dict) - lang_set))}) are not in valid set ({', '.join(lang_set)})." |
|
) |
|
|
|
|
|
|
|
translation_tuples = [] |
|
for lang, text in translation_dict.items(): |
|
if isinstance(text, str): |
|
translation_tuples.append((lang, text)) |
|
else: |
|
translation_tuples.extend([(lang, el) for el in text]) |
|
|
|
|
|
languages, translations = zip(*sorted(translation_tuples)) |
|
|
|
return {"language": languages, "translation": translations} |
|
|
|
def flatten(self) -> Union["FeatureType", dict[str, "FeatureType"]]: |
|
"""Flatten the TranslationVariableLanguages feature into a dictionary.""" |
|
from .features import Sequence, Value |
|
|
|
return { |
|
"language": Sequence(Value("string")), |
|
"translation": Sequence(Value("string")), |
|
} |
|
|