|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForTokenClassification |
|
import torch |
|
|
|
id_to_label = { |
|
0: 'O', |
|
1: 'B-TOPIC', |
|
2: 'I-TOPIC', |
|
3: 'B-STYLE', |
|
4: 'I-STYLE', |
|
5: 'B-LENGTH', |
|
6: 'I-LENGTH', |
|
7: 'B-LANGUAGE', |
|
8: 'I-LANGUAGE' |
|
} |
|
|
|
@st.cache_resource |
|
def load_model(): |
|
tokenizer = AutoTokenizer.from_pretrained(".") |
|
model = AutoModelForTokenClassification.from_pretrained(".") |
|
return tokenizer, model |
|
|
|
tokenizer, model = load_model() |
|
|
|
def predict(text, model, tokenizer, id_to_label): |
|
tokens = list(text) |
|
inputs = tokenizer(tokens, is_split_into_words=True, return_tensors="pt", truncation=True, max_length=128) |
|
model.eval() |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
predictions = torch.argmax(logits, dim=-1) |
|
|
|
word_ids = inputs.word_ids(batch_index=0) |
|
pred_labels = [] |
|
tokens_out = [] |
|
|
|
for idx, word_idx in enumerate(word_ids): |
|
if word_idx is None: |
|
continue |
|
token = tokens[word_idx] |
|
label = id_to_label[predictions[0][idx].item()] |
|
tokens_out.append(token) |
|
pred_labels.append(label) |
|
|
|
return tokens_out, pred_labels |
|
|
|
def post_process(tokens, labels): |
|
words, word_labels = [], [] |
|
current_word = "" |
|
current_label = None |
|
for token, label in zip(tokens, labels): |
|
if token in ["[CLS]", "[SEP]", "[PAD]"]: |
|
continue |
|
if token.startswith("##"): |
|
current_word += token[2:] |
|
else: |
|
if current_word: |
|
words.append(current_word) |
|
word_labels.append(current_label) |
|
current_word = token |
|
current_label = label |
|
if current_word: |
|
words.append(current_word) |
|
word_labels.append(current_label) |
|
return words, word_labels |
|
|
|
def align_words_labels(words, labels): |
|
return list(zip(words, labels)) |
|
|
|
def extract_entities(aligned_result): |
|
entities, current_entity, current_text = [], None, "" |
|
for word, label in aligned_result: |
|
if label == "O": |
|
if current_entity: |
|
entities.append({"entity": current_entity, "text": current_text}) |
|
current_entity, current_text = None, "" |
|
continue |
|
prefix, entity_type = label.split("-", 1) |
|
if prefix == "B": |
|
if current_entity: |
|
entities.append({"entity": current_entity, "text": current_text}) |
|
current_entity, current_text = entity_type, word |
|
elif prefix == "I" and current_entity == entity_type: |
|
current_text += word |
|
else: |
|
if current_entity: |
|
entities.append({"entity": current_entity, "text": current_text}) |
|
current_entity, current_text = entity_type, word |
|
if current_entity: |
|
entities.append({"entity": current_entity, "text": current_text}) |
|
return entities |
|
|
|
|
|
st.title("๐ฏ Learning Condition Extractor") |
|
st.write("์ฌ์ฉ์์ ํ์ต ๋ชฉํ ๋ฌธ์ฅ์์ ์กฐ๊ฑด(TOPIC, STYLE, LENGTH, LANGUAGE)์ ์ถ์ถํฉ๋๋ค.") |
|
|
|
user_input = st.text_input("ํ์ต ๋ชฉํ๋ฅผ ์
๋ ฅํ์ธ์:", value="๋ฅ๋ฌ๋์ ์ค์ต ์์ฃผ๋ก 30๋ถ ์ด๋ด์ ๋ฐฐ์ฐ๊ณ ์ถ์ด์") |
|
|
|
if st.button("์ถ๋ก ์์"): |
|
tokens, pred_labels = predict(user_input, model, tokenizer, id_to_label) |
|
words, word_labels = post_process(tokens, pred_labels) |
|
aligned = align_words_labels(words, word_labels) |
|
entities = extract_entities(aligned) |
|
|
|
result_dict = {'TOPIC': None, 'STYLE': None, 'LENGTH': None, 'LANGUAGE': None} |
|
for ent in entities: |
|
result_dict[ent['entity']] = ent['text'] |
|
|
|
st.subheader("๐ ์ถ์ถ๋ ์กฐ๊ฑด") |
|
st.json(result_dict) |
|
|