Ming-Lite-Omni / sentence_manager /sentence_manager.py
LandyGuo
update 20250516 version
81a8221
# -*- encoding: utf-8 -*-
# Time: 2024/12/02 19:58:58
# Desc:
import os
import re
import yaml
from typing import Optional, List
from enum import Enum
from .text_norm.normalizer import Normalizer
CN_REGEX = "\u4e00-\u4E27\u4E29-\u4E3E\u4E42-\u9fa4" # 跳过笔画\u4E28\u4E3F\u4E40\u4E41
class SentencePieceType(str, Enum):
END_OF_SENTENCE = "END_OF_SENTENCE"
def split_with_separator(regx_sep, text):
"""difference with `re.split`: include separator
"""
split_list = []
start = 0
for match in re.finditer(regx_sep, text):
end = match.span()[1]
assert end > start
split_list.append(text[start: end])
start = end
split_list.append(text[start:]) # could be empty string
return split_list
def split(text, split_pattern, split_cn_length=None):
"""
Args:
text
Return:
(split_list: List[str], remain: str)
"""
split_list = split_with_separator(split_pattern, text)
remain = split_list.pop(-1) # 最后一项为不完整子句,可能为空字符串
# 针对末尾的不完整子句,若满足中文字数条件,也添加进split_list
if split_cn_length is not None:
text_split = re.search(f"^[。!?,{CN_REGEX}]" + "{" + f"{split_cn_length}," + "}", remain)
if text_split:
text_split = text_split.group()
split_list.append(text_split)
remain = remain[len(text_split):]
return split_list, remain
class SentenceNormalizer(Normalizer):
def __init__(self, config={}):
self.config = config
def normalize(self, text, context: str = ""):
text = self.preprocess(text)
text, norm_details = self.normalize_regular(text, is_en=False)
text = self.postprocess(text, custom=self.config["postprocess"])
text = text[len(context):]
return text
class SentenceManager:
def __init__(self, tokenizer, normalizer, config):
"""
Args:
tokenizer: tokenizer为必填参数,因为对英文等特殊符号,tokenize和拼接操作不是可交换的
"""
self.split_pattern = "|".join(config["split_token"])
self.split_cn_length = config["split_cn_length"]
self.tokenizer = tokenizer
self.normalizer = normalizer
self.context: Optional[str] = ""
self.cache: List[int] = []
self.output_queue: List[List[int]] = []
def put(self, token_id):
text = self.tokenizer.decode([*self.cache, token_id])
split_list, remain = split(text, self.split_pattern, split_cn_length=self.split_cn_length)
assert split_list or remain
if split_list:
normalized_split_list = [
self.normalizer.normalize(x) if i < len(split_list) else self.normalizer.normalize(x, self.context)
for i, x in enumerate(split_list)
]
if len(normalized_split_list[-1]) == len(split_list[-1]):
self.context = split_list[-1]
token_ids_list = [self.tokenizer.encode(x) for x in normalized_split_list if x]
self.output_queue.extend(token_ids_list)
if re.search(f"({self.split_pattern})$", split_list[-1]):
self.output_queue.append(SentencePieceType.END_OF_SENTENCE)
if remain:
token_ids_remain = self.tokenizer.encode(remain)
self.cache = token_ids_remain
else:
self.cache = []
def get(self):
if self.output_queue:
return self.output_queue.pop(0)
else:
return None