import os os.environ["TRANSFORMERS_CACHE"] = "/tmp" from fastapi import FastAPI from pydantic import BaseModel from transformers import PegasusTokenizer, PegasusForConditionalGeneration import torch app = FastAPI() # Load model và tokenizer model_name = "google/pegasus-cnn_dailymail" tokenizer = PegasusTokenizer.from_pretrained(model_name) model = PegasusForConditionalGeneration.from_pretrained(model_name) # Dùng GPU nếu có device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) # Định nghĩa input schema class InputText(BaseModel): text: str # Hàm tóm tắt tự động điều chỉnh độ dài theo số token def summarize(text: str) -> str: # Tokenize input text inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024) input_length = inputs["input_ids"].shape[1] # Xác định độ dài summary theo tỷ lệ input summary_max_len = max(30, int(input_length * 0.2)) # tối đa khoảng 20% số token summary_min_len = max(15, int(summary_max_len * 0.6)) # tối thiểu khoảng 60% max inputs = {k: v.to(device) for k, v in inputs.items()} # Sinh summary summary_ids = model.generate( inputs["input_ids"], max_length=summary_max_len, min_length=summary_min_len, num_beams=4, no_repeat_ngram_size=3, early_stopping=True ) return tokenizer.decode(summary_ids[0], skip_special_tokens=True) # API route @app.post("/summarize") def summarize_api(input: InputText): return {"summary": summarize(input.text)}