Tomatillo commited on
Commit
60f5dd2
·
verified ·
1 Parent(s): dc853ae

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +122 -67
src/streamlit_app.py CHANGED
@@ -1,8 +1,13 @@
 
 
1
  import streamlit as st
2
  import io
3
  import csv
4
- from datetime import datetime
5
  from segments import SegmentsClient
 
 
 
6
  from get_labels_from_samples import (
7
  get_samples as get_samples_objects,
8
  export_frames_and_annotations,
@@ -44,6 +49,95 @@ def parse_classes(input_str: str) -> list:
44
  return sorted(set(classes))
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  def generate_csv(metrics: list, dataset_identifier: str) -> str:
48
  """
49
  Generate CSV content from list of per-sample metrics.
@@ -102,6 +196,9 @@ if api_key and dataset_identifier:
102
  if is_multisensor:
103
  sensor_select = st.selectbox("Choose sensor (optional)", options=['All sensors'] + sensor_names)
104
 
 
 
 
105
  if run_button:
106
  st.session_state.csv_content = None
107
  st.session_state.error = None
@@ -122,75 +219,33 @@ if run_button:
122
  st.info("Checking dataset type...")
123
  try:
124
  target_classes = parse_classes(classes_input)
125
- client = init_client(api_key)
126
  metrics = []
127
  # Update loader after dataset type check
128
  if status_ctx is not None:
129
  status_ctx.update(label="Dataset type checked. Processing samples...", state="running")
130
- for sample in samples_objects:
131
- try:
132
- label = client.get_label(sample.uuid)
133
- labelset = getattr(label, 'labelset', '') or ''
134
- labeled_by = getattr(label, 'created_by', '') or ''
135
- reviewed_by = getattr(label, 'reviewed_by', '') or ''
136
- if is_multisensor and sensor_select and sensor_select != 'All sensors':
137
- frames_list = export_sensor_frames_and_annotations(label, sensor_select)
138
- sensor_val = sensor_select
139
- num_frames = len(frames_list)
140
- total_annotations = sum(len(f['annotations']) for f in frames_list)
141
- matching_annotations = sum(
142
- 1
143
- for f in frames_list
144
- for ann in f['annotations']
145
- if getattr(ann, 'category_id', None) in target_classes
146
- )
147
- elif is_multisensor and (not sensor_select or sensor_select == 'All sensors'):
148
- all_sensor_frames = export_all_sensor_frames_and_annotations(label)
149
- for sensor_name, frames_list in all_sensor_frames.items():
150
- num_frames = len(frames_list)
151
- total_annotations = sum(len(f['annotations']) for f in frames_list)
152
- matching_annotations = sum(
153
- 1
154
- for f in frames_list
155
- for ann in f['annotations']
156
- if getattr(ann, 'category_id', None) in target_classes
157
- )
158
- metrics.append({
159
- 'name': getattr(sample, 'name', sample.uuid),
160
- 'uuid': sample.uuid,
161
- 'labelset': labelset,
162
- 'sensor': sensor_name,
163
- 'num_frames': num_frames,
164
- 'total_annotations': total_annotations,
165
- 'matching_annotations': matching_annotations,
166
- 'labeled_by': labeled_by,
167
- 'reviewed_by': reviewed_by
168
- })
169
- continue
170
- else:
171
- frames_list = export_frames_and_annotations(label)
172
- sensor_val = ''
173
- num_frames = len(frames_list)
174
- total_annotations = sum(len(f['annotations']) for f in frames_list)
175
- matching_annotations = sum(
176
- 1
177
- for f in frames_list
178
- for ann in f['annotations']
179
- if getattr(ann, 'category_id', None) in target_classes
180
- )
181
- metrics.append({
182
- 'name': getattr(sample, 'name', sample.uuid),
183
- 'uuid': sample.uuid,
184
- 'labelset': labelset,
185
- 'sensor': sensor_val if is_multisensor else '',
186
- 'num_frames': num_frames,
187
- 'total_annotations': total_annotations,
188
- 'matching_annotations': matching_annotations,
189
- 'labeled_by': labeled_by,
190
- 'reviewed_by': reviewed_by
191
- })
192
- except Exception as e:
193
- continue
194
  if not metrics:
195
  st.session_state.error = "No metrics could be generated for the dataset."
196
  else:
@@ -213,4 +268,4 @@ if st.session_state.csv_content:
213
  data=st.session_state.csv_content,
214
  file_name=filename,
215
  mime="text/csv"
216
- )
 
1
+ #!/usr/bin/env python3
2
+
3
  import streamlit as st
4
  import io
5
  import csv
6
+ import concurrent.futures
7
  from segments import SegmentsClient
8
+ from datetime import datetime
9
+ import sys
10
+ import os
11
  from get_labels_from_samples import (
12
  get_samples as get_samples_objects,
13
  export_frames_and_annotations,
 
49
  return sorted(set(classes))
50
 
51
 
52
+ def _count_from_frames(frames, target_set):
53
+ """Helper to count frames, total annotations, and matching annotations directly."""
54
+ if not frames:
55
+ return 0, 0, 0
56
+ num_frames = len(frames)
57
+ total_annotations = 0
58
+ matching_annotations = 0
59
+ for f in frames:
60
+ anns = getattr(f, 'annotations', [])
61
+ total_annotations += len(anns)
62
+ if target_set:
63
+ for ann in anns:
64
+ if getattr(ann, 'category_id', None) in target_set:
65
+ matching_annotations += 1
66
+ return num_frames, total_annotations, matching_annotations
67
+
68
+
69
+ def compute_metrics_for_sample(sample, api_key, target_set, is_multisensor, sensor_select):
70
+ """
71
+ Fetch label for a single sample and compute metrics.
72
+ Returns a list of metric dicts (one per sensor if 'All sensors', otherwise one).
73
+ """
74
+ try:
75
+ client = init_client(api_key)
76
+ label = client.get_label(sample.uuid)
77
+ labelset = getattr(label, 'labelset', '') or ''
78
+ labeled_by = getattr(label, 'created_by', '') or ''
79
+ reviewed_by = getattr(label, 'reviewed_by', '') or ''
80
+
81
+ metrics_rows = []
82
+
83
+ if is_multisensor:
84
+ sensors = getattr(getattr(label, 'attributes', None), 'sensors', None) or []
85
+ if sensor_select and sensor_select != 'All sensors':
86
+ # single sensor
87
+ for sensor in sensors:
88
+ if getattr(sensor, 'name', None) == sensor_select:
89
+ frames = getattr(getattr(sensor, 'attributes', None), 'frames', [])
90
+ num_frames, total_annotations, matching_annotations = _count_from_frames(frames, target_set)
91
+ metrics_rows.append({
92
+ 'name': getattr(sample, 'name', sample.uuid),
93
+ 'uuid': sample.uuid,
94
+ 'labelset': labelset,
95
+ 'sensor': sensor_select,
96
+ 'num_frames': num_frames,
97
+ 'total_annotations': total_annotations,
98
+ 'matching_annotations': matching_annotations,
99
+ 'labeled_by': labeled_by,
100
+ 'reviewed_by': reviewed_by
101
+ })
102
+ break
103
+ else:
104
+ # all sensors
105
+ for sensor in sensors:
106
+ sensor_name = getattr(sensor, 'name', 'Unknown')
107
+ frames = getattr(getattr(sensor, 'attributes', None), 'frames', [])
108
+ num_frames, total_annotations, matching_annotations = _count_from_frames(frames, target_set)
109
+ metrics_rows.append({
110
+ 'name': getattr(sample, 'name', sample.uuid),
111
+ 'uuid': sample.uuid,
112
+ 'labelset': labelset,
113
+ 'sensor': sensor_name,
114
+ 'num_frames': num_frames,
115
+ 'total_annotations': total_annotations,
116
+ 'matching_annotations': matching_annotations,
117
+ 'labeled_by': labeled_by,
118
+ 'reviewed_by': reviewed_by
119
+ })
120
+ else:
121
+ # single-sensor dataset
122
+ frames = getattr(getattr(label, 'attributes', None), 'frames', [])
123
+ num_frames, total_annotations, matching_annotations = _count_from_frames(frames, target_set)
124
+ metrics_rows.append({
125
+ 'name': getattr(sample, 'name', sample.uuid),
126
+ 'uuid': sample.uuid,
127
+ 'labelset': labelset,
128
+ 'sensor': '',
129
+ 'num_frames': num_frames,
130
+ 'total_annotations': total_annotations,
131
+ 'matching_annotations': matching_annotations,
132
+ 'labeled_by': labeled_by,
133
+ 'reviewed_by': reviewed_by
134
+ })
135
+
136
+ return metrics_rows
137
+ except Exception:
138
+ return []
139
+
140
+
141
  def generate_csv(metrics: list, dataset_identifier: str) -> str:
142
  """
143
  Generate CSV content from list of per-sample metrics.
 
196
  if is_multisensor:
197
  sensor_select = st.selectbox("Choose sensor (optional)", options=['All sensors'] + sensor_names)
198
 
199
+ # Concurrency control
200
+ parallel_workers = st.slider("Parallel requests", min_value=1, max_value=32, value=8, help="Increase to speed up processing; lower if you hit API limits.")
201
+
202
  if run_button:
203
  st.session_state.csv_content = None
204
  st.session_state.error = None
 
219
  st.info("Checking dataset type...")
220
  try:
221
  target_classes = parse_classes(classes_input)
222
+ target_set = set(target_classes)
223
  metrics = []
224
  # Update loader after dataset type check
225
  if status_ctx is not None:
226
  status_ctx.update(label="Dataset type checked. Processing samples...", state="running")
227
+ progress = st.progress(0)
228
+ total = len(samples_objects)
229
+ done = 0
230
+ with concurrent.futures.ThreadPoolExecutor(max_workers=parallel_workers) as executor:
231
+ futures = [
232
+ executor.submit(
233
+ compute_metrics_for_sample,
234
+ sample,
235
+ api_key,
236
+ target_set,
237
+ is_multisensor,
238
+ sensor_select,
239
+ )
240
+ for sample in samples_objects
241
+ ]
242
+ for future in concurrent.futures.as_completed(futures):
243
+ rows = future.result()
244
+ if rows:
245
+ metrics.extend(rows)
246
+ done += 1
247
+ if total:
248
+ progress.progress(min(done / total, 1.0))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  if not metrics:
250
  st.session_state.error = "No metrics could be generated for the dataset."
251
  else:
 
268
  data=st.session_state.csv_content,
269
  file_name=filename,
270
  mime="text/csv"
271
+ )