# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 import numpy as np import numpy.typing as npt from typing import Dict, List, Tuple, Literal def get_overlaps( boxes: npt.NDArray[np.float64], other_boxes: npt.NDArray[np.float64], normalize: Literal["box_only", "all"] = "box_only", ) -> npt.NDArray[np.float64]: """ Checks if a box overlaps with any other box. Boxes are expeceted in format (x0, y0, x1, y1) Args: boxes (np array [4] or [n x 4]): Boxes. other_boxes (np array [m x 4]): Other boxes. Returns: np array [n x m]: Overlaps. """ if boxes.ndim == 1: boxes = boxes[None, :] x0, y0, x1, y1 = ( boxes[:, 0][:, None], boxes[:, 1][:, None], boxes[:, 2][:, None], boxes[:, 3][:, None], ) areas = (y1 - y0) * (x1 - x0) x0_other, y0_other, x1_other, y1_other = ( other_boxes[:, 0][None, :], other_boxes[:, 1][None, :], other_boxes[:, 2][None, :], other_boxes[:, 3][None, :], ) areas_other = (y1_other - y0_other) * (x1_other - x0_other) # Intersection inter_y0 = np.maximum(y0, y0_other) inter_y1 = np.minimum(y1, y1_other) inter_x0 = np.maximum(x0, x0_other) inter_x1 = np.minimum(x1, x1_other) inter_area = np.maximum(0, inter_y1 - inter_y0) * np.maximum(0, inter_x1 - inter_x0) # Overlap if normalize == "box_only": # Only consider box included in other box overlaps = inter_area / areas elif ( normalize == "all" ): # Consider box included in other box and other box included in box overlaps = inter_area / np.minimum(areas, areas_other[:, None]) else: raise ValueError(f"Invalid normalization: {normalize}") return overlaps def get_distances( title_boxes: npt.NDArray[np.float64], other_boxes: npt.NDArray[np.float64] ) -> npt.NDArray[np.float64]: """ Computes the distances between title and table/chart boxes. Distance is computed as the sum of the vertical and horizontal distances. Horizontal distance uses min(boxes center dist, boxes left dist). Vertical distance uses min(top_title to bottom_other dists, bottom_title to top_other dists). Args: title_boxes (np array [n_titles x 4]): Title boxes. other_boxes (np array [n_other x 4]): Other boxes. Returns: np array [n_titles x n_other]: Distances between titles and other boxes. """ x0_title, xc_title, y0_title, y1_title = ( title_boxes[:, 0], (title_boxes[:, 0] + title_boxes[:, 2]) / 2, title_boxes[:, 1], title_boxes[:, 3], ) x0_other, xc_other, y0_other, y1_other = ( other_boxes[:, 0], (other_boxes[:, 0] + other_boxes[:, 2]) / 2, other_boxes[:, 1], other_boxes[:, 3], ) x_dists = np.min( [ np.abs( xc_title[:, None] - xc_other[None, :] ), # Title center to other center np.abs(x0_title[:, None] - x0_other[None, :]), # Title left to other left ], axis=0, ) y_dists = np.min( [ np.abs(y1_title[:, None] - y0_other[None, :]), # Title above other np.abs(y0_title[:, None] - y1_other[None, :]), # Title below other ], axis=0, ) dists = y_dists + x_dists / 2 return dists def find_titles( title_boxes: npt.NDArray[np.float64], table_boxes: npt.NDArray[np.float64], chart_boxes: npt.NDArray[np.float64], max_dist: float = 0.1, ) -> Dict[int, Tuple[str, int]]: """ Associates titles to tables and charts. Args: title_boxes (np array [n_titles x 4]): Title boxes. table_boxes (np array [n_tables x 4]): Table boxes. chart_boxes (np array [n_charts x 4]): Chart boxes. max_dist (float, optional): Maximum distance between title and table/chart. Defaults to 0.1. Returns: dict: Dictionary of assigned titles. - Keys are the indices of the titles, - Values are tuples of: - str: Whether the title is assigned to a "chart" or "table" - int: index of the assigned table/chart """ if not len(title_boxes) or not (len(table_boxes) or len(chart_boxes)): return {} # print(title_boxes.shape, table_boxes.shape, chart_boxes.shape) # Get distances chart_distances = np.ones((len(title_boxes), 0)) if len(chart_boxes): chart_distances = get_distances(title_boxes, chart_boxes) chart_overlaps = get_overlaps(title_boxes, chart_boxes, normalize="box_only") # print(chart_overlaps, "chart_overlaps", chart_overlaps.shape) # print(chart_distances, "chart_distances", chart_distances.shape) chart_distances = np.where(chart_overlaps > 0.25, 0, chart_distances) # print(chart_distances) table_distances = np.ones((len(title_boxes), 0)) if len(table_boxes): table_distances = get_distances(title_boxes, table_boxes) if len(chart_boxes): # Penalize table titles that are inside charts table_distances = np.where( chart_overlaps.max(1, keepdims=True) > 0.25, table_distances * 10, table_distances, ) # print(table_distances, "table_distances") # Assign to tables assigned_titles = {} for i, table in enumerate(table_boxes): best_match = np.argmin(table_distances[:, i]) if table_distances[best_match, i] < max_dist: assigned_titles[best_match] = ("table", i) table_distances[best_match] = np.inf chart_distances[best_match] = np.inf # Assign to charts for i, chart in enumerate(chart_boxes): best_match = np.argmin(chart_distances[:, i]) if chart_distances[best_match, i] < max_dist: assigned_titles[best_match] = ("chart", i) chart_distances[best_match] = np.inf return assigned_titles def postprocess_included( boxes: npt.NDArray[np.float64], labels: npt.NDArray[np.int_], confs: npt.NDArray[np.float64], class_: str = "title", classes: List[str] = ["table", "chart", "title", "infographic"], ) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.int_], npt.NDArray[np.float64]]: """ Post process title predictions. - Remove titles that are included in other boxes Args: boxes (numpy.ndarray [N, 4]): Array of bounding boxes. labels (numpy.ndarray [N]): Array of labels. confs (numpy.ndarray [N]): Array of confidences. class_ (str, optional): Class to postprocess. Defaults to "title". classes (list, optional): Classes. Defaults to ["table", "chart", "title", "infographic"]. Returns: boxes (numpy.ndarray): Array of bounding boxes. labels (numpy.ndarray): Array of labels. confs (numpy.ndarray): Array of confidences. """ boxes_to_pp = boxes[labels == classes.index(class_)] confs_to_pp = confs[labels == classes.index(class_)] order = np.argsort(confs_to_pp) # least to most confident for NMS boxes_to_pp, confs_to_pp = boxes_to_pp[order], confs_to_pp[order] if len(boxes_to_pp) == 0: return boxes, labels, confs # other_boxes = boxes[labels != classes.index("title")] inclusion_classes = ["table", "infographic", "chart"] if class_ in ["header_footer", "title"]: inclusion_classes.append("text") other_boxes = boxes[np.isin(labels, [classes.index(c) for c in inclusion_classes])] # Remove boxes included in other_boxes kept_boxes, kept_confs = [], [] for i, b in enumerate(boxes_to_pp): # # Inclusion NMS # if i < len(titles) - 1: # overlaps_titles = get_overlaps(t, titles[i + 1:], normalize="all") # if overlaps_titles.max() > 0.9: # continue # print(t) # print(other_boxes) if len(other_boxes) > 0: overlaps = get_overlaps(b, other_boxes, normalize="box_only") if overlaps.max() > 0.9: continue kept_boxes.append(b) kept_confs.append(confs_to_pp[i]) # Aggregate kept_boxes = np.stack(kept_boxes) if len(kept_boxes) else np.empty((0, 4)) kept_confs = np.stack(kept_confs) if len(kept_confs) else np.empty(0) boxes_pp = np.concatenate([boxes[labels != classes.index(class_)], kept_boxes]) confs_pp = np.concatenate([confs[labels != classes.index(class_)], kept_confs]) labels_pp = np.concatenate( [ labels[labels != classes.index(class_)], np.ones(len(kept_boxes)) * classes.index(class_), ] ) return boxes_pp, labels_pp, confs_pp