Spaces:
Sleeping
Sleeping
import os | |
os.environ["TRANSFORMERS_CACHE"] = "/tmp" | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from transformers import PegasusTokenizer, PegasusForConditionalGeneration | |
import torch | |
app = FastAPI() | |
# Load model và tokenizer | |
model_name = "google/pegasus-cnn_dailymail" | |
tokenizer = PegasusTokenizer.from_pretrained(model_name) | |
model = PegasusForConditionalGeneration.from_pretrained(model_name) | |
# Dùng GPU nếu có | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = model.to(device) | |
# Định nghĩa input schema | |
class InputText(BaseModel): | |
text: str | |
# Hàm tóm tắt tự động điều chỉnh độ dài theo số token | |
def summarize(text: str) -> str: | |
# Tokenize input text | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024) | |
input_length = inputs["input_ids"].shape[1] | |
# Xác định độ dài summary theo tỷ lệ input | |
summary_max_len = max(30, int(input_length * 0.2)) # tối đa khoảng 20% số token | |
summary_min_len = max(15, int(summary_max_len * 0.6)) # tối thiểu khoảng 60% max | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
# Sinh summary | |
summary_ids = model.generate( | |
inputs["input_ids"], | |
max_length=summary_max_len, | |
min_length=summary_min_len, | |
num_beams=4, | |
no_repeat_ngram_size=3, | |
early_stopping=True | |
) | |
return tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
# API route | |
def summarize_api(input: InputText): | |
return {"summary": summarize(input.text)} | |