File size: 6,679 Bytes
337fbc2
 
 
 
9c85f16
337fbc2
 
 
 
 
 
9c85f16
337fbc2
 
 
 
9c85f16
 
337fbc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c85f16
337fbc2
 
9c85f16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337fbc2
9c85f16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337fbc2
9c85f16
 
 
 
 
337fbc2
9c85f16
 
 
337fbc2
9c85f16
 
337fbc2
9c85f16
337fbc2
9c85f16
 
 
337fbc2
9c85f16
337fbc2
9c85f16
 
 
337fbc2
 
9c85f16
337fbc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c85f16
80e3f59
337fbc2
 
 
80e3f59
337fbc2
 
 
 
80e3f59
337fbc2
5b650c7
337fbc2
9c85f16
337fbc2
 
 
 
 
 
 
 
74e36b2
337fbc2
 
74e36b2
337fbc2
74e36b2
337fbc2
 
 
 
 
 
 
 
 
 
 
 
80e3f59
337fbc2
 
427ea2c
337fbc2
74e36b2
337fbc2
 
 
 
 
 
 
 
 
 
9c85f16
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import re
import gradio as gr
from gliner import GLiNER
from cerberus import Validator
from transformers import AutoTokenizer

# ----------------------------------------------------------------------------
# Load model + labels
# ----------------------------------------------------------------------------

model = GLiNER.from_pretrained("urchade/gliner_multi_pii-v1")
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")

with open("labels.txt", "r", encoding="utf-8") as f:
    labels = [line.strip() for line in f.readlines()]

MAX_TOKENS = 512  # безопасный лимит токенов на один фрагмент

# ----------------------------------------------------------------------------
# Simple Cerberus validation for incoming data
# ----------------------------------------------------------------------------

schema = {
    "text": {
        "type": "string",
        "empty": False
    }
}

validator = Validator(schema)

def validate_input(data: dict) -> str:
    if not validator.validate(data):
        raise ValueError(f"Invalid input data. Errors: {validator.errors}")
    return data["text"]

# ----------------------------------------------------------------------------
# Chunking + Anonymization logic
# ----------------------------------------------------------------------------

def split_text_into_chunks(text, max_tokens=MAX_TOKENS):
    words = text.split()
    chunks = []
    chunk = []
    chunk_token_count = 0
    current_offset = 0

    for word in words:
        token_count = len(tokenizer.tokenize(word))
        if chunk_token_count + token_count > max_tokens:
            chunk_text = ' '.join(chunk)
            chunks.append((chunk_text, current_offset))
            current_offset += len(chunk_text) + 1
            chunk = [word]
            chunk_token_count = token_count
        else:
            chunk.append(word)
            chunk_token_count += token_count

    if chunk:
        chunk_text = ' '.join(chunk)
        chunks.append((chunk_text, current_offset))

    return chunks

def anonymize_text_long(text):
    chunks = split_text_into_chunks(text)
    full_anonymized = ""
    global_entity_map = {}

    for chunk_text, _ in chunks:
        entities = model.predict_entities(chunk_text, labels=labels, threshold=0.2)
        entities.sort(key=lambda e: e['start'])

        anonymized_chunk = ""
        next_start = 0

        for entity in entities:
            label = entity['label'].replace(" ", "_").upper()
            original_text = entity['text']
            start_idx, end_idx = entity['start'], entity['end']

            if label not in global_entity_map:
                global_entity_map[label] = [original_text]
                idx = 1
            else:
                if original_text in global_entity_map[label]:
                    idx = global_entity_map[label].index(original_text) + 1
                else:
                    global_entity_map[label].append(original_text)
                    idx = len(global_entity_map[label])

            anonymized_chunk += chunk_text[next_start:start_idx]
            anonymized_chunk += f"<PII_{label}_{idx}>"
            next_start = end_idx

        anonymized_chunk += chunk_text[next_start:]
        full_anonymized += anonymized_chunk + " "

    return full_anonymized.strip(), global_entity_map

# ----------------------------------------------------------------------------
# De-anonymization logic
# ----------------------------------------------------------------------------

def deanonymize_text(anonymized_response, entity_map):
    def replace_match(match):
        label = match.group(1)
        idx_str = match.group(2)
        idx = int(idx_str) - 1
        if label in entity_map and 0 <= idx < len(entity_map[label]):
            return entity_map[label][idx]
        return match.group(0)

    pattern = r"<PII_(\w+)_(\d+)>"
    return re.sub(pattern, replace_match, anonymized_response)

# ----------------------------------------------------------------------------
# Gradio Interface
# ----------------------------------------------------------------------------

def anonymize_fn(original_text):
    data = {"text": original_text}
    try:
        user_text = validate_input(data)
    except ValueError as e:
        return "", {}, f"Validation error: {str(e)}"

    anonymized, entities = anonymize_text_long(user_text)
    return anonymized, entities, "Успешно анонимизировано!"

def deanonymize_fn(anonymized_llm_response, entity_map):
    if not anonymized_llm_response.strip():
        return "", "Вставьте анонимизированный текст."
    if not entity_map:
        return "", "No entity map found; anonymize some text first."

    result = deanonymize_text(anonymized_llm_response, entity_map)
    return result, "Успешно деанонимизировано!"

md_text = """# Анонимизатор психотерапевтических сессий

Вставьте текст в раздел \"Исходный текст\", чтобы анонимизировать сензитивные данные.
"""

with gr.Blocks() as demo:
    gr.Markdown(md_text)

    with gr.Row():
        with gr.Column():
            original_text = gr.Textbox(
                lines=6, label="Исходный текст (анонимизировать)"
            )
            anonymized_text = gr.Textbox(
                lines=6, label="Анонимизированный текст", interactive=False
            )
            button_anon = gr.Button("Анонимизировать")

            entity_map_state = gr.State()
            message_out = gr.Textbox(label="Status", interactive=False)

            button_anon.click(
                anonymize_fn,
                inputs=[original_text],
                outputs=[anonymized_text, entity_map_state, message_out]
            )

        with gr.Column():
            anonymized_llm_response = gr.Textbox(
                lines=6, label="Анонимизированная сессия (вставить)"
            )
            deanonymized_text = gr.Textbox(
                lines=6, label="Деанонимизированная сессия", interactive=False
            )
            button_deanon = gr.Button("Деанонимизировать")

            message_out_de = gr.Textbox(label="Status", interactive=False)

            button_deanon.click(
                deanonymize_fn,
                inputs=[anonymized_llm_response, entity_map_state],
                outputs=[deanonymized_text, message_out_de]
            )

if __name__ == "__main__":
    demo.launch()