import gradio as gr import torch import gc from transformers import AutoTokenizer, AutoModelForCausalLM import os # 清理内存 torch.cuda.empty_cache() gc.collect() # 设置环境变量 os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" # 模型名称 model_name = "您的用户名/text-style-converter" # 全局变量存储模型 tokenizer = None model = None def load_model(): """延迟加载模型""" global tokenizer, model if tokenizer is None or model is None: try: print("正在加载tokenizer...") tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True, use_fast=False # 使用慢速tokenizer减少内存 ) print("正在加载模型...") model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, # 使用半精度 device_map="cpu", # 强制使用CPU low_cpu_mem_usage=True, # 启用低内存模式 trust_remote_code=True, load_in_8bit=False, # 在CPU上不使用量化 offload_folder="./offload", # 设置offload文件夹 ) # 设置pad_token if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print("模型加载完成!") except Exception as e: print(f"模型加载失败: {str(e)}") return False return True def convert_text_style(input_text): """文本风格转换函数""" if not input_text.strip(): return "请输入要转换的文本" # 检查模型是否加载 if not load_model(): return "模型加载失败,请稍后重试" try: prompt = f"""以下是一个文本风格转换任务,请将书面化、技术性的输入文本转换为自然、口语化的表达方式。 ### 输入文本: {input_text} ### 输出文本: """ # 编码输入 inputs = tokenizer( prompt, return_tensors="pt", max_length=1024, # 限制输入长度 truncation=True, padding=True ) # 生成回答 with torch.no_grad(): # 不计算梯度节省内存 outputs = model.generate( inputs.input_ids, attention_mask=inputs.attention_mask, max_new_tokens=300, # 减少生成长度 temperature=0.7, do_sample=True, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, num_return_sequences=1, no_repeat_ngram_size=2 ) # 解码输出 full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) # 提取生成的部分 if "### 输出文本:" in full_response: response = full_response.split("### 输出文本:")[-1].strip() else: response = full_response[len(prompt):].strip() # 清理内存 del inputs, outputs torch.cuda.empty_cache() gc.collect() return response if response else "抱歉,未能生成有效回答" except Exception as e: return f"生成过程中出现错误: {str(e)}" # 创建Gradio接口 def create_interface(): iface = gr.Interface( fn=convert_text_style, inputs=gr.Textbox( label="输入文本", placeholder="请输入需要转换为口语化的书面文本...", lines=3 ), outputs=gr.Textbox( label="输出文本", lines=3 ), title="中文文本风格转换API", description="将书面化、技术性文本转换为自然、口语化表达", examples=[ ["乙醇的检测方法包括酸碱度检查。"], ["本品为薄膜衣片,除去包衣后显橙红色。"] ], cache_examples=False, # 不缓存示例 allow_flagging="never" # 禁用标记功能 ) return iface # 启动应用 if __name__ == "__main__": print("正在启动应用...") iface = create_interface() iface.launch( server_name="0.0.0.0", server_port=7860, share=False, debug=False, enable_queue=True, max_threads=1 # 限制线程数 )