File size: 3,604 Bytes
81a8221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# -*- encoding: utf-8 -*-
# Time: 2024/12/02 19:58:58
# Desc:

import os
import re
import yaml
from typing import Optional, List
from enum import Enum

from .text_norm.normalizer import Normalizer


CN_REGEX = "\u4e00-\u4E27\u4E29-\u4E3E\u4E42-\u9fa4"  # 跳过笔画\u4E28\u4E3F\u4E40\u4E41


class SentencePieceType(str, Enum):
    END_OF_SENTENCE = "END_OF_SENTENCE"


def split_with_separator(regx_sep, text):
    """difference with `re.split`: include separator
    """
    split_list = []
    start = 0
    for match in re.finditer(regx_sep, text):
        end = match.span()[1]
        assert end > start
        split_list.append(text[start: end])
        start = end
    split_list.append(text[start:])  # could be empty string
    return split_list


def split(text, split_pattern, split_cn_length=None):
    """
    Args:
        text
    Return:
        (split_list: List[str], remain: str)
    """
    split_list = split_with_separator(split_pattern, text)
    remain = split_list.pop(-1)  # 最后一项为不完整子句,可能为空字符串

    # 针对末尾的不完整子句,若满足中文字数条件,也添加进split_list
    if split_cn_length is not None:
        text_split = re.search(f"^[。!?,{CN_REGEX}]" + "{" + f"{split_cn_length}," + "}", remain)
        if text_split:
            text_split = text_split.group()
            split_list.append(text_split)
            remain = remain[len(text_split):]
    return split_list, remain


class SentenceNormalizer(Normalizer):
    def __init__(self, config={}):
        self.config = config

    def normalize(self, text, context: str = ""):
        text = self.preprocess(text)
        text, norm_details = self.normalize_regular(text, is_en=False)
        text = self.postprocess(text, custom=self.config["postprocess"])
        text = text[len(context):]
        return text


class SentenceManager:
    def __init__(self, tokenizer, normalizer, config):
        """
        Args:
            tokenizer: tokenizer为必填参数,因为对英文等特殊符号,tokenize和拼接操作不是可交换的
        """

        self.split_pattern = "|".join(config["split_token"])
        self.split_cn_length = config["split_cn_length"]

        self.tokenizer = tokenizer
        self.normalizer = normalizer

        self.context: Optional[str] = ""
        self.cache: List[int] = []
        self.output_queue: List[List[int]] = []

    def put(self, token_id):
        text = self.tokenizer.decode([*self.cache, token_id])
        split_list, remain = split(text, self.split_pattern, split_cn_length=self.split_cn_length)
        assert split_list or remain

        if split_list:
            normalized_split_list = [
                self.normalizer.normalize(x) if i < len(split_list) else self.normalizer.normalize(x, self.context) 
                for i, x in enumerate(split_list)
            ]
            if len(normalized_split_list[-1]) == len(split_list[-1]):
                self.context = split_list[-1]
            token_ids_list = [self.tokenizer.encode(x) for x in normalized_split_list if x]
            self.output_queue.extend(token_ids_list)
            if re.search(f"({self.split_pattern})$", split_list[-1]):
                self.output_queue.append(SentencePieceType.END_OF_SENTENCE)

        if remain:
            token_ids_remain = self.tokenizer.encode(remain)
            self.cache = token_ids_remain
        else:
            self.cache = []

    def get(self):
        if self.output_queue:
            return self.output_queue.pop(0)
        else:
            return None