|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Credits |
|
This code is modified from https://github.com/GitYCC/g2pW |
|
""" |
|
|
|
from typing import Dict |
|
from typing import List |
|
from typing import Tuple |
|
|
|
import numpy as np |
|
|
|
from .utils import tokenize_and_map |
|
|
|
ANCHOR_CHAR = "▁" |
|
|
|
|
|
def prepare_onnx_input( |
|
tokenizer, |
|
labels: List[str], |
|
char2phonemes: Dict[str, List[int]], |
|
chars: List[str], |
|
texts: List[str], |
|
query_ids: List[int], |
|
use_mask: bool = False, |
|
window_size: int = None, |
|
max_len: int = 512, |
|
) -> Dict[str, np.array]: |
|
if window_size is not None: |
|
truncated_texts, truncated_query_ids = _truncate_texts( |
|
window_size=window_size, texts=texts, query_ids=query_ids |
|
) |
|
input_ids = [] |
|
token_type_ids = [] |
|
attention_masks = [] |
|
phoneme_masks = [] |
|
char_ids = [] |
|
position_ids = [] |
|
|
|
for idx in range(len(texts)): |
|
text = (truncated_texts if window_size else texts)[idx].lower() |
|
query_id = (truncated_query_ids if window_size else query_ids)[idx] |
|
|
|
try: |
|
tokens, text2token, token2text = tokenize_and_map(tokenizer=tokenizer, text=text) |
|
except Exception: |
|
print(f'warning: text "{text}" is invalid') |
|
return {} |
|
|
|
text, query_id, tokens, text2token, token2text = _truncate( |
|
max_len=max_len, text=text, query_id=query_id, tokens=tokens, text2token=text2token, token2text=token2text |
|
) |
|
|
|
processed_tokens = ["[CLS]"] + tokens + ["[SEP]"] |
|
|
|
input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens))) |
|
token_type_id = list(np.zeros((len(processed_tokens),), dtype=int)) |
|
attention_mask = list(np.ones((len(processed_tokens),), dtype=int)) |
|
|
|
query_char = text[query_id] |
|
phoneme_mask = ( |
|
[1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] if use_mask else [1] * len(labels) |
|
) |
|
char_id = chars.index(query_char) |
|
position_id = text2token[query_id] + 1 |
|
|
|
input_ids.append(input_id) |
|
token_type_ids.append(token_type_id) |
|
attention_masks.append(attention_mask) |
|
phoneme_masks.append(phoneme_mask) |
|
char_ids.append(char_id) |
|
position_ids.append(position_id) |
|
|
|
outputs = { |
|
"input_ids": np.array(input_ids).astype(np.int64), |
|
"token_type_ids": np.array(token_type_ids).astype(np.int64), |
|
"attention_masks": np.array(attention_masks).astype(np.int64), |
|
"phoneme_masks": np.array(phoneme_masks).astype(np.float32), |
|
"char_ids": np.array(char_ids).astype(np.int64), |
|
"position_ids": np.array(position_ids).astype(np.int64), |
|
} |
|
return outputs |
|
|
|
|
|
def _truncate_texts(window_size: int, texts: List[str], query_ids: List[int]) -> Tuple[List[str], List[int]]: |
|
truncated_texts = [] |
|
truncated_query_ids = [] |
|
for text, query_id in zip(texts, query_ids): |
|
start = max(0, query_id - window_size // 2) |
|
end = min(len(text), query_id + window_size // 2) |
|
truncated_text = text[start:end] |
|
truncated_texts.append(truncated_text) |
|
|
|
truncated_query_id = query_id - start |
|
truncated_query_ids.append(truncated_query_id) |
|
return truncated_texts, truncated_query_ids |
|
|
|
|
|
def _truncate( |
|
max_len: int, text: str, query_id: int, tokens: List[str], text2token: List[int], token2text: List[Tuple[int]] |
|
): |
|
truncate_len = max_len - 2 |
|
if len(tokens) <= truncate_len: |
|
return (text, query_id, tokens, text2token, token2text) |
|
|
|
token_position = text2token[query_id] |
|
|
|
token_start = token_position - truncate_len // 2 |
|
token_end = token_start + truncate_len |
|
font_exceed_dist = -token_start |
|
back_exceed_dist = token_end - len(tokens) |
|
if font_exceed_dist > 0: |
|
token_start += font_exceed_dist |
|
token_end += font_exceed_dist |
|
elif back_exceed_dist > 0: |
|
token_start -= back_exceed_dist |
|
token_end -= back_exceed_dist |
|
|
|
start = token2text[token_start][0] |
|
end = token2text[token_end - 1][1] |
|
|
|
return ( |
|
text[start:end], |
|
query_id - start, |
|
tokens[token_start:token_end], |
|
[i - token_start if i is not None else None for i in text2token[start:end]], |
|
[(s - start, e - start) for s, e in token2text[token_start:token_end]], |
|
) |
|
|
|
|
|
def get_phoneme_labels(polyphonic_chars: List[List[str]]) -> Tuple[List[str], Dict[str, List[int]]]: |
|
labels = sorted(list(set([phoneme for char, phoneme in polyphonic_chars]))) |
|
char2phonemes = {} |
|
for char, phoneme in polyphonic_chars: |
|
if char not in char2phonemes: |
|
char2phonemes[char] = [] |
|
char2phonemes[char].append(labels.index(phoneme)) |
|
return labels, char2phonemes |
|
|
|
|
|
def get_char_phoneme_labels(polyphonic_chars: List[List[str]]) -> Tuple[List[str], Dict[str, List[int]]]: |
|
labels = sorted(list(set([f"{char} {phoneme}" for char, phoneme in polyphonic_chars]))) |
|
char2phonemes = {} |
|
for char, phoneme in polyphonic_chars: |
|
if char not in char2phonemes: |
|
char2phonemes[char] = [] |
|
char2phonemes[char].append(labels.index(f"{char} {phoneme}")) |
|
return labels, char2phonemes |
|
|