#!/usr/bin/env python3 """ File: count_by_class.py Location: 6_Data_metrics/count_by_class.py Description: Streamlit application to count and report metrics per sample for specified classes. For each sample, outputs: - Sample name - Sample URL (including labelset) - Number of frames (ignoring a 31st frame if present) - Total number of annotations - Number of annotations matching any of the user-specified classes - Labeled by (from label data) - Reviewed by (from label data) Usage: streamlit run 6_Data_metrics/count_by_class.py """ import streamlit as st import io import csv from segments import SegmentsClient from datetime import datetime import sys import os from get_labels_from_samples import get_samples as get_samples_objects def init_session_state(): if 'csv_content' not in st.session_state: st.session_state.csv_content = None if 'error' not in st.session_state: st.session_state.error = None def init_client(api_key: str) -> SegmentsClient: """Initialize the Segments.ai API client using the provided API key.""" return SegmentsClient(api_key) def parse_classes(input_str: str) -> list: """ Parse user input for classes (ranges and comma-separated lists). Returns unique sorted list of ints. """ classes = [] tokens = input_str.split(',') for token in tokens: token = token.strip() if '-' in token: try: start, end = map(int, token.split('-')) classes.extend(range(start, end + 1)) except ValueError: continue else: try: classes.append(int(token)) except ValueError: continue return sorted(set(classes)) def _count_from_frames(frames, target_set, class_ids): """Helper to count frames, totals, and per-class counts directly.""" if not frames: return 0, 0, 0, {cid: 0 for cid in class_ids} num_frames = len(frames) total_annotations = 0 matching_annotations = 0 class_counts = {cid: 0 for cid in class_ids} for f in frames: anns = getattr(f, 'annotations', []) total_annotations += len(anns) if target_set: for ann in anns: cid = getattr(ann, 'category_id', None) if cid in target_set: matching_annotations += 1 if cid in class_counts: class_counts[cid] += 1 return num_frames, total_annotations, matching_annotations, class_counts def compute_metrics_for_sample(sample, api_key, target_set, class_ids, is_multisensor, sensor_select): """ Fetch label for a single sample and compute metrics. Returns a list of metric dicts (one per sensor if 'All sensors', otherwise one). """ try: client = init_client(api_key) label = client.get_label(sample.uuid) labelset = getattr(label, 'labelset', '') or '' labeled_by = getattr(label, 'created_by', '') or '' reviewed_by = getattr(label, 'reviewed_by', '') or '' metrics_rows = [] if is_multisensor: sensors = getattr(getattr(label, 'attributes', None), 'sensors', None) or [] if sensor_select and sensor_select != 'All sensors': # single sensor for sensor in sensors: if getattr(sensor, 'name', None) == sensor_select: frames = getattr(getattr(sensor, 'attributes', None), 'frames', []) num_frames, total_annotations, matching_annotations, class_counts = _count_from_frames(frames, target_set, class_ids) metrics_rows.append({ 'name': getattr(sample, 'name', sample.uuid), 'uuid': sample.uuid, 'labelset': labelset, 'sensor': sensor_select, 'num_frames': num_frames, 'total_annotations': total_annotations, 'matching_annotations': matching_annotations, 'class_counts': class_counts, 'labeled_by': labeled_by, 'reviewed_by': reviewed_by }) break else: # all sensors for sensor in sensors: sensor_name = getattr(sensor, 'name', 'Unknown') frames = getattr(getattr(sensor, 'attributes', None), 'frames', []) num_frames, total_annotations, matching_annotations, class_counts = _count_from_frames(frames, target_set, class_ids) metrics_rows.append({ 'name': getattr(sample, 'name', sample.uuid), 'uuid': sample.uuid, 'labelset': labelset, 'sensor': sensor_name, 'num_frames': num_frames, 'total_annotations': total_annotations, 'matching_annotations': matching_annotations, 'class_counts': class_counts, 'labeled_by': labeled_by, 'reviewed_by': reviewed_by }) else: # single-sensor dataset frames = getattr(getattr(label, 'attributes', None), 'frames', []) num_frames, total_annotations, matching_annotations, class_counts = _count_from_frames(frames, target_set, class_ids) metrics_rows.append({ 'name': getattr(sample, 'name', sample.uuid), 'uuid': sample.uuid, 'labelset': labelset, 'sensor': '', 'num_frames': num_frames, 'total_annotations': total_annotations, 'matching_annotations': matching_annotations, 'class_counts': class_counts, 'labeled_by': labeled_by, 'reviewed_by': reviewed_by }) return metrics_rows except Exception: return [] def generate_csv(metrics: list, dataset_identifier: str, target_classes: list[int]) -> str: """ Generate CSV content from list of per-sample metrics. Columns: name, sample_url, sensor, num_frames, total_annotations, matching_annotations, labeled_by, reviewed_by """ output = io.StringIO() writer = csv.writer(output) header = [ 'name', 'sample_url', 'sensor', 'num_frames', 'total_annotations', 'matching_annotations' ] # dynamic per-class columns header.extend([f'class_{cid}' for cid in target_classes]) header.extend(['labeled_by', 'reviewed_by']) writer.writerow(header) for m in metrics: url = f"https://app.segments.ai/{dataset_identifier}/samples/{m['uuid']}/{m['labelset']}" row = [ m['name'], url, m['sensor'], m['num_frames'], m['total_annotations'], m['matching_annotations'] ] # add per-class counts in the same order as header class_counts = m.get('class_counts', {}) row.extend([class_counts.get(cid, 0) for cid in target_classes]) row.extend([m['labeled_by'], m['reviewed_by']]) writer.writerow(row) content = output.getvalue() output.close() return content # ---------------------- # Streamlit UI # ---------------------- init_session_state() st.title("Per-Sample Annotation Counts by Class") api_key = st.text_input("API Key", type="password", key="api_key_input") dataset_identifier = st.text_input("Dataset Identifier (e.g., username/dataset)", key="dataset_identifier_input") classes_input = st.text_input("Classes (e.g., 1,2,5 or 1-3)", key="classes_input") run_button = st.button("Generate CSV", key="run_button") sensor_names = [] is_multisensor = False sensor_select = None samples_objects = [] if api_key and dataset_identifier: try: client = init_client(api_key) samples_objects = get_samples_objects(client, dataset_identifier) if samples_objects: label = client.get_label(samples_objects[0].uuid) sensors = getattr(getattr(label, 'attributes', None), 'sensors', None) if sensors is not None: is_multisensor = True sensor_names = [getattr(sensor, 'name', 'Unknown') for sensor in sensors] except Exception as e: st.warning(f"Could not inspect dataset sensors: {e}") if is_multisensor: sensor_select = st.selectbox("Choose sensor (optional)", options=['All sensors'] + sensor_names) if run_button: st.session_state.csv_content = None st.session_state.error = None if not api_key: st.session_state.error = "API Key is required." elif not dataset_identifier: st.session_state.error = "Dataset identifier is required." elif not classes_input: st.session_state.error = "Please specify at least one class." elif is_multisensor and not sensor_select: st.session_state.error = "Please select a sensor or 'All sensors' before generating CSV." else: # Show loader/status message while checking dataset type and generating CSV status_ctx = None try: status_ctx = st.status("Checking dataset type...", expanded=True) except AttributeError: st.info("Checking dataset type...") try: target_classes = parse_classes(classes_input) target_set = set(target_classes) metrics = [] # Update loader after dataset type check total = len(samples_objects) if status_ctx is not None: status_ctx.update(label=f"Dataset type checked. Processing {total} samples...", state="running") progress = st.progress(0) done = 0 failed_samples = [] # Process samples sequentially instead of in parallel for i, sample in enumerate(samples_objects): try: rows = compute_metrics_for_sample( sample, api_key, target_set, target_classes, is_multisensor, sensor_select, ) if rows: metrics.extend(rows) else: failed_samples.append(f"Sample {sample.uuid}: No metrics generated") except Exception as e: failed_samples.append(f"Sample {sample.uuid}: {str(e)}") done += 1 if total: progress.progress(min(done / total, 1.0)) if not metrics: st.session_state.error = "No metrics could be generated for the dataset." else: st.session_state.csv_content = generate_csv(metrics, dataset_identifier, target_classes) success_msg = f"CSV generated! Processed {len(metrics)} samples" if failed_samples: success_msg += f" ({len(failed_samples)} samples failed)" if status_ctx is not None: status_ctx.update(label=success_msg, state="complete") if failed_samples: st.warning(f"{len(failed_samples)} samples failed processing. First few errors:") for failure in failed_samples[:5]: # Show first 5 failures st.text(failure) except Exception as e: st.session_state.error = f"An error occurred: {e}" if status_ctx is not None: status_ctx.update(label="Error occurred.", state="error") if st.session_state.error: st.error(st.session_state.error) if st.session_state.csv_content: today_str = datetime.now().strftime("%Y%m%d") filename = f"{today_str}_{dataset_identifier}_count-by-class.csv" st.download_button( "Download CSV", data=st.session_state.csv_content, file_name=filename, mime="text/csv" )