File size: 3,731 Bytes
9dcbc08
ba2392f
 
9dcbc08
ba2392f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import streamlit as st
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch

id_to_label = {
    0: 'O',
    1: 'B-TOPIC',
    2: 'I-TOPIC',
    3: 'B-STYLE',
    4: 'I-STYLE',
    5: 'B-LENGTH',
    6: 'I-LENGTH',
    7: 'B-LANGUAGE',
    8: 'I-LANGUAGE'
}

@st.cache_resource
def load_model():
    tokenizer = AutoTokenizer.from_pretrained(".")
    model = AutoModelForTokenClassification.from_pretrained(".")
    return tokenizer, model

tokenizer, model = load_model()

def predict(text, model, tokenizer, id_to_label):
    tokens = list(text)
    inputs = tokenizer(tokens, is_split_into_words=True, return_tensors="pt", truncation=True, max_length=128)
    model.eval()
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)

    word_ids = inputs.word_ids(batch_index=0)
    pred_labels = []
    tokens_out = []

    for idx, word_idx in enumerate(word_ids):
        if word_idx is None:
            continue
        token = tokens[word_idx]
        label = id_to_label[predictions[0][idx].item()]
        tokens_out.append(token)
        pred_labels.append(label)

    return tokens_out, pred_labels

def post_process(tokens, labels):
    words, word_labels = [], []
    current_word = ""
    current_label = None
    for token, label in zip(tokens, labels):
        if token in ["[CLS]", "[SEP]", "[PAD]"]:
            continue
        if token.startswith("##"):
            current_word += token[2:]
        else:
            if current_word:
                words.append(current_word)
                word_labels.append(current_label)
            current_word = token
            current_label = label
    if current_word:
        words.append(current_word)
        word_labels.append(current_label)
    return words, word_labels

def align_words_labels(words, labels):
    return list(zip(words, labels))

def extract_entities(aligned_result):
    entities, current_entity, current_text = [], None, ""
    for word, label in aligned_result:
        if label == "O":
            if current_entity:
                entities.append({"entity": current_entity, "text": current_text})
                current_entity, current_text = None, ""
            continue
        prefix, entity_type = label.split("-", 1)
        if prefix == "B":
            if current_entity:
                entities.append({"entity": current_entity, "text": current_text})
            current_entity, current_text = entity_type, word
        elif prefix == "I" and current_entity == entity_type:
            current_text += word
        else:
            if current_entity:
                entities.append({"entity": current_entity, "text": current_text})
            current_entity, current_text = entity_type, word
    if current_entity:
        entities.append({"entity": current_entity, "text": current_text})
    return entities

# Streamlit UI
st.title("๐ŸŽฏ Learning Condition Extractor")
st.write("์‚ฌ์šฉ์ž์˜ ํ•™์Šต ๋ชฉํ‘œ ๋ฌธ์žฅ์—์„œ ์กฐ๊ฑด(TOPIC, STYLE, LENGTH, LANGUAGE)์„ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค.")

user_input = st.text_input("ํ•™์Šต ๋ชฉํ‘œ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”:", value="๋”ฅ๋Ÿฌ๋‹์„ ์‹ค์Šต ์œ„์ฃผ๋กœ 30๋ถ„ ์ด๋‚ด์— ๋ฐฐ์šฐ๊ณ  ์‹ถ์–ด์š”")

if st.button("์ถ”๋ก  ์‹œ์ž‘"):
    tokens, pred_labels = predict(user_input, model, tokenizer, id_to_label)
    words, word_labels = post_process(tokens, pred_labels)
    aligned = align_words_labels(words, word_labels)
    entities = extract_entities(aligned)

    result_dict = {'TOPIC': None, 'STYLE': None, 'LENGTH': None, 'LANGUAGE': None}
    for ent in entities:
        result_dict[ent['entity']] = ent['text']

    st.subheader("๐Ÿ“Œ ์ถ”์ถœ๋œ ์กฐ๊ฑด")
    st.json(result_dict)