|
""" |
|
Knowledge Base Browser - A Gradio Custom Component for RAG applications |
|
""" |
|
|
|
from __future__ import annotations |
|
|
|
import json |
|
import os |
|
from pathlib import Path |
|
from typing import Any, Callable, Dict, List, Optional, Union |
|
|
|
import gradio as gr |
|
from gradio.components.base import Component |
|
from gradio.events import Events |
|
|
|
from .retriever import KnowledgeRetriever |
|
|
|
|
|
class KnowledgeBrowser(Component): |
|
""" |
|
A custom Gradio component that provides a knowledge base browser interface |
|
for retrieval-augmented generation use cases. |
|
""" |
|
|
|
EVENTS = [ |
|
Events.change, |
|
Events.submit, |
|
Events.select, |
|
] |
|
|
|
def __init__( |
|
self, |
|
query: str = "", |
|
results: Optional[List[Dict[str, Any]]] = None, |
|
index_path: str = "./data", |
|
search_type: str = "semantic", |
|
max_results: int = 10, |
|
label: Optional[str] = None, |
|
every: Optional[float] = None, |
|
show_label: Optional[bool] = None, |
|
container: bool = True, |
|
scale: Optional[int] = None, |
|
min_width: int = 160, |
|
visible: bool = True, |
|
elem_id: Optional[str] = None, |
|
elem_classes: Optional[List[str] | str] = None, |
|
render: bool = True, |
|
**kwargs, |
|
): |
|
""" |
|
Parameters: |
|
query: Initial search query |
|
results: Pre-loaded search results |
|
index_path: Path to document index |
|
search_type: Type of search ("semantic", "keyword", "hybrid") |
|
max_results: Maximum number of results to return |
|
label: Component label |
|
every: Timer interval for updates |
|
show_label: Whether to show the label |
|
container: Whether to place component in container |
|
scale: Relative width compared to adjacent components |
|
min_width: Minimum pixel width |
|
visible: Whether component is visible |
|
elem_id: Optional HTML element ID |
|
elem_classes: Optional HTML element classes |
|
render: Whether to render component immediately |
|
""" |
|
self.query = query |
|
self.results = results or [] |
|
self.search_type = search_type |
|
self.max_results = max_results |
|
|
|
|
|
self.retriever = KnowledgeRetriever(index_path) |
|
|
|
super().__init__( |
|
label=label, |
|
every=every, |
|
show_label=show_label, |
|
container=container, |
|
scale=scale, |
|
min_width=min_width, |
|
visible=visible, |
|
elem_id=elem_id, |
|
elem_classes=elem_classes, |
|
render=render, |
|
**kwargs, |
|
) |
|
|
|
def preprocess(self, payload: Dict[str, Any]) -> Dict[str, Any]: |
|
""" |
|
Preprocesses the component's payload to convert it to a format expected by the backend. |
|
""" |
|
return { |
|
"query": payload.get("query", ""), |
|
"search_type": payload.get("search_type", self.search_type), |
|
"max_results": payload.get("max_results", self.max_results), |
|
"filters": payload.get("filters", {}), |
|
} |
|
|
|
def postprocess(self, value: Dict[str, Any]) -> Dict[str, Any]: |
|
""" |
|
Postprocesses the component's value to convert it to a format expected by the frontend. |
|
""" |
|
if value is None: |
|
return {"query": self.query, "results": self.results} |
|
|
|
return { |
|
"query": value.get("query", self.query), |
|
"results": value.get("results", self.results), |
|
"search_type": value.get("search_type", self.search_type), |
|
"total_count": value.get("total_count", 0), |
|
"search_time": value.get("search_time", 0), |
|
} |
|
|
|
def api_info(self) -> Dict[str, Any]: |
|
""" |
|
Returns the API information for this component. |
|
""" |
|
return { |
|
"info": { |
|
"type": "object", |
|
"properties": { |
|
"query": {"type": "string"}, |
|
"results": {"type": "array"}, |
|
"search_type": {"type": "string"}, |
|
"total_count": {"type": "number"}, |
|
"search_time": {"type": "number"}, |
|
}, |
|
}, |
|
"serialized_info": False, |
|
} |
|
|
|
def example_inputs(self) -> Any: |
|
""" |
|
Returns example inputs for this component. |
|
""" |
|
return { |
|
"query": "retrieval augmented generation", |
|
"search_type": "semantic", |
|
"max_results": 5, |
|
} |
|
|
|
def search(self, query: str, search_type: str = None, max_results: int = None) -> Dict[str, Any]: |
|
""" |
|
Performs a search using the knowledge retriever. |
|
""" |
|
search_type = search_type or self.search_type |
|
max_results = max_results or self.max_results |
|
|
|
results = self.retriever.search( |
|
query=query, |
|
search_type=search_type, |
|
k=max_results |
|
) |
|
|
|
return { |
|
"query": query, |
|
"results": results["documents"], |
|
"search_type": search_type, |
|
"total_count": len(results["documents"]), |
|
"search_time": results["search_time"], |
|
} |
|
|
|
@property |
|
def skip_api(self): |
|
return False |
|
|
|
|
|
|
|
__all__ = ["KnowledgeBrowser"] |