File size: 3,731 Bytes
9dcbc08 ba2392f 9dcbc08 ba2392f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import streamlit as st
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch
id_to_label = {
0: 'O',
1: 'B-TOPIC',
2: 'I-TOPIC',
3: 'B-STYLE',
4: 'I-STYLE',
5: 'B-LENGTH',
6: 'I-LENGTH',
7: 'B-LANGUAGE',
8: 'I-LANGUAGE'
}
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained(".")
model = AutoModelForTokenClassification.from_pretrained(".")
return tokenizer, model
tokenizer, model = load_model()
def predict(text, model, tokenizer, id_to_label):
tokens = list(text)
inputs = tokenizer(tokens, is_split_into_words=True, return_tensors="pt", truncation=True, max_length=128)
model.eval()
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predictions = torch.argmax(logits, dim=-1)
word_ids = inputs.word_ids(batch_index=0)
pred_labels = []
tokens_out = []
for idx, word_idx in enumerate(word_ids):
if word_idx is None:
continue
token = tokens[word_idx]
label = id_to_label[predictions[0][idx].item()]
tokens_out.append(token)
pred_labels.append(label)
return tokens_out, pred_labels
def post_process(tokens, labels):
words, word_labels = [], []
current_word = ""
current_label = None
for token, label in zip(tokens, labels):
if token in ["[CLS]", "[SEP]", "[PAD]"]:
continue
if token.startswith("##"):
current_word += token[2:]
else:
if current_word:
words.append(current_word)
word_labels.append(current_label)
current_word = token
current_label = label
if current_word:
words.append(current_word)
word_labels.append(current_label)
return words, word_labels
def align_words_labels(words, labels):
return list(zip(words, labels))
def extract_entities(aligned_result):
entities, current_entity, current_text = [], None, ""
for word, label in aligned_result:
if label == "O":
if current_entity:
entities.append({"entity": current_entity, "text": current_text})
current_entity, current_text = None, ""
continue
prefix, entity_type = label.split("-", 1)
if prefix == "B":
if current_entity:
entities.append({"entity": current_entity, "text": current_text})
current_entity, current_text = entity_type, word
elif prefix == "I" and current_entity == entity_type:
current_text += word
else:
if current_entity:
entities.append({"entity": current_entity, "text": current_text})
current_entity, current_text = entity_type, word
if current_entity:
entities.append({"entity": current_entity, "text": current_text})
return entities
# Streamlit UI
st.title("๐ฏ Learning Condition Extractor")
st.write("์ฌ์ฉ์์ ํ์ต ๋ชฉํ ๋ฌธ์ฅ์์ ์กฐ๊ฑด(TOPIC, STYLE, LENGTH, LANGUAGE)์ ์ถ์ถํฉ๋๋ค.")
user_input = st.text_input("ํ์ต ๋ชฉํ๋ฅผ ์
๋ ฅํ์ธ์:", value="๋ฅ๋ฌ๋์ ์ค์ต ์์ฃผ๋ก 30๋ถ ์ด๋ด์ ๋ฐฐ์ฐ๊ณ ์ถ์ด์")
if st.button("์ถ๋ก ์์"):
tokens, pred_labels = predict(user_input, model, tokenizer, id_to_label)
words, word_labels = post_process(tokens, pred_labels)
aligned = align_words_labels(words, word_labels)
entities = extract_entities(aligned)
result_dict = {'TOPIC': None, 'STYLE': None, 'LENGTH': None, 'LANGUAGE': None}
for ent in entities:
result_dict[ent['entity']] = ent['text']
st.subheader("๐ ์ถ์ถ๋ ์กฐ๊ฑด")
st.json(result_dict)
|