File size: 12,326 Bytes
60f5dd2
00a6dbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60f5dd2
9e3c899
b6ff680
 
dc853ae
60f5dd2
 
 
c919c75
9e3c899
b6ff680
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00a6dbb
 
60f5dd2
00a6dbb
60f5dd2
 
 
00a6dbb
60f5dd2
 
 
 
 
00a6dbb
 
60f5dd2
00a6dbb
 
 
60f5dd2
 
00a6dbb
60f5dd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00a6dbb
60f5dd2
 
 
 
 
 
 
 
00a6dbb
60f5dd2
 
 
 
 
 
 
 
 
00a6dbb
60f5dd2
 
 
 
 
 
 
 
00a6dbb
60f5dd2
 
 
 
 
 
00a6dbb
60f5dd2
 
 
 
 
 
 
 
00a6dbb
60f5dd2
 
 
 
 
 
 
 
 
00a6dbb
b6ff680
 
 
 
 
 
 
00a6dbb
b6ff680
00a6dbb
 
 
 
 
 
b6ff680
 
00a6dbb
b6ff680
 
00a6dbb
 
 
 
 
 
 
b6ff680
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60f5dd2
b6ff680
 
 
 
 
 
 
 
 
 
 
 
664215b
 
 
 
 
 
 
 
60f5dd2
664215b
 
20723b1
664215b
20723b1
60f5dd2
 
20723b1
 
 
 
 
 
60f5dd2
 
 
00a6dbb
60f5dd2
 
 
 
 
20723b1
 
 
 
 
 
 
 
 
664215b
 
 
00a6dbb
20723b1
 
 
 
 
 
 
 
 
 
664215b
 
 
 
b6ff680
 
 
 
 
9ce355a
 
b6ff680
9ce355a
b6ff680
9ce355a
b6ff680
c919c75
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
#!/usr/bin/env python3
"""
File: count_by_class.py
Location: 6_Data_metrics/count_by_class.py

Description:
    Streamlit application to count and report metrics per sample for specified classes.
    For each sample, outputs:
      - Sample name
      - Sample URL (including labelset)
      - Number of frames (ignoring a 31st frame if present)
      - Total number of annotations
      - Number of annotations matching any of the user-specified classes
      - Labeled by (from label data)
      - Reviewed by (from label data)

Usage:
    streamlit run 6_Data_metrics/count_by_class.py
"""

import streamlit as st
import io
import csv
from segments import SegmentsClient
from datetime import datetime
import sys
import os
from get_labels_from_samples import get_samples as get_samples_objects

def init_session_state():
    if 'csv_content' not in st.session_state:
        st.session_state.csv_content = None
    if 'error' not in st.session_state:
        st.session_state.error = None


def init_client(api_key: str) -> SegmentsClient:
    """Initialize the Segments.ai API client using the provided API key."""
    return SegmentsClient(api_key)


def parse_classes(input_str: str) -> list:
    """
    Parse user input for classes (ranges and comma-separated lists). Returns unique sorted list of ints.
    """
    classes = []
    tokens = input_str.split(',')
    for token in tokens:
        token = token.strip()
        if '-' in token:
            try:
                start, end = map(int, token.split('-'))
                classes.extend(range(start, end + 1))
            except ValueError:
                continue
        else:
            try:
                classes.append(int(token))
            except ValueError:
                continue
    return sorted(set(classes))


def _count_from_frames(frames, target_set, class_ids):
    """Helper to count frames, totals, and per-class counts directly."""
    if not frames:
        return 0, 0, 0, {cid: 0 for cid in class_ids}
    num_frames = len(frames)
    total_annotations = 0
    matching_annotations = 0
    class_counts = {cid: 0 for cid in class_ids}
    for f in frames:
        anns = getattr(f, 'annotations', [])
        total_annotations += len(anns)
        if target_set:
            for ann in anns:
                cid = getattr(ann, 'category_id', None)
                if cid in target_set:
                    matching_annotations += 1
                    if cid in class_counts:
                        class_counts[cid] += 1
    return num_frames, total_annotations, matching_annotations, class_counts


def compute_metrics_for_sample(sample, api_key, target_set, class_ids, is_multisensor, sensor_select):
    """
    Fetch label for a single sample and compute metrics.
    Returns a list of metric dicts (one per sensor if 'All sensors', otherwise one).
    """
    try:
        client = init_client(api_key)
        label = client.get_label(sample.uuid)
        labelset = getattr(label, 'labelset', '') or ''
        labeled_by = getattr(label, 'created_by', '') or ''
        reviewed_by = getattr(label, 'reviewed_by', '') or ''

        metrics_rows = []

        if is_multisensor:
            sensors = getattr(getattr(label, 'attributes', None), 'sensors', None) or []
            if sensor_select and sensor_select != 'All sensors':
                # single sensor
                for sensor in sensors:
                    if getattr(sensor, 'name', None) == sensor_select:
                        frames = getattr(getattr(sensor, 'attributes', None), 'frames', [])
                        num_frames, total_annotations, matching_annotations, class_counts = _count_from_frames(frames, target_set, class_ids)
                        metrics_rows.append({
                            'name': getattr(sample, 'name', sample.uuid),
                            'uuid': sample.uuid,
                            'labelset': labelset,
                            'sensor': sensor_select,
                            'num_frames': num_frames,
                            'total_annotations': total_annotations,
                            'matching_annotations': matching_annotations,
                            'class_counts': class_counts,
                            'labeled_by': labeled_by,
                            'reviewed_by': reviewed_by
                        })
                        break
            else:
                # all sensors
                for sensor in sensors:
                    sensor_name = getattr(sensor, 'name', 'Unknown')
                    frames = getattr(getattr(sensor, 'attributes', None), 'frames', [])
                    num_frames, total_annotations, matching_annotations, class_counts = _count_from_frames(frames, target_set, class_ids)
                    metrics_rows.append({
                        'name': getattr(sample, 'name', sample.uuid),
                        'uuid': sample.uuid,
                        'labelset': labelset,
                        'sensor': sensor_name,
                        'num_frames': num_frames,
                        'total_annotations': total_annotations,
                        'matching_annotations': matching_annotations,
                        'class_counts': class_counts,
                        'labeled_by': labeled_by,
                        'reviewed_by': reviewed_by
                    })
        else:
            # single-sensor dataset
            frames = getattr(getattr(label, 'attributes', None), 'frames', [])
            num_frames, total_annotations, matching_annotations, class_counts = _count_from_frames(frames, target_set, class_ids)
            metrics_rows.append({
                'name': getattr(sample, 'name', sample.uuid),
                'uuid': sample.uuid,
                'labelset': labelset,
                'sensor': '',
                'num_frames': num_frames,
                'total_annotations': total_annotations,
                'matching_annotations': matching_annotations,
                'class_counts': class_counts,
                'labeled_by': labeled_by,
                'reviewed_by': reviewed_by
            })

        return metrics_rows
    except Exception:
        return []


def generate_csv(metrics: list, dataset_identifier: str, target_classes: list[int]) -> str:
    """
    Generate CSV content from list of per-sample metrics.
    Columns: name, sample_url, sensor, num_frames, total_annotations,
             matching_annotations, labeled_by, reviewed_by
    """
    output = io.StringIO()
    writer = csv.writer(output)
    header = [
        'name', 'sample_url', 'sensor', 'num_frames',
        'total_annotations', 'matching_annotations'
    ]
    # dynamic per-class columns
    header.extend([f'class_{cid}' for cid in target_classes])
    header.extend(['labeled_by', 'reviewed_by'])
    writer.writerow(header)
    for m in metrics:
        url = f"https://app.segments.ai/{dataset_identifier}/samples/{m['uuid']}/{m['labelset']}"
        row = [
            m['name'], url, m['sensor'],
            m['num_frames'], m['total_annotations'],
            m['matching_annotations']
        ]
        # add per-class counts in the same order as header
        class_counts = m.get('class_counts', {})
        row.extend([class_counts.get(cid, 0) for cid in target_classes])
        row.extend([m['labeled_by'], m['reviewed_by']])
        writer.writerow(row)
    content = output.getvalue()
    output.close()
    return content

# ----------------------
# Streamlit UI
# ----------------------

init_session_state()
st.title("Per-Sample Annotation Counts by Class")

api_key = st.text_input("API Key", type="password", key="api_key_input")
dataset_identifier = st.text_input("Dataset Identifier (e.g., username/dataset)", key="dataset_identifier_input")
classes_input = st.text_input("Classes (e.g., 1,2,5 or 1-3)", key="classes_input")
run_button = st.button("Generate CSV", key="run_button")

sensor_names = []
is_multisensor = False
sensor_select = None
samples_objects = []

if api_key and dataset_identifier:
    try:
        client = init_client(api_key)
        samples_objects = get_samples_objects(client, dataset_identifier)
        if samples_objects:
            label = client.get_label(samples_objects[0].uuid)
            sensors = getattr(getattr(label, 'attributes', None), 'sensors', None)
            if sensors is not None:
                is_multisensor = True
                sensor_names = [getattr(sensor, 'name', 'Unknown') for sensor in sensors]
    except Exception as e:
        st.warning(f"Could not inspect dataset sensors: {e}")

if is_multisensor:
    sensor_select = st.selectbox("Choose sensor (optional)", options=['All sensors'] + sensor_names)


if run_button:
    st.session_state.csv_content = None
    st.session_state.error = None
    if not api_key:
        st.session_state.error = "API Key is required."
    elif not dataset_identifier:
        st.session_state.error = "Dataset identifier is required."
    elif not classes_input:
        st.session_state.error = "Please specify at least one class."
    elif is_multisensor and not sensor_select:
        st.session_state.error = "Please select a sensor or 'All sensors' before generating CSV."
    else:
        # Show loader/status message while checking dataset type and generating CSV
        status_ctx = None
        try:
            status_ctx = st.status("Checking dataset type...", expanded=True)
        except AttributeError:
            st.info("Checking dataset type...")
        try:
            target_classes = parse_classes(classes_input)
            target_set = set(target_classes)
            metrics = []
            # Update loader after dataset type check
            total = len(samples_objects)
            if status_ctx is not None:
                status_ctx.update(label=f"Dataset type checked. Processing {total} samples...", state="running")
            progress = st.progress(0)
            done = 0
            failed_samples = []
            
            # Process samples sequentially instead of in parallel
            for i, sample in enumerate(samples_objects):
                try:
                    rows = compute_metrics_for_sample(
                        sample,
                        api_key,
                        target_set,
                        target_classes,
                        is_multisensor,
                        sensor_select,
                    )
                    if rows:
                        metrics.extend(rows)
                    else:
                        failed_samples.append(f"Sample {sample.uuid}: No metrics generated")
                except Exception as e:
                    failed_samples.append(f"Sample {sample.uuid}: {str(e)}")
                
                done += 1
                if total:
                    progress.progress(min(done / total, 1.0))
            
            if not metrics:
                st.session_state.error = "No metrics could be generated for the dataset."
            else:
                st.session_state.csv_content = generate_csv(metrics, dataset_identifier, target_classes)
                success_msg = f"CSV generated! Processed {len(metrics)} samples"
                if failed_samples:
                    success_msg += f" ({len(failed_samples)} samples failed)"
                if status_ctx is not None:
                    status_ctx.update(label=success_msg, state="complete")
            
            if failed_samples:
                st.warning(f"{len(failed_samples)} samples failed processing. First few errors:")
                for failure in failed_samples[:5]:  # Show first 5 failures
                    st.text(failure)
        except Exception as e:
            st.session_state.error = f"An error occurred: {e}"
            if status_ctx is not None:
                status_ctx.update(label="Error occurred.", state="error")

if st.session_state.error:
    st.error(st.session_state.error)

if st.session_state.csv_content:
    today_str = datetime.now().strftime("%Y%m%d")
    filename = f"{today_str}_{dataset_identifier}_count-by-class.csv"
    st.download_button(
        "Download CSV",
        data=st.session_state.csv_content,
        file_name=filename,
        mime="text/csv"
    )