fazeel007's picture
initial commit
7c012de
"""
Knowledge Base Browser - A Gradio Custom Component for RAG applications
"""
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union
import gradio as gr
from gradio.components.base import Component
from gradio.events import Events
from .retriever import KnowledgeRetriever
from gradio.events import Dependency
class KnowledgeBrowser(Component):
"""
A custom Gradio component that provides a knowledge base browser interface
for retrieval-augmented generation use cases.
"""
EVENTS = [
Events.change,
Events.submit,
Events.select,
]
def __init__(
self,
query: str = "",
results: Optional[List[Dict[str, Any]]] = None,
index_path: str = "./data",
search_type: str = "semantic",
max_results: int = 10,
label: Optional[str] = None,
every: Optional[float] = None,
show_label: Optional[bool] = None,
container: bool = True,
scale: Optional[int] = None,
min_width: int = 160,
visible: bool = True,
elem_id: Optional[str] = None,
elem_classes: Optional[List[str] | str] = None,
render: bool = True,
**kwargs,
):
"""
Parameters:
query: Initial search query
results: Pre-loaded search results
index_path: Path to document index
search_type: Type of search ("semantic", "keyword", "hybrid")
max_results: Maximum number of results to return
label: Component label
every: Timer interval for updates
show_label: Whether to show the label
container: Whether to place component in container
scale: Relative width compared to adjacent components
min_width: Minimum pixel width
visible: Whether component is visible
elem_id: Optional HTML element ID
elem_classes: Optional HTML element classes
render: Whether to render component immediately
"""
self.query = query
self.results = results or []
self.search_type = search_type
self.max_results = max_results
# Initialize the retriever
self.retriever = KnowledgeRetriever(index_path)
super().__init__(
label=label,
every=every,
show_label=show_label,
container=container,
scale=scale,
min_width=min_width,
visible=visible,
elem_id=elem_id,
elem_classes=elem_classes,
render=render,
**kwargs,
)
def preprocess(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""
Preprocesses the component's payload to convert it to a format expected by the backend.
"""
return {
"query": payload.get("query", ""),
"search_type": payload.get("search_type", self.search_type),
"max_results": payload.get("max_results", self.max_results),
"filters": payload.get("filters", {}),
}
def postprocess(self, value: Dict[str, Any]) -> Dict[str, Any]:
"""
Postprocesses the component's value to convert it to a format expected by the frontend.
"""
if value is None:
return {"query": self.query, "results": self.results}
return {
"query": value.get("query", self.query),
"results": value.get("results", self.results),
"search_type": value.get("search_type", self.search_type),
"total_count": value.get("total_count", 0),
"search_time": value.get("search_time", 0),
}
def api_info(self) -> Dict[str, Any]:
"""
Returns the API information for this component.
"""
return {
"info": {
"type": "object",
"properties": {
"query": {"type": "string"},
"results": {"type": "array"},
"search_type": {"type": "string"},
"total_count": {"type": "number"},
"search_time": {"type": "number"},
},
},
"serialized_info": False,
}
def example_inputs(self) -> Any:
"""
Returns example inputs for this component.
"""
return {
"query": "retrieval augmented generation",
"search_type": "semantic",
"max_results": 5,
}
def search(self, query: str, search_type: str = None, max_results: int = None) -> Dict[str, Any]:
"""
Performs a search using the knowledge retriever.
"""
search_type = search_type or self.search_type
max_results = max_results or self.max_results
results = self.retriever.search(
query=query,
search_type=search_type,
k=max_results
)
return {
"query": query,
"results": results["documents"],
"search_type": search_type,
"total_count": len(results["documents"]),
"search_time": results["search_time"],
}
@property
def skip_api(self):
return False
from typing import Callable, Literal, Sequence, Any, TYPE_CHECKING
from gradio.blocks import Block
if TYPE_CHECKING:
from gradio.components import Timer
from gradio.components.base import Component
def change(self,
fn: Callable[..., Any] | None = None,
inputs: Block | Sequence[Block] | set[Block] | None = None,
outputs: Block | Sequence[Block] | None = None,
api_name: str | None | Literal[False] = None,
scroll_to_output: bool = False,
show_progress: Literal["full", "minimal", "hidden"] = "full",
show_progress_on: Component | Sequence[Component] | None = None,
queue: bool | None = None,
batch: bool = False,
max_batch_size: int = 4,
preprocess: bool = True,
postprocess: bool = True,
cancels: dict[str, Any] | list[dict[str, Any]] | None = None,
every: Timer | float | None = None,
trigger_mode: Literal["once", "multiple", "always_last"] | None = None,
js: str | Literal[True] | None = None,
concurrency_limit: int | None | Literal["default"] = "default",
concurrency_id: str | None = None,
show_api: bool = True,
) -> Dependency:
"""
Parameters:
fn: the function to call when this event is triggered. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
inputs: list of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
outputs: list of gradio.components to use as outputs. If the function returns no outputs, this should be an empty list.
api_name: defines how the endpoint appears in the API docs. Can be a string, None, or False. If False, the endpoint will not be exposed in the api docs. If set to None, will use the functions name as the endpoint route. If set to a string, the endpoint will be exposed in the api docs with the given name.
scroll_to_output: if True, will scroll to output component on completion
show_progress: how to show the progress animation while event is running: "full" shows a spinner which covers the output component area as well as a runtime display in the upper right corner, "minimal" only shows the runtime display, "hidden" shows no progress animation at all
show_progress_on: Component or list of components to show the progress animation on. If None, will show the progress animation on all of the output components.
queue: if True, will place the request on the queue, if the queue has been enabled. If False, will not put this event on the queue, even if the queue has been enabled. If None, will use the queue setting of the gradio app.
batch: if True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
max_batch_size: maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
preprocess: if False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: if False, will not run postprocessing of component data before returning 'fn' output to the browser.
cancels: a list of other events to cancel when this listener is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method. Functions that have not yet run (or generators that are iterating) will be cancelled, but functions that are currently running will be allowed to finish.
every: continously calls `value` to recalculate it if `value` is a function (has no effect otherwise). Can provide a Timer whose tick resets `value`, or a float that provides the regular interval for the reset Timer.
trigger_mode: if "once" (default for all events except `.change()`) would not allow any submissions while an event is pending. If set to "multiple", unlimited submissions are allowed while pending, and "always_last" (default for `.change()` and `.key_up()` events) would allow a second submission after the pending event is complete.
js: optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
concurrency_limit: if set, this is the maximum number of this event that can be running simultaneously. Can be set to None to mean no concurrency_limit (any number of this event can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `Blocks.queue()`, which itself is 1 by default).
concurrency_id: if set, this is the id of the concurrency group. Events with the same concurrency_id will be limited by the lowest set concurrency_limit.
show_api: whether to show this event in the "view API" page of the Gradio app, or in the ".view_api()" method of the Gradio clients. Unlike setting api_name to False, setting show_api to False will still allow downstream apps as well as the Clients to use this event. If fn is None, show_api will automatically be set to False.
"""
...
def submit(self,
fn: Callable[..., Any] | None = None,
inputs: Block | Sequence[Block] | set[Block] | None = None,
outputs: Block | Sequence[Block] | None = None,
api_name: str | None | Literal[False] = None,
scroll_to_output: bool = False,
show_progress: Literal["full", "minimal", "hidden"] = "full",
show_progress_on: Component | Sequence[Component] | None = None,
queue: bool | None = None,
batch: bool = False,
max_batch_size: int = 4,
preprocess: bool = True,
postprocess: bool = True,
cancels: dict[str, Any] | list[dict[str, Any]] | None = None,
every: Timer | float | None = None,
trigger_mode: Literal["once", "multiple", "always_last"] | None = None,
js: str | Literal[True] | None = None,
concurrency_limit: int | None | Literal["default"] = "default",
concurrency_id: str | None = None,
show_api: bool = True,
) -> Dependency:
"""
Parameters:
fn: the function to call when this event is triggered. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
inputs: list of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
outputs: list of gradio.components to use as outputs. If the function returns no outputs, this should be an empty list.
api_name: defines how the endpoint appears in the API docs. Can be a string, None, or False. If False, the endpoint will not be exposed in the api docs. If set to None, will use the functions name as the endpoint route. If set to a string, the endpoint will be exposed in the api docs with the given name.
scroll_to_output: if True, will scroll to output component on completion
show_progress: how to show the progress animation while event is running: "full" shows a spinner which covers the output component area as well as a runtime display in the upper right corner, "minimal" only shows the runtime display, "hidden" shows no progress animation at all
show_progress_on: Component or list of components to show the progress animation on. If None, will show the progress animation on all of the output components.
queue: if True, will place the request on the queue, if the queue has been enabled. If False, will not put this event on the queue, even if the queue has been enabled. If None, will use the queue setting of the gradio app.
batch: if True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
max_batch_size: maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
preprocess: if False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: if False, will not run postprocessing of component data before returning 'fn' output to the browser.
cancels: a list of other events to cancel when this listener is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method. Functions that have not yet run (or generators that are iterating) will be cancelled, but functions that are currently running will be allowed to finish.
every: continously calls `value` to recalculate it if `value` is a function (has no effect otherwise). Can provide a Timer whose tick resets `value`, or a float that provides the regular interval for the reset Timer.
trigger_mode: if "once" (default for all events except `.change()`) would not allow any submissions while an event is pending. If set to "multiple", unlimited submissions are allowed while pending, and "always_last" (default for `.change()` and `.key_up()` events) would allow a second submission after the pending event is complete.
js: optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
concurrency_limit: if set, this is the maximum number of this event that can be running simultaneously. Can be set to None to mean no concurrency_limit (any number of this event can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `Blocks.queue()`, which itself is 1 by default).
concurrency_id: if set, this is the id of the concurrency group. Events with the same concurrency_id will be limited by the lowest set concurrency_limit.
show_api: whether to show this event in the "view API" page of the Gradio app, or in the ".view_api()" method of the Gradio clients. Unlike setting api_name to False, setting show_api to False will still allow downstream apps as well as the Clients to use this event. If fn is None, show_api will automatically be set to False.
"""
...
def select(self,
fn: Callable[..., Any] | None = None,
inputs: Block | Sequence[Block] | set[Block] | None = None,
outputs: Block | Sequence[Block] | None = None,
api_name: str | None | Literal[False] = None,
scroll_to_output: bool = False,
show_progress: Literal["full", "minimal", "hidden"] = "full",
show_progress_on: Component | Sequence[Component] | None = None,
queue: bool | None = None,
batch: bool = False,
max_batch_size: int = 4,
preprocess: bool = True,
postprocess: bool = True,
cancels: dict[str, Any] | list[dict[str, Any]] | None = None,
every: Timer | float | None = None,
trigger_mode: Literal["once", "multiple", "always_last"] | None = None,
js: str | Literal[True] | None = None,
concurrency_limit: int | None | Literal["default"] = "default",
concurrency_id: str | None = None,
show_api: bool = True,
) -> Dependency:
"""
Parameters:
fn: the function to call when this event is triggered. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
inputs: list of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
outputs: list of gradio.components to use as outputs. If the function returns no outputs, this should be an empty list.
api_name: defines how the endpoint appears in the API docs. Can be a string, None, or False. If False, the endpoint will not be exposed in the api docs. If set to None, will use the functions name as the endpoint route. If set to a string, the endpoint will be exposed in the api docs with the given name.
scroll_to_output: if True, will scroll to output component on completion
show_progress: how to show the progress animation while event is running: "full" shows a spinner which covers the output component area as well as a runtime display in the upper right corner, "minimal" only shows the runtime display, "hidden" shows no progress animation at all
show_progress_on: Component or list of components to show the progress animation on. If None, will show the progress animation on all of the output components.
queue: if True, will place the request on the queue, if the queue has been enabled. If False, will not put this event on the queue, even if the queue has been enabled. If None, will use the queue setting of the gradio app.
batch: if True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
max_batch_size: maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
preprocess: if False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: if False, will not run postprocessing of component data before returning 'fn' output to the browser.
cancels: a list of other events to cancel when this listener is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method. Functions that have not yet run (or generators that are iterating) will be cancelled, but functions that are currently running will be allowed to finish.
every: continously calls `value` to recalculate it if `value` is a function (has no effect otherwise). Can provide a Timer whose tick resets `value`, or a float that provides the regular interval for the reset Timer.
trigger_mode: if "once" (default for all events except `.change()`) would not allow any submissions while an event is pending. If set to "multiple", unlimited submissions are allowed while pending, and "always_last" (default for `.change()` and `.key_up()` events) would allow a second submission after the pending event is complete.
js: optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
concurrency_limit: if set, this is the maximum number of this event that can be running simultaneously. Can be set to None to mean no concurrency_limit (any number of this event can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `Blocks.queue()`, which itself is 1 by default).
concurrency_id: if set, this is the id of the concurrency group. Events with the same concurrency_id will be limited by the lowest set concurrency_limit.
show_api: whether to show this event in the "view API" page of the Gradio app, or in the ".view_api()" method of the Gradio clients. Unlike setting api_name to False, setting show_api to False will still allow downstream apps as well as the Clients to use this event. If fn is None, show_api will automatically be set to False.
"""
...
# Export the component
__all__ = ["KnowledgeBrowser"]