MentalTech commited on
Commit
9c85f16
·
verified ·
1 Parent(s): ec93d27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -61
app.py CHANGED
@@ -2,21 +2,24 @@ import re
2
  import gradio as gr
3
  from gliner import GLiNER
4
  from cerberus import Validator
 
5
 
6
  # ----------------------------------------------------------------------------
7
  # Load model + labels
8
  # ----------------------------------------------------------------------------
9
 
10
  model = GLiNER.from_pretrained("urchade/gliner_multi_pii-v1")
 
11
 
12
  with open("labels.txt", "r", encoding="utf-8") as f:
13
  labels = [line.strip() for line in f.readlines()]
14
 
 
 
15
  # ----------------------------------------------------------------------------
16
  # Simple Cerberus validation for incoming data
17
  # ----------------------------------------------------------------------------
18
 
19
- # We expect a dict with at least {"text": "<some string>"}
20
  schema = {
21
  "text": {
22
  "type": "string",
@@ -26,74 +29,88 @@ schema = {
26
 
27
  validator = Validator(schema)
28
 
29
-
30
  def validate_input(data: dict) -> str:
31
- """Validate that data has a non-empty 'text' key."""
32
  if not validator.validate(data):
33
- # If invalid, raise an exception. You could handle this more gracefully if you like.
34
  raise ValueError(f"Invalid input data. Errors: {validator.errors}")
35
  return data["text"]
36
 
37
  # ----------------------------------------------------------------------------
38
- # Core anonymize / de-anonymize logic (same as before)
39
  # ----------------------------------------------------------------------------
40
 
41
-
42
- def anonymize_text(text):
43
- """
44
- 1) Detect PII using GLiNER,
45
- 2) Replace each entity with a placeholder (<PII_LABEL_INDEX>)
46
- 3) Return anonymized_text + entity_map
47
- """
48
- entities = model.predict_entities(text, labels=labels, threshold=0.2)
49
- # Sort by start index to apply placeholders in correct order
50
- entities.sort(key=lambda e: e['start'])
51
-
52
- entity_map = {} # e.g. {'PERSON': ['Alice', 'Bob']}
53
- anonymized_text = ""
54
- next_start = 0
55
-
56
- for entity in entities:
57
- label = entity['label'].replace(" ", "_").upper()
58
- original_text = entity['text']
59
- start_idx, end_idx = entity['start'], entity['end']
60
-
61
- if label not in entity_map:
62
- entity_map[label] = [original_text]
63
- idx = 1
64
  else:
65
- # If same exact string repeated, use the same index as before
66
- if original_text in entity_map[label]:
67
- idx = entity_map[label].index(original_text) + 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  else:
69
- entity_map[label].append(original_text)
70
- idx = len(entity_map[label])
 
 
 
71
 
72
- # Copy everything before this entity
73
- anonymized_text += text[next_start:start_idx]
74
- # Insert placeholder
75
- anonymized_text += f"<PII_{label}_{idx}>"
76
- next_start = end_idx
77
 
78
- # Remainder of the text after last entity
79
- anonymized_text += text[next_start:]
80
- return anonymized_text, entity_map
81
 
 
82
 
83
- def deanonymize_text(anonymized_response, entity_map):
84
- """
85
- Replace <PII_LABEL_INDEX> placeholders in anonymized_response
86
- with their original strings from entity_map.
87
- """
88
 
 
89
  def replace_match(match):
90
- label = match.group(1) # e.g. "PERSON"
91
- idx_str = match.group(2) # e.g. "1"
92
- idx = int(idx_str) - 1 # 1-based index -> 0-based list index
93
-
94
  if label in entity_map and 0 <= idx < len(entity_map[label]):
95
  return entity_map[label][idx]
96
- return match.group(0) # If something is off, return the placeholder as-is
97
 
98
  pattern = r"<PII_(\w+)_(\d+)>"
99
  return re.sub(pattern, replace_match, anonymized_response)
@@ -103,18 +120,15 @@ def deanonymize_text(anonymized_response, entity_map):
103
  # ----------------------------------------------------------------------------
104
 
105
  def anonymize_fn(original_text):
106
- # We’ll do a simple dict so we can pass it to our Cerberus validator:
107
  data = {"text": original_text}
108
  try:
109
  user_text = validate_input(data)
110
  except ValueError as e:
111
- # If invalid, show error in Gradio output
112
  return "", {}, f"Validation error: {str(e)}"
113
 
114
- anonymized, entities = anonymize_text(user_text)
115
  return anonymized, entities, "Успешно анонимизировано!"
116
 
117
-
118
  def deanonymize_fn(anonymized_llm_response, entity_map):
119
  if not anonymized_llm_response.strip():
120
  return "", "Вставьте анонимизированный текст."
@@ -124,11 +138,9 @@ def deanonymize_fn(anonymized_llm_response, entity_map):
124
  result = deanonymize_text(anonymized_llm_response, entity_map)
125
  return result, "Успешно деанонимизировано!"
126
 
127
-
128
  md_text = """# Анонимизатор психотерапевтических сессий
129
 
130
- Вставьте текст в раздел "Исходный текст", чтобы анонимизировать сензитивные данные.
131
-
132
  """
133
 
134
  with gr.Blocks() as demo:
@@ -144,9 +156,7 @@ with gr.Blocks() as demo:
144
  )
145
  button_anon = gr.Button("Анонимизировать")
146
 
147
- # Hidden state to store the entity map
148
  entity_map_state = gr.State()
149
-
150
  message_out = gr.Textbox(label="Status", interactive=False)
151
 
152
  button_anon.click(
@@ -173,4 +183,4 @@ with gr.Blocks() as demo:
173
  )
174
 
175
  if __name__ == "__main__":
176
- demo.launch()
 
2
  import gradio as gr
3
  from gliner import GLiNER
4
  from cerberus import Validator
5
+ from transformers import AutoTokenizer
6
 
7
  # ----------------------------------------------------------------------------
8
  # Load model + labels
9
  # ----------------------------------------------------------------------------
10
 
11
  model = GLiNER.from_pretrained("urchade/gliner_multi_pii-v1")
12
+ tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
13
 
14
  with open("labels.txt", "r", encoding="utf-8") as f:
15
  labels = [line.strip() for line in f.readlines()]
16
 
17
+ MAX_TOKENS = 512 # безопасный лимит токенов на один фрагмент
18
+
19
  # ----------------------------------------------------------------------------
20
  # Simple Cerberus validation for incoming data
21
  # ----------------------------------------------------------------------------
22
 
 
23
  schema = {
24
  "text": {
25
  "type": "string",
 
29
 
30
  validator = Validator(schema)
31
 
 
32
  def validate_input(data: dict) -> str:
 
33
  if not validator.validate(data):
 
34
  raise ValueError(f"Invalid input data. Errors: {validator.errors}")
35
  return data["text"]
36
 
37
  # ----------------------------------------------------------------------------
38
+ # Chunking + Anonymization logic
39
  # ----------------------------------------------------------------------------
40
 
41
+ def split_text_into_chunks(text, max_tokens=MAX_TOKENS):
42
+ words = text.split()
43
+ chunks = []
44
+ chunk = []
45
+ chunk_token_count = 0
46
+ current_offset = 0
47
+
48
+ for word in words:
49
+ token_count = len(tokenizer.tokenize(word))
50
+ if chunk_token_count + token_count > max_tokens:
51
+ chunk_text = ' '.join(chunk)
52
+ chunks.append((chunk_text, current_offset))
53
+ current_offset += len(chunk_text) + 1
54
+ chunk = [word]
55
+ chunk_token_count = token_count
 
 
 
 
 
 
 
 
56
  else:
57
+ chunk.append(word)
58
+ chunk_token_count += token_count
59
+
60
+ if chunk:
61
+ chunk_text = ' '.join(chunk)
62
+ chunks.append((chunk_text, current_offset))
63
+
64
+ return chunks
65
+
66
+ def anonymize_text_long(text):
67
+ chunks = split_text_into_chunks(text)
68
+ full_anonymized = ""
69
+ global_entity_map = {}
70
+
71
+ for chunk_text, _ in chunks:
72
+ entities = model.predict_entities(chunk_text, labels=labels, threshold=0.2)
73
+ entities.sort(key=lambda e: e['start'])
74
+
75
+ anonymized_chunk = ""
76
+ next_start = 0
77
+
78
+ for entity in entities:
79
+ label = entity['label'].replace(" ", "_").upper()
80
+ original_text = entity['text']
81
+ start_idx, end_idx = entity['start'], entity['end']
82
+
83
+ if label not in global_entity_map:
84
+ global_entity_map[label] = [original_text]
85
+ idx = 1
86
  else:
87
+ if original_text in global_entity_map[label]:
88
+ idx = global_entity_map[label].index(original_text) + 1
89
+ else:
90
+ global_entity_map[label].append(original_text)
91
+ idx = len(global_entity_map[label])
92
 
93
+ anonymized_chunk += chunk_text[next_start:start_idx]
94
+ anonymized_chunk += f"<PII_{label}_{idx}>"
95
+ next_start = end_idx
 
 
96
 
97
+ anonymized_chunk += chunk_text[next_start:]
98
+ full_anonymized += anonymized_chunk + " "
 
99
 
100
+ return full_anonymized.strip(), global_entity_map
101
 
102
+ # ----------------------------------------------------------------------------
103
+ # De-anonymization logic
104
+ # ----------------------------------------------------------------------------
 
 
105
 
106
+ def deanonymize_text(anonymized_response, entity_map):
107
  def replace_match(match):
108
+ label = match.group(1)
109
+ idx_str = match.group(2)
110
+ idx = int(idx_str) - 1
 
111
  if label in entity_map and 0 <= idx < len(entity_map[label]):
112
  return entity_map[label][idx]
113
+ return match.group(0)
114
 
115
  pattern = r"<PII_(\w+)_(\d+)>"
116
  return re.sub(pattern, replace_match, anonymized_response)
 
120
  # ----------------------------------------------------------------------------
121
 
122
  def anonymize_fn(original_text):
 
123
  data = {"text": original_text}
124
  try:
125
  user_text = validate_input(data)
126
  except ValueError as e:
 
127
  return "", {}, f"Validation error: {str(e)}"
128
 
129
+ anonymized, entities = anonymize_text_long(user_text)
130
  return anonymized, entities, "Успешно анонимизировано!"
131
 
 
132
  def deanonymize_fn(anonymized_llm_response, entity_map):
133
  if not anonymized_llm_response.strip():
134
  return "", "Вставьте анонимизированный текст."
 
138
  result = deanonymize_text(anonymized_llm_response, entity_map)
139
  return result, "Успешно деанонимизировано!"
140
 
 
141
  md_text = """# Анонимизатор психотерапевтических сессий
142
 
143
+ Вставьте текст в раздел \"Исходный текст\", чтобы анонимизировать сензитивные данные.
 
144
  """
145
 
146
  with gr.Blocks() as demo:
 
156
  )
157
  button_anon = gr.Button("Анонимизировать")
158
 
 
159
  entity_map_state = gr.State()
 
160
  message_out = gr.Textbox(label="Status", interactive=False)
161
 
162
  button_anon.click(
 
183
  )
184
 
185
  if __name__ == "__main__":
186
+ demo.launch()