|
import torch |
|
import torch.nn as nn |
|
from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM |
|
import gradio as gr |
|
|
|
|
|
class MultiTaskModel(nn.Module): |
|
def __init__(self, base_model_name, num_topic_classes, num_sentiment_classes): |
|
super().__init__() |
|
self.encoder = AutoModel.from_pretrained(base_model_name) |
|
hs = self.encoder.config.hidden_size |
|
self.topik_classifier = nn.Linear(hs, num_topic_classes) |
|
self.sentiment_classifier = nn.Linear(hs, num_sentiment_classes) |
|
|
|
def forward(self, input_ids, attention_mask, token_type_ids=None): |
|
out = self.encoder(input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids) |
|
pooled = out.last_hidden_state[:, 0] |
|
return self.topik_classifier(pooled), self.sentiment_classifier(pooled) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("tokenizer") |
|
model = MultiTaskModel("indobenchmark/indobert-base-p1", num_topic_classes=5, num_sentiment_classes=3) |
|
model.load_state_dict(torch.load("model.pt", map_location="cpu")) |
|
model.eval() |
|
|
|
|
|
sum_tok = AutoTokenizer.from_pretrained("xTorch8/bart-id-summarization") |
|
sum_model = AutoModelForSeq2SeqLM.from_pretrained("xTorch8/bart-id-summarization") |
|
|
|
|
|
labels_topik = ["Produk", "Layanan", "Pengiriman", "Pembatalan", "Lainnya"] |
|
labels_sentiment = ["Negatif", "Netral", "Positif"] |
|
|
|
|
|
def analyze(text): |
|
|
|
inp = tokenizer(text, return_tensors="pt", truncation=True, padding=True) |
|
with torch.no_grad(): |
|
t_logits, s_logits = model(**inp) |
|
topik = labels_topik[int(torch.argmax(t_logits))] |
|
sentimen = labels_sentiment[int(torch.argmax(s_logits))] |
|
|
|
|
|
s_inp = sum_tok(text, return_tensors="pt", truncation=True, padding=True) |
|
summ_ids = sum_model.generate(**s_inp, max_length=50, num_beams=2) |
|
ringkasan = sum_tok.decode(summ_ids[0], skip_special_tokens=True) |
|
|
|
return (f"HASIL ANALISIS\n" |
|
f"Topik: {topik}\n" |
|
f"Sentimen: {sentimen}\n" |
|
f"Ringkasan: {ringkasan}") |
|
|
|
|
|
demo = gr.Interface(fn=analyze, inputs="text", outputs="text", title="Analisis Topik, Sentimen, dan Ringkasan Pelanggan") |
|
demo.launch() |
|
|