onath commited on
Commit
cf07861
Β·
verified Β·
1 Parent(s): 9cb62f7

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +302 -40
src/streamlit_app.py CHANGED
@@ -1,40 +1,302 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
- import streamlit as st
5
-
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # base = '/Users/oikantik/expts_check_samples_ocr_quality'
2
+ import streamlit as st, os, json, glob, pandas as pd
3
+ from PIL import Image
4
+
5
+ # ───────── CONFIG ────────────────────────────────────────────────────────────
6
+ langs_dict = {
7
+ 'hi': 'Hindi', 'bn': 'Bengali', 'pa': 'Punjabi', 'or': 'Odia', 'ta': 'Tamil',
8
+ 'te': 'Telugu', 'kn': 'Kannada', 'ml': 'Malayalam', 'mr': 'Marathi', 'gu': 'Gujarati'
9
+ }
10
+ doc_categories = {
11
+ 'mg': 'magazines', 'tb': 'textbooks', 'nv': 'novels', 'np': 'newspapers',
12
+ 'rp': 'research-papers', 'br': 'brochures', 'nt': 'notices', 'sy': 'syllabi',
13
+ 'qp': 'question-papers', 'mn': 'manuals'
14
+ }
15
+ base = '/files/expts_check_samples_ocr_quality'
16
+ img_dir, gcp_dir, gem_dir = [f'{base}/{d}' for d in
17
+ ('ocr_snippets_testing', 'gcp_ocr_snippets', 'gemini_ocr_snippets')]
18
+
19
+ RATING_FILE = 'ratings.csv'
20
+ UI_STATE_FILE = 'ui_state.json'
21
+ COLS = ['image_name', 'lang', 'domain', 'image_rating', 'ocr_pred_rating']
22
+ DEFAULT, SKIP = -1, -2 # -1 = not rated, -2 = skipped
23
+
24
+ # ───────── HELPERS ───────────────────────────────────────────────────────────
25
+ def read_json(path, default):
26
+ try:
27
+ with open(path) as f:
28
+ return json.load(f)
29
+ except FileNotFoundError:
30
+ return default
31
+
32
+ def write_json(path, obj):
33
+ with open(path, 'w') as f:
34
+ json.dump(obj, f, indent=2)
35
+
36
+ def load_ratings():
37
+ if os.path.exists(RATING_FILE):
38
+ return pd.read_csv(RATING_FILE)
39
+ pd.DataFrame(columns=COLS).to_csv(RATING_FILE, index=False)
40
+ return pd.read_csv(RATING_FILE)
41
+
42
+ def safe_json(path):
43
+ try:
44
+ with open(path) as f:
45
+ return json.load(f)
46
+ except FileNotFoundError:
47
+ return None
48
+
49
+ def gcp_text(path):
50
+ js = safe_json(path)
51
+ if js:
52
+ return ' '.join(
53
+ b['block_text'] for b in js.get('ocr_output', {}).get('blocks', [])
54
+ )
55
+ return 'β€”'
56
+
57
+ def gem_text(path):
58
+ js = safe_json(path)
59
+ if js:
60
+ parts = (
61
+ js.get('candidates', [{}])[0]
62
+ .get('content', {})
63
+ .get('parts', [])
64
+ )
65
+ if parts:
66
+ return ' '.join(
67
+ p.get('text', '') for p in parts if isinstance(p, dict)
68
+ )
69
+ return 'β€”'
70
+
71
+ def md15(label, txt):
72
+ st.markdown(
73
+ f'<div style="font-size:15px;"><b>{label}</b><br>{txt}</div>',
74
+ unsafe_allow_html=True,
75
+ )
76
+
77
+ # ───────── STATE INIT ────────────────────────────────────────────────────────
78
+ ratings_df = load_ratings()
79
+ ui_state = read_json(
80
+ UI_STATE_FILE,
81
+ {"last_lang": None, "show_completed": False, "view_completed": False},
82
+ )
83
+
84
+ # ───────── SIDEBAR ───────────────────────────────────────────────────────────
85
+
86
+ # language selector
87
+ default_lang = ui_state.get("last_lang")
88
+ default_lang_idx = (
89
+ list(langs_dict.values()).index(default_lang)
90
+ if default_lang in langs_dict.values()
91
+ else 0
92
+ )
93
+ lang_name = st.sidebar.selectbox(
94
+ 'Language', list(langs_dict.values()), index=default_lang_idx
95
+ )
96
+ ui_state["last_lang"] = lang_name # remember selection
97
+ lang_code = next(k for k, v in langs_dict.items() if v == lang_name)
98
+
99
+ # overall progress
100
+ total_lang = len(glob.glob(os.path.join(img_dir, lang_code, '*')))
101
+ done_lang = ratings_df[ratings_df.lang == lang_code].image_name.nunique()
102
+ st.sidebar.markdown(f'**Progress:** {done_lang}/{total_lang}')
103
+
104
+ # per-domain progress
105
+ with st.sidebar.expander('Per-domain progress'):
106
+ for dk, dn in doc_categories.items():
107
+ total = len(glob.glob(os.path.join(img_dir, lang_code, f'{dk}_{lang_code}_*')))
108
+ done = ratings_df[
109
+ (ratings_df.lang == lang_code) & (ratings_df.domain == dk)
110
+ ].image_name.nunique()
111
+ st.write(f'{dn}: {done}/{total}')
112
+
113
+ # completed-table toggle
114
+ show_tbl = st.sidebar.checkbox(
115
+ 'Show completed table',
116
+ value=ui_state.get("show_completed", False) # safe default
117
+ )
118
+ ui_state["show_completed"] = show_tbl
119
+
120
+ if show_tbl:
121
+ st.sidebar.dataframe(
122
+ ratings_df[ratings_df.lang == lang_code][COLS],
123
+ use_container_width=True,
124
+ )
125
+
126
+ # visual review toggle
127
+ view_comp = st.sidebar.checkbox(
128
+ 'View completed visually',
129
+ value=ui_state.get("view_completed", False) # safe default
130
+ )
131
+ ui_state["view_completed"] = view_comp
132
+
133
+ # persist sidebar choices immediately
134
+ write_json(UI_STATE_FILE, ui_state)
135
+
136
+
137
+ # ───────── CSV UPDATE --------------------------------------------------------
138
+ def update_csv(name, img=None, ocr=None, skip=False):
139
+ global ratings_df
140
+ if skip:
141
+ img = ocr = SKIP
142
+ mask = ratings_df.image_name == name
143
+ if mask.any():
144
+ if img is not None:
145
+ ratings_df.loc[mask, 'image_rating'] = img
146
+ if ocr is not None:
147
+ ratings_df.loc[mask, 'ocr_pred_rating'] = ocr
148
+ else:
149
+ ratings_df = pd.concat(
150
+ [
151
+ ratings_df,
152
+ pd.DataFrame(
153
+ [
154
+ {
155
+ 'image_name': name,
156
+ 'lang': lang_code,
157
+ 'domain': name[:2],
158
+ 'image_rating': img if img is not None else DEFAULT,
159
+ 'ocr_pred_rating': ocr if ocr is not None else DEFAULT,
160
+ }
161
+ ]
162
+ ),
163
+ ],
164
+ ignore_index=True,
165
+ )
166
+ ratings_df.to_csv(RATING_FILE, index=False)
167
+
168
+ # ───────── MAIN – PENDING SNIPPETS ───────────────────────────────────────────
169
+ tabs = st.tabs(list(doc_categories.values()))
170
+
171
+ for (dk, dn), tab in zip(doc_categories.items(), tabs):
172
+ with tab:
173
+ all_imgs = sorted(
174
+ glob.glob(os.path.join(img_dir, lang_code, f'{dk}_{lang_code}_*'))
175
+ )
176
+ done_imgs = ratings_df[
177
+ (ratings_df.lang == lang_code) & (ratings_df.domain == dk)
178
+ ].image_name.tolist()
179
+ pending = [p for p in all_imgs if os.path.basename(p) not in done_imgs]
180
+
181
+ if not pending:
182
+ st.success('All snippets done for this domain!')
183
+ else:
184
+ for file in pending:
185
+ name = os.path.basename(file)
186
+ stem = os.path.splitext(name)[0]
187
+ region = name.split('_')[-1].split('.')[0]
188
+ uid = '_'.join(name.split('_')[2:-1])
189
+
190
+ with st.container():
191
+ c1, c2 = st.columns([1, 2], gap='large')
192
+
193
+ # image + rating buttons
194
+ with c1:
195
+ st.image(Image.open(file))
196
+ st.markdown(
197
+ f'**File:** {name}<br>**UID:** {uid}<br>**Region:** {region}',
198
+ unsafe_allow_html=True,
199
+ )
200
+ b1, b2, b3, b4 = st.columns(4)
201
+ if b1.button('πŸ‘Ž', key=f'{stem}_img0'):
202
+ update_csv(name, img=0)
203
+ if b2.button('😐', key=f'{stem}_img1'):
204
+ update_csv(name, img=1)
205
+ if b3.button('πŸ‘', key=f'{stem}_img2'):
206
+ update_csv(name, img=2)
207
+ if b4.button('⏭️', key=f'{stem}_skip'):
208
+ update_csv(name, skip=True)
209
+
210
+ # ocr texts + comparison buttons
211
+ with c2:
212
+ md15(
213
+ 'GCP OCR',
214
+ gcp_text(os.path.join(gcp_dir, lang_code, f'{stem}.json')),
215
+ )
216
+ st.markdown('<hr>', unsafe_allow_html=True)
217
+ md15(
218
+ 'Gemini OCR',
219
+ gem_text(os.path.join(gem_dir, lang_code, f'{stem}.json')),
220
+ )
221
+ st.markdown('<hr>', unsafe_allow_html=True)
222
+ t1, t2, t3 = st.columns(3)
223
+ if t1.button(
224
+ 'πŸ‘ GCP', key=f'{stem}_ocr0'
225
+ ):
226
+ update_csv(name, ocr=0)
227
+ if t2.button(
228
+ '😐 Equal', key=f'{stem}_ocr1'
229
+ ):
230
+ update_csv(name, ocr=1)
231
+ if t3.button(
232
+ 'πŸ‘ Gemini', key=f'{stem}_ocr2'
233
+ ):
234
+ update_csv(name, ocr=2)
235
+
236
+ st.markdown('---')
237
+
238
+ # ───────── VISUALISE COMPLETED ───────────────────────────────────────────────
239
+ if ui_state["view_completed"]:
240
+ st.header('βœ… Completed snippets')
241
+ comp_tabs = st.tabs(list(doc_categories.values()))
242
+
243
+ for (dk, dn), ctab in zip(doc_categories.items(), comp_tabs):
244
+ with ctab:
245
+ done_rows = ratings_df[
246
+ (ratings_df.lang == lang_code)
247
+ & (ratings_df.domain == dk)
248
+ & (ratings_df.image_rating != DEFAULT)
249
+ & (ratings_df.ocr_pred_rating != DEFAULT)
250
+ ]
251
+
252
+ if done_rows.empty:
253
+ st.info('Nothing completed here yet.')
254
+ continue
255
+
256
+ for _, row in done_rows.iterrows():
257
+ file = os.path.join(img_dir, lang_code, row.image_name)
258
+ stem = os.path.splitext(row.image_name)[0]
259
+ region = row.image_name.split('_')[-1].split('.')[0]
260
+ uid = '_'.join(row.image_name.split('_')[2:-1])
261
+
262
+ with st.container():
263
+ c1, c2 = st.columns([1, 2], gap='large')
264
+
265
+ # image + static badge
266
+ with c1:
267
+ st.image(Image.open(file))
268
+ st.markdown(
269
+ f'**File:** {row.image_name}<br>'
270
+ f'**UID:** {uid}<br>'
271
+ f'**Region:** {region}',
272
+ unsafe_allow_html=True,
273
+ )
274
+ img_badge = {0: 'πŸ‘Ž', 1: '😐', 2: 'πŸ‘', SKIP: '⏭️'}[
275
+ row.image_rating
276
+ ]
277
+ st.markdown(f'Image rating: **{img_badge}**')
278
+
279
+ # OCR texts + static badge
280
+ with c2:
281
+ md15(
282
+ 'GCP OCR',
283
+ gcp_text(
284
+ os.path.join(gcp_dir, lang_code, f'{stem}.json')
285
+ ),
286
+ )
287
+ st.markdown('<hr>', unsafe_allow_html=True)
288
+ md15(
289
+ 'Gemini OCR',
290
+ gem_text(
291
+ os.path.join(gem_dir, lang_code, f'{stem}.json')
292
+ ),
293
+ )
294
+ ocr_badge = {
295
+ 0: 'GCP better',
296
+ 1: 'Equal',
297
+ 2: 'Gemini better',
298
+ SKIP: 'Skipped',
299
+ }[row.ocr_pred_rating]
300
+ st.success(f'Chosen: {ocr_badge}')
301
+
302
+ st.markdown('---')