Edutube / streamlit_app.py
zhixiusue's picture
Rename app.py to streamlit_app.py
ce1a29d verified
raw
history blame
3.73 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch
id_to_label = {
0: 'O',
1: 'B-TOPIC',
2: 'I-TOPIC',
3: 'B-STYLE',
4: 'I-STYLE',
5: 'B-LENGTH',
6: 'I-LENGTH',
7: 'B-LANGUAGE',
8: 'I-LANGUAGE'
}
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained(".")
model = AutoModelForTokenClassification.from_pretrained(".")
return tokenizer, model
tokenizer, model = load_model()
def predict(text, model, tokenizer, id_to_label):
tokens = list(text)
inputs = tokenizer(tokens, is_split_into_words=True, return_tensors="pt", truncation=True, max_length=128)
model.eval()
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predictions = torch.argmax(logits, dim=-1)
word_ids = inputs.word_ids(batch_index=0)
pred_labels = []
tokens_out = []
for idx, word_idx in enumerate(word_ids):
if word_idx is None:
continue
token = tokens[word_idx]
label = id_to_label[predictions[0][idx].item()]
tokens_out.append(token)
pred_labels.append(label)
return tokens_out, pred_labels
def post_process(tokens, labels):
words, word_labels = [], []
current_word = ""
current_label = None
for token, label in zip(tokens, labels):
if token in ["[CLS]", "[SEP]", "[PAD]"]:
continue
if token.startswith("##"):
current_word += token[2:]
else:
if current_word:
words.append(current_word)
word_labels.append(current_label)
current_word = token
current_label = label
if current_word:
words.append(current_word)
word_labels.append(current_label)
return words, word_labels
def align_words_labels(words, labels):
return list(zip(words, labels))
def extract_entities(aligned_result):
entities, current_entity, current_text = [], None, ""
for word, label in aligned_result:
if label == "O":
if current_entity:
entities.append({"entity": current_entity, "text": current_text})
current_entity, current_text = None, ""
continue
prefix, entity_type = label.split("-", 1)
if prefix == "B":
if current_entity:
entities.append({"entity": current_entity, "text": current_text})
current_entity, current_text = entity_type, word
elif prefix == "I" and current_entity == entity_type:
current_text += word
else:
if current_entity:
entities.append({"entity": current_entity, "text": current_text})
current_entity, current_text = entity_type, word
if current_entity:
entities.append({"entity": current_entity, "text": current_text})
return entities
# Streamlit UI
st.title("๐ŸŽฏ Learning Condition Extractor")
st.write("์‚ฌ์šฉ์ž์˜ ํ•™์Šต ๋ชฉํ‘œ ๋ฌธ์žฅ์—์„œ ์กฐ๊ฑด(TOPIC, STYLE, LENGTH, LANGUAGE)์„ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค.")
user_input = st.text_input("ํ•™์Šต ๋ชฉํ‘œ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”:", value="๋”ฅ๋Ÿฌ๋‹์„ ์‹ค์Šต ์œ„์ฃผ๋กœ 30๋ถ„ ์ด๋‚ด์— ๋ฐฐ์šฐ๊ณ  ์‹ถ์–ด์š”")
if st.button("์ถ”๋ก  ์‹œ์ž‘"):
tokens, pred_labels = predict(user_input, model, tokenizer, id_to_label)
words, word_labels = post_process(tokens, pred_labels)
aligned = align_words_labels(words, word_labels)
entities = extract_entities(aligned)
result_dict = {'TOPIC': None, 'STYLE': None, 'LENGTH': None, 'LANGUAGE': None}
for ent in entities:
result_dict[ent['entity']] = ent['text']
st.subheader("๐Ÿ“Œ ์ถ”์ถœ๋œ ์กฐ๊ฑด")
st.json(result_dict)