File size: 5,415 Bytes
7c012de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""
Knowledge Base Browser - A Gradio Custom Component for RAG applications
"""

from __future__ import annotations

import json
import os
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union

import gradio as gr
from gradio.components.base import Component
from gradio.events import Events

from .retriever import KnowledgeRetriever


class KnowledgeBrowser(Component):
    """
    A custom Gradio component that provides a knowledge base browser interface
    for retrieval-augmented generation use cases.
    """

    EVENTS = [
        Events.change,
        Events.submit,
        Events.select,
    ]

    def __init__(
        self,
        query: str = "",
        results: Optional[List[Dict[str, Any]]] = None,
        index_path: str = "./data",
        search_type: str = "semantic",
        max_results: int = 10,
        label: Optional[str] = None,
        every: Optional[float] = None,
        show_label: Optional[bool] = None,
        container: bool = True,
        scale: Optional[int] = None,
        min_width: int = 160,
        visible: bool = True,
        elem_id: Optional[str] = None,
        elem_classes: Optional[List[str] | str] = None,
        render: bool = True,
        **kwargs,
    ):
        """
        Parameters:
            query: Initial search query
            results: Pre-loaded search results
            index_path: Path to document index
            search_type: Type of search ("semantic", "keyword", "hybrid")
            max_results: Maximum number of results to return
            label: Component label
            every: Timer interval for updates
            show_label: Whether to show the label
            container: Whether to place component in container
            scale: Relative width compared to adjacent components
            min_width: Minimum pixel width
            visible: Whether component is visible
            elem_id: Optional HTML element ID
            elem_classes: Optional HTML element classes
            render: Whether to render component immediately
        """
        self.query = query
        self.results = results or []
        self.search_type = search_type
        self.max_results = max_results
        
        # Initialize the retriever
        self.retriever = KnowledgeRetriever(index_path)
        
        super().__init__(
            label=label,
            every=every,
            show_label=show_label,
            container=container,
            scale=scale,
            min_width=min_width,
            visible=visible,
            elem_id=elem_id,
            elem_classes=elem_classes,
            render=render,
            **kwargs,
        )

    def preprocess(self, payload: Dict[str, Any]) -> Dict[str, Any]:
        """
        Preprocesses the component's payload to convert it to a format expected by the backend.
        """
        return {
            "query": payload.get("query", ""),
            "search_type": payload.get("search_type", self.search_type),
            "max_results": payload.get("max_results", self.max_results),
            "filters": payload.get("filters", {}),
        }

    def postprocess(self, value: Dict[str, Any]) -> Dict[str, Any]:
        """
        Postprocesses the component's value to convert it to a format expected by the frontend.
        """
        if value is None:
            return {"query": self.query, "results": self.results}
        
        return {
            "query": value.get("query", self.query),
            "results": value.get("results", self.results),
            "search_type": value.get("search_type", self.search_type),
            "total_count": value.get("total_count", 0),
            "search_time": value.get("search_time", 0),
        }

    def api_info(self) -> Dict[str, Any]:
        """
        Returns the API information for this component.
        """
        return {
            "info": {
                "type": "object",
                "properties": {
                    "query": {"type": "string"},
                    "results": {"type": "array"},
                    "search_type": {"type": "string"},
                    "total_count": {"type": "number"},
                    "search_time": {"type": "number"},
                },
            },
            "serialized_info": False,
        }

    def example_inputs(self) -> Any:
        """
        Returns example inputs for this component.
        """
        return {
            "query": "retrieval augmented generation",
            "search_type": "semantic",
            "max_results": 5,
        }

    def search(self, query: str, search_type: str = None, max_results: int = None) -> Dict[str, Any]:
        """
        Performs a search using the knowledge retriever.
        """
        search_type = search_type or self.search_type
        max_results = max_results or self.max_results
        
        results = self.retriever.search(
            query=query,
            search_type=search_type,
            k=max_results
        )
        
        return {
            "query": query,
            "results": results["documents"],
            "search_type": search_type,
            "total_count": len(results["documents"]),
            "search_time": results["search_time"],
        }

    @property
    def skip_api(self):
        return False


# Export the component
__all__ = ["KnowledgeBrowser"]